Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
......@@ -4,10 +4,11 @@
import os
import warnings
from pathlib import Path
from typing import Dict, Union, Optional, List, Callable
from typing import Dict, Union, Optional, List, Callable, Type
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as nn_functional
import torch.optim as optim
import torchmetrics
import torch.utils.data as torch_data
......@@ -124,12 +125,12 @@ class Lightning(Evaluator):
if other is None:
return False
if hasattr(self, "function") and hasattr(other, "function"):
eq_func = (self.function == other.function)
eq_func = getattr(self, "function") == getattr(other, "function")
elif not (hasattr(self, "function") or hasattr(other, "function")):
eq_func = True
if hasattr(self, "arguments") and hasattr(other, "arguments"):
eq_args = (self.arguments == other.arguments)
eq_args = getattr(self, "arguments") == getattr(other, "arguments")
elif not (hasattr(self, "arguments") or hasattr(other, "arguments")):
eq_args = True
......@@ -159,10 +160,13 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
trainer: pl.Trainer
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetrics.Metric]],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
......@@ -214,7 +218,7 @@ class _SupervisedLearningModule(LightningModule):
self.log('test_' + name, metric(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
......@@ -233,15 +237,15 @@ class _SupervisedLearningModule(LightningModule):
class _AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)
return super().update(nn_functional.softmax(pred), target)
@nni.trace
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
......@@ -275,10 +279,10 @@ class Classification(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
......@@ -291,10 +295,10 @@ class Classification(Lightning):
@nni.trace
class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
......@@ -328,10 +332,10 @@ class Regression(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.MSELoss,
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .api import *
......@@ -129,6 +129,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation_summary = get_mutation_summary(model)
assert model.evaluator is not None, 'Model evaluator can not be None'
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)
@classmethod
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import random
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
from ..graph import Model
from ..integration_api import receive_trial_parameters
......@@ -39,6 +42,9 @@ class BenchmarkGraphData:
def load(data) -> 'BenchmarkGraphData':
return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path'])
def __repr__(self) -> str:
return f"BenchmarkGraphData({self.mutation}, {self.benchmark}, {self.db_path})"
class BenchmarkExecutionEngine(BaseExecutionEngine):
"""
......@@ -67,6 +73,7 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = BenchmarkGraphData.load(receive_trial_parameters())
assert graph_data.db_path is not None, f'Invalid graph data because db_path is None: {graph_data}'
os.environ['NASBENCHMARK_DIR'] = graph_data.db_path
final, intermediates = cls.query_in_benchmark(graph_data)
......@@ -89,7 +96,6 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
arch = t
if arch is None:
raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}')
print(arch)
return _convert_to_final_and_intermediates(
query_nb101_trial_stats(arch, 108, include_intermediates=True),
'valid_acc'
......@@ -146,4 +152,5 @@ def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_
benchmark_result = random.choice(benchmark_result)
else:
benchmark_result = benchmark_result[0]
benchmark_result = cast(dict, benchmark_result)
return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import random
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Any, Type
import torch.nn as nn
......@@ -49,7 +52,8 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model)
graph_data = PythonGraphData(model.python_class, model.python_init_params, mutation, model.evaluator)
assert model.evaluator is not None, 'Model evaluator is not available.'
graph_data = PythonGraphData(model.python_class, model.python_init_params or {}, mutation, model.evaluator)
return graph_data
@classmethod
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, List
from ..graph import Model
......
......@@ -11,7 +11,7 @@ from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
from threading import Thread
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Union, cast
import colorama
import psutil
......@@ -23,6 +23,7 @@ from nni.experiment import Experiment, launcher, management, rest
from nni.experiment.config import utils
from nni.experiment.config.base import ConfigBase
from nni.experiment.config.training_service import TrainingServiceConfig
from nni.experiment.config.training_services import RemoteConfig
from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command
......@@ -222,6 +223,7 @@ class RetiariiExperiment(Experiment):
Examples
--------
Multi-trial NAS:
>>> base_model = Net()
>>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
......@@ -233,6 +235,7 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config, 8081)
One-shot NAS:
>>> base_model = Net()
>>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
......@@ -242,15 +245,16 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config)
Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net()
"""
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = None,
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None,
trainer: BaseOneShotTrainer = None):
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None), strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)):
if trainer is not None:
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
......@@ -260,21 +264,22 @@ class RetiariiExperiment(Experiment):
raise ValueError('Evaluator should not be none.')
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = None
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
self.port: Optional[int] = None
self.base_model = base_model
self.evaluator: Evaluator = evaluator
self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator
self.applied_mutators = applied_mutators
self.strategy = strategy
# FIXME: this is only a workaround
from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy
if not isinstance(strategy, OneShotStrategy):
self._dispatcher = RetiariiAdvisor()
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
else:
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
self.url_prefix = None
......@@ -325,7 +330,7 @@ class RetiariiExperiment(Experiment):
assert self.config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert self.config.batch_waiting_time is not None
assert self.config.batch_waiting_time is not None and self.config.max_concurrency_cgo is not None
devices = self._construct_devices()
engine = CGOExecutionEngine(devices,
max_concurrency=self.config.max_concurrency_cgo,
......@@ -335,7 +340,10 @@ class RetiariiExperiment(Experiment):
engine = PurePythonExecutionEngine()
elif self.config.execution_engine == 'benchmark':
from ..execution.benchmark import BenchmarkExecutionEngine
assert self.config.benchmark is not None, '"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(self.config.benchmark)
else:
raise ValueError(f'Unsupported engine type: {self.config.execution_engine}')
set_execution_engine(engine)
self.id = management.generate_experiment_id()
......@@ -377,9 +385,10 @@ class RetiariiExperiment(Experiment):
def _construct_devices(self):
devices = []
if hasattr(self.config.training_service, 'machine_list'):
for machine in self.config.training_service.machine_list:
for machine in cast(RemoteConfig, self.config.training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
......@@ -387,7 +396,7 @@ class RetiariiExperiment(Experiment):
def _create_dispatcher(self):
return self._dispatcher
def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
def run(self, config: Optional[RetiariiExeConfig] = None, port: int = 8080, debug: bool = False) -> None:
"""
Run the experiment.
This function will block until experiment finish or error.
......@@ -420,6 +429,7 @@ class RetiariiExperiment(Experiment):
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
assert self._proc is not None
try:
while True:
time.sleep(10)
......@@ -437,6 +447,7 @@ class RetiariiExperiment(Experiment):
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
raise RuntimeError('Check experiment status failed.')
def stop(self) -> None:
"""
......@@ -466,11 +477,11 @@ class RetiariiExperiment(Experiment):
if self._pipe is not None:
self._pipe.close()
self.id = None
self.port = None
self.id = cast(str, None)
self.port = cast(int, None)
self._proc = None
self._pipe = None
self._dispatcher = None
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread = None
_logger.info('Experiment stopped')
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from pathlib import Path
......
......@@ -5,10 +5,16 @@
Model representation.
"""
from __future__ import annotations
import abc
import json
from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union, cast, overload)
if TYPE_CHECKING:
from .mutator import Mutator
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid
......@@ -63,7 +69,7 @@ class Evaluator(abc.ABC):
pass
@abc.abstractmethod
def _execute(self, model_cls: type) -> Any:
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
pass
@abc.abstractmethod
......@@ -203,7 +209,7 @@ class Model:
matched_nodes.extend(nodes)
return matched_nodes
def get_node_by_name(self, node_name: str) -> 'Node':
def get_node_by_name(self, node_name: str) -> 'Node' | None:
"""
Traverse all the nodes to find the matched node with the given name.
"""
......@@ -217,7 +223,7 @@ class Model:
else:
return None
def get_node_by_python_name(self, python_name: str) -> 'Node':
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
"""
Traverse all the nodes to find the matched node with the given python_name.
"""
......@@ -297,7 +303,7 @@ class Graph:
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
"""
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
def __init__(self, model: Model, graph_id: int, name: str = cast(str, None), _internal: bool = False):
assert _internal, '`Graph()` is private'
self.model: Model = model
......@@ -338,9 +344,9 @@ class Graph:
@overload
def add_node(self, name: str, operation: Operation) -> 'Node': ...
@overload
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def add_node(self, name, operation_or_type, parameters=None):
def add_node(self, name, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
......@@ -350,9 +356,10 @@ class Graph:
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node':
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node': # type: ignore
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
......@@ -405,7 +412,7 @@ class Graph:
def get_nodes_by_name(self, name: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.name == name]
def get_nodes_by_python_name(self, python_name: str) -> Optional['Node']:
def get_nodes_by_python_name(self, python_name: str) -> List['Node']:
return [node for node in self.nodes if node.python_name == python_name]
def topo_sort(self) -> List['Node']:
......@@ -594,7 +601,7 @@ class Node:
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))
@property
def successor_slots(self) -> List[Tuple['Node', Union[int, None]]]:
def successor_slots(self) -> Set[Tuple['Node', Union[int, None]]]:
return set((edge.tail, edge.tail_slot) for edge in self.outgoing_edges)
@property
......@@ -610,19 +617,19 @@ class Node:
assert isinstance(self.operation, Cell)
return self.graph.model.graphs[self.operation.parameters['cell']]
def update_label(self, label: str) -> None:
def update_label(self, label: Optional[str]) -> None:
self.label = label
@overload
def update_operation(self, operation: Operation) -> None: ...
@overload
def update_operation(self, type_name: str, parameters: Dict[str, Any] = None) -> None: ...
def update_operation(self, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> None: ...
def update_operation(self, operation_or_type, parameters=None):
def update_operation(self, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation):
self.operation = operation_or_type
else:
self.operation = Operation.new(operation_or_type, parameters)
self.operation = Operation.new(operation_or_type, cast(dict, parameters))
# mutation
def remove(self) -> None:
......@@ -663,7 +670,13 @@ class Node:
return node
def _dump(self) -> Any:
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters, 'attributes': self.operation.attributes}}
ret: Dict[str, Any] = {
'operation': {
'type': self.operation.type,
'parameters': self.operation.parameters,
'attributes': self.operation.attributes
}
}
if isinstance(self.operation, Cell):
ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Tuple, Optional, Callable
from typing import Tuple, Optional, Callable, cast
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
......@@ -75,10 +75,10 @@ class MobileNetV3Space(nn.Module):
bn_momentum: float = 0.1):
super().__init__()
self.widths = [
self.widths = cast(nn.ChoiceOf[int], [
nn.ValueChoice([make_divisible(base_width * mult, 8) for mult in width_multipliers], label=f'width_{i}')
for i, base_width in enumerate(base_widths)
]
])
self.expand_ratios = expand_ratios
blocks = [
......@@ -115,7 +115,7 @@ class MobileNetV3Space(nn.Module):
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(self.widths[7], num_labels),
nn.Linear(cast(int, self.widths[7]), num_labels),
)
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)
......
......@@ -3,6 +3,7 @@
import math
import torch
import torch.nn as nn
from nni.retiarii import model_wrapper
from nni.retiarii.nn.pytorch import NasBench101Cell
......@@ -11,7 +12,7 @@ from nni.retiarii.nn.pytorch import NasBench101Cell
__all__ = ['NasBench101']
def truncated_normal_(tensor, mean=0, std=1):
def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
......@@ -117,9 +118,3 @@ class NasBench101(nn.Module):
out = self.gap(out).view(bs, -1)
out = self.classifier(out)
return out
def reset_parameters(self):
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.eps = self.config.bn_eps
module.momentum = self.config.bn_momentum
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Callable, Dict
import torch
import torch.nn as nn
......@@ -176,8 +178,10 @@ class NasBench201(nn.Module):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = NasBench201Cell({prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES},
C_prev, C_curr, label='cell')
ops: Dict[str, Callable[[int, int], nn.Module]] = {
prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES
}
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
self.cells.append(cell)
C_prev = C_curr
......
......@@ -8,7 +8,7 @@ It's called ``nasnet.py`` simply because NASNet is the first to propose such str
"""
from collections import OrderedDict
from typing import Tuple, List, Union, Iterable, Dict, Callable
from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast
try:
from typing import Literal
......@@ -250,14 +250,14 @@ class CellPreprocessor(nn.Module):
See :class:`CellBuilder` on how to calculate those channel numbers.
"""
def __init__(self, C_pprev: int, C_prev: int, C: int, last_cell_reduce: bool) -> None:
def __init__(self, C_pprev: nn.MaybeChoice[int], C_prev: nn.MaybeChoice[int], C: nn.MaybeChoice[int], last_cell_reduce: bool) -> None:
super().__init__()
if last_cell_reduce:
self.pre0 = FactorizedReduce(C_pprev, C)
self.pre0 = FactorizedReduce(cast(int, C_pprev), cast(int, C))
else:
self.pre0 = ReLUConvBN(C_pprev, C, 1, 1, 0)
self.pre1 = ReLUConvBN(C_prev, C, 1, 1, 0)
self.pre0 = ReLUConvBN(cast(int, C_pprev), cast(int, C), 1, 1, 0)
self.pre1 = ReLUConvBN(cast(int, C_prev), cast(int, C), 1, 1, 0)
def forward(self, cells):
assert len(cells) == 2
......@@ -283,15 +283,19 @@ class CellBuilder:
Note that the builder is ephemeral, it can only be called once for every index.
"""
def __init__(self, op_candidates: List[str], C_prev_in: int, C_in: int, C: int,
num_nodes: int, merge_op: Literal['all', 'loose_end'],
def __init__(self, op_candidates: List[str],
C_prev_in: nn.MaybeChoice[int],
C_in: nn.MaybeChoice[int],
C: nn.MaybeChoice[int],
num_nodes: int,
merge_op: Literal['all', 'loose_end'],
first_cell_reduce: bool, last_cell_reduce: bool):
self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell.
self.C_in = C_in # This is the out channesl of last cell.
self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self.op_candidates = op_candidates
self.num_nodes = num_nodes
self.merge_op = merge_op
self.merge_op: Literal['all', 'loose_end'] = merge_op
self.first_cell_reduce = first_cell_reduce
self.last_cell_reduce = last_cell_reduce
self._expect_idx = 0
......@@ -312,7 +316,7 @@ class CellBuilder:
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce)
ops_factory: Dict[str, Callable[[int, int, int], nn.Module]] = {
ops_factory: Dict[str, Callable[[int, int, Optional[int]], nn.Module]] = {
op: # make final chosen ops named with their aliases
lambda node_index, op_index, input_index:
OPS[op](self.C, 2 if is_reduction_cell and (
......@@ -353,7 +357,7 @@ _INIT_PARAMETER_DOCS = """
class NDS(nn.Module):
"""
__doc__ = """
The unified version of NASNet search space.
We follow the implementation in
......@@ -378,8 +382,8 @@ class NDS(nn.Module):
op_candidates: List[str],
merge_op: Literal['all', 'loose_end'] = 'all',
num_nodes_per_cell: int = 4,
width: Union[Tuple[int], int] = 16,
num_cells: Union[Tuple[int], int] = 20,
width: Union[Tuple[int, ...], int] = 16,
num_cells: Union[Tuple[int, ...], int] = 20,
dataset: Literal['cifar', 'imagenet'] = 'imagenet',
auxiliary_loss: bool = False):
super().__init__()
......@@ -394,30 +398,31 @@ class NDS(nn.Module):
else:
C = width
self.num_cells: nn.MaybeChoice[int] = cast(int, num_cells)
if isinstance(num_cells, Iterable):
num_cells = nn.ValueChoice(list(num_cells), label='depth')
num_cells_per_stage = [i * num_cells // 3 - (i - 1) * num_cells // 3 for i in range(3)]
self.num_cells = nn.ValueChoice(list(num_cells), label='depth')
num_cells_per_stage = [i * self.num_cells // 3 - (i - 1) * self.num_cells // 3 for i in range(3)]
# auxiliary head is different for network targetted at different datasets
if dataset == 'imagenet':
self.stem0 = nn.Sequential(
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.Conv2d(3, cast(int, C // 2), kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(cast(int, C // 2)),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.Conv2d(cast(int, C // 2), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.Conv2d(cast(int, C), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_pprev = C_prev = C_curr = C
last_cell_reduce = True
elif dataset == 'cifar':
self.stem = nn.Sequential(
nn.Conv2d(3, 3 * C, 3, padding=1, bias=False),
nn.BatchNorm2d(3 * C)
nn.Conv2d(3, cast(int, 3 * C), 3, padding=1, bias=False),
nn.BatchNorm2d(cast(int, 3 * C))
)
C_pprev = C_prev = 3 * C
C_curr = C
......@@ -439,7 +444,7 @@ class NDS(nn.Module):
# C_pprev is output channel number of last second cell among all the cells already built.
if len(stage) > 1:
# Contains more than one cell
C_pprev = len(stage[-2].output_node_indices) * C_curr
C_pprev = len(cast(nn.Cell, stage[-2]).output_node_indices) * C_curr
else:
# Look up in the out channels of last stage.
C_pprev = C_prev
......@@ -447,7 +452,7 @@ class NDS(nn.Module):
# This was originally,
# C_prev = num_nodes_per_cell * C_curr.
# but due to loose end, it becomes,
C_prev = len(stage[-1].output_node_indices) * C_curr
C_prev = len(cast(nn.Cell, stage[-1]).output_node_indices) * C_curr
# Useful in aligning the pprev and prev cell.
last_cell_reduce = cell_builder.last_cell_reduce
......@@ -457,11 +462,11 @@ class NDS(nn.Module):
if auxiliary_loss:
assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.'
self.stages[2] = SequentialBreakdown(self.stages[2])
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset)
self.stages[2] = SequentialBreakdown(cast(nn.Sequential, self.stages[2]))
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(C_prev, self.num_labels)
self.classifier = nn.Linear(cast(int, C_prev), self.num_labels)
def forward(self, inputs):
if self.dataset == 'imagenet':
......@@ -483,7 +488,7 @@ class NDS(nn.Module):
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
if self.training and self.auxiliary_loss:
return logits, logits_aux
return logits, logits_aux # type: ignore
else:
return logits
......@@ -524,8 +529,8 @@ class NASNet(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.NASNET_OPS,
......@@ -555,8 +560,8 @@ class ENAS(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.ENAS_OPS,
......@@ -590,8 +595,8 @@ class AmoebaNet(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
......@@ -626,8 +631,8 @@ class PNAS(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.PNAS_OPS,
......@@ -660,8 +665,8 @@ class DARTS(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.DARTS_OPS,
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import math
from typing import Optional, Callable, List, Tuple
from typing import Optional, Callable, List, Tuple, cast
import torch
import nni.retiarii.nn.pytorch as nn
......@@ -31,12 +31,12 @@ class ConvBNReLU(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
groups: nn.MaybeChoice[int] = 1,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None:
......@@ -46,9 +46,17 @@ class ConvBNReLU(nn.Sequential):
if activation_layer is None:
activation_layer = nn.ReLU6
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups,
bias=False),
norm_layer(out_channels),
nn.Conv2d(
cast(int, in_channels),
cast(int, out_channels),
cast(int, kernel_size),
stride,
cast(int, padding),
dilation=dilation,
groups=cast(int, groups),
bias=False
),
norm_layer(cast(int, out_channels)),
activation_layer(inplace=True)
)
self.out_channels = out_channels
......@@ -62,11 +70,11 @@ class SeparableConv(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__(
......@@ -101,13 +109,13 @@ class InvertedResidual(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: int,
kernel_size: int = 3,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
expand_ratio: nn.MaybeChoice[float],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
squeeze_and_excite: Optional[Callable[[int], nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
squeeze_and_excite: Optional[Callable[[nn.MaybeChoice[int]], nn.Module]] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
......@@ -115,7 +123,7 @@ class InvertedResidual(nn.Sequential):
self.out_channels = out_channels
assert stride in [1, 2]
hidden_ch = nn.ValueChoice.to_int(round(in_channels * expand_ratio))
hidden_ch = nn.ValueChoice.to_int(round(cast(int, in_channels * expand_ratio)))
# FIXME: check whether this equal works
# Residual connection is added here stride = 1 and input channels and output channels are the same.
......@@ -215,7 +223,7 @@ class ProxylessNAS(nn.Module):
self.first_conv = ConvBNReLU(3, widths[0], stride=2, norm_layer=nn.BatchNorm2d)
blocks = [
blocks: List[nn.Module] = [
# first stage is fixed
SeparableConv(widths[0], widths[1], kernel_size=3, stride=1)
]
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import cast
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
......@@ -14,7 +16,7 @@ class ShuffleNetBlock(nn.Module):
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int, *,
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *,
kernel_size: int, stride: int, sequence: str = "pdp", affine: bool = True):
super().__init__()
assert stride in [1, 2]
......@@ -57,14 +59,15 @@ class ShuffleNetBlock(nn.Module):
def _decode_point_depth_conv(self, sequence):
result = []
first_depth = first_point = True
pc = c = self.channels
pc: int = self.channels
c: int = self.channels
for i, token in enumerate(sequence):
# compute output channels of this conv
if i + 1 == len(sequence):
assert token == "p", "Last conv must be point-wise conv."
c = self.oup_main
elif token == "p" and first_point:
c = self.mid_channels
c = cast(int, self.mid_channels)
if token == "d":
# depth-wise conv
if isinstance(pc, int) and isinstance(c, int):
......@@ -101,7 +104,7 @@ class ShuffleXceptionBlock(ShuffleNetBlock):
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int, *, stride: int, affine: bool = True):
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *, stride: int, affine: bool = True):
super().__init__(in_channels, out_channels, mid_channels,
kernel_size=3, stride=stride, sequence="dpdpdp", affine=affine)
......@@ -154,7 +157,7 @@ class ShuffleNetSpace(nn.Module):
nn.ReLU(inplace=True),
)
self.features = []
feature_blocks = []
global_block_idx = 0
for stage_idx, num_repeat in enumerate(self.stage_repeats):
......@@ -175,15 +178,17 @@ class ShuffleNetSpace(nn.Module):
else:
mid_channels = int(base_mid_channels)
mid_channels = cast(nn.MaybeChoice[int], mid_channels)
choice_block = nn.LayerChoice([
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=3, stride=stride, affine=affine),
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=5, stride=stride, affine=affine),
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=7, stride=stride, affine=affine),
ShuffleXceptionBlock(in_channels, out_channels, mid_channels=mid_channels, stride=stride, affine=affine)
], label=f'layer_{global_block_idx}')
self.features.append(choice_block)
feature_blocks.append(choice_block)
self.features = nn.Sequential(*self.features)
self.features = nn.Sequential(*feature_blocks)
# final layers
last_conv_channels = self.stage_out_channels[-1]
......@@ -226,13 +231,15 @@ class ShuffleNetSpace(nn.Module):
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.running_mean, 0)
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.BatchNorm1d):
if m.weight is not None:
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.running_mean, 0)
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Useful type hints
......@@ -3,7 +3,7 @@
import logging
import os
from typing import Any, Callable
from typing import Any, Callable, Optional
import nni
from nni.common.serializer import PayloadTooLarge
......@@ -53,11 +53,11 @@ class RetiariiAdvisor(MsgDispatcherBase):
register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None
self.send_trial_callback: Callable[[dict], None] = None
self.request_trial_jobs_callback: Callable[[int], None] = None
self.trial_end_callback: Callable[[int, bool], None] = None
self.intermediate_metric_callback: Callable[[int, MetricData], None] = None
self.final_metric_callback: Callable[[int, MetricData], None] = None
self.send_trial_callback: Optional[Callable[[dict], None]] = None
self.request_trial_jobs_callback: Optional[Callable[[int], None]] = None
self.trial_end_callback: Optional[Callable[[int, bool], None]] = None
self.intermediate_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.parameters_count = 0
......@@ -158,19 +158,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data)
self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
if self.trial_end_callback is not None:
self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
if self.intermediate_metric_callback is not None:
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
elif data['type'] == MetricType.FINAL:
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
if self.final_metric_callback is not None:
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
@staticmethod
def _process_value(value) -> Any: # hopefully a float
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Iterable, List, Optional, Tuple)
import warnings
from typing import (Any, Iterable, List, Optional, Tuple, cast)
from .graph import Model, Mutation, ModelStatus
......@@ -44,9 +45,11 @@ class Mutator:
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
"""
def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None):
def __init__(self, sampler: Optional[Sampler] = None, label: str = cast(str, None)):
self.sampler: Optional[Sampler] = sampler
self.label: Optional[str] = label
if label is None:
warnings.warn('Each mutator should have an explicit label. Mutator without label is deprecated.', DeprecationWarning)
self.label: str = label
self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import itertools
import math
import operator
import warnings
from typing import Any, List, Union, Dict, Optional, Callable, Iterable, NoReturn, TypeVar, Sequence
from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
NoReturn, Optional, Sequence, SupportsRound, TypeVar,
Union, cast)
import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import Translatable
from nni.retiarii.serializer import basic_unit
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace, NoContextError
from nni.retiarii.utils import (STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace,
NoContextError)
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
__all__ = [
# APIs
'LayerChoice',
'InputChoice',
'ValueChoice',
'ModelParameterChoice',
'Placeholder',
# Fixed module
'ChosenInputs',
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'ModelParameterChoice', 'Placeholder', 'ChosenInputs']
# Type utils
'ReductionType',
'MaybeChoice',
'ChoiceOf',
]
class LayerChoice(Mutable):
......@@ -130,26 +147,16 @@ class LayerChoice(Mutable):
self.names.append(str(i))
else:
raise TypeError("Unsupported candidates type: {}".format(type(candidates)))
self._first_module = self._modules[self.names[0]] # to make the dummy forward meaningful
@property
def key(self):
return self._key()
@torch.jit.ignore
def _key(self):
warnings.warn('Using key to access the identifier of LayerChoice is deprecated. Please use label instead.',
category=DeprecationWarning)
return self._label
self._first_module = cast(nn.Module, self._modules[self.names[0]]) # to make the dummy forward meaningful
@property
def label(self):
return self._label
def __getitem__(self, idx):
def __getitem__(self, idx: Union[int, str]) -> nn.Module:
if isinstance(idx, str):
return self._modules[idx]
return list(self)[idx]
return cast(nn.Module, self._modules[idx])
return cast(nn.Module, list(self)[idx])
def __setitem__(self, idx, module):
key = idx if isinstance(idx, str) else self.names[idx]
......@@ -173,15 +180,6 @@ class LayerChoice(Mutable):
def __iter__(self):
return map(lambda name: self._modules[name], self.names)
@property
def choices(self):
return self._choices()
@torch.jit.ignore
def _choices(self):
warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", category=DeprecationWarning)
return list(self)
def forward(self, x):
"""
The forward of layer choice is simply running the first candidate module.
......@@ -266,16 +264,6 @@ class InputChoice(Mutable):
assert self.reduction in ['mean', 'concat', 'sum', 'none']
self._label = generate_new_label(label)
@property
def key(self):
return self._key()
@torch.jit.ignore
def _key(self):
warnings.warn('Using key to access the identifier of InputChoice is deprecated. Please use label instead.',
category=DeprecationWarning)
return self._label
@property
def label(self):
return self._label
......@@ -350,7 +338,7 @@ def _valuechoice_codegen(*, _internal: bool = False):
'truediv': '//', 'floordiv': '/', 'mod': '%',
'lshift': '<<', 'rshift': '>>',
'and': '&', 'xor': '^', 'or': '|',
# no reflection
# no reverse
'lt': '<', 'le': '<=', 'eq': '==',
'ne': '!=', 'ge': '>=', 'gt': '>',
# NOTE
......@@ -358,14 +346,14 @@ def _valuechoice_codegen(*, _internal: bool = False):
# Might support them in future when we actually need them.
}
binary_template = """ def __{op}__(self, other: Any) -> 'ValueChoiceX':
binary_template = """ def __{op}__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [self, other])"""
binary_r_template = """ def __r{op}__(self, other: Any) -> 'ValueChoiceX':
binary_r_template = """ def __r{op}__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [other, self])"""
unary_template = """ def __{op}__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.{op}, '{sym}{{}}', [self])"""
unary_template = """ def __{op}__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.{op}, '{sym}{{}}', [self]))"""
for op, sym in MAPPING.items():
if op in ['neg', 'pos', 'invert']:
......@@ -377,8 +365,14 @@ def _valuechoice_codegen(*, _internal: bool = False):
print(binary_r_template.format(op=op, opt=opt, sym=sym) + '\n')
def _valuechoice_staticmethod_helper(orig_func):
orig_func.__doc__ += """
_func = TypeVar('_func')
_cand = TypeVar('_cand')
_value = TypeVar('_value')
def _valuechoice_staticmethod_helper(orig_func: _func) -> _func:
if orig_func.__doc__ is not None:
orig_func.__doc__ += """
Notes
-----
This function performs lazy evaluation.
......@@ -388,7 +382,7 @@ def _valuechoice_staticmethod_helper(orig_func):
return orig_func
class ValueChoiceX(Translatable, nn.Module):
class ValueChoiceX(Generic[_cand], Translatable, nn.Module):
"""Internal API. Implementation note:
The transformed (X) version of value choice.
......@@ -408,7 +402,10 @@ class ValueChoiceX(Translatable, nn.Module):
This class is implemented as a ``nn.Module`` so that it can be scanned by python engine / torchscript.
"""
def __init__(self, function: Callable[..., Any], repr_template: str, arguments: List[Any], dry_run: bool = True):
def __init__(self, function: Callable[..., _cand] = cast(Callable[..., _cand], None),
repr_template: str = cast(str, None),
arguments: List[Any] = cast('List[MaybeChoice[_cand]]', None),
dry_run: bool = True):
super().__init__()
if function is None:
......@@ -431,7 +428,7 @@ class ValueChoiceX(Translatable, nn.Module):
def inner_choices(self) -> Iterable['ValueChoice']:
"""
Return an iterable of all leaf value choices.
Return a generator of all leaf value choices.
Useful for composition of value choices.
No deduplication on labels. Mutators should take care.
"""
......@@ -439,18 +436,18 @@ class ValueChoiceX(Translatable, nn.Module):
if isinstance(arg, ValueChoiceX):
yield from arg.inner_choices()
def dry_run(self) -> Any:
def dry_run(self) -> _cand:
"""
Dry run the value choice to get one of its possible evaluation results.
"""
# values are not used
return self._evaluate(iter([]), True)
def all_options(self) -> Iterable[Any]:
def all_options(self) -> Iterable[_cand]:
"""Explore all possibilities of a value choice.
"""
# Record all inner choices: label -> candidates, no duplicates.
dedup_inner_choices: Dict[str, List[Any]] = {}
dedup_inner_choices: Dict[str, List[_cand]] = {}
# All labels of leaf nodes on tree, possibly duplicates.
all_labels: List[str] = []
......@@ -470,14 +467,14 @@ class ValueChoiceX(Translatable, nn.Module):
chosen = dict(zip(dedup_labels, chosen))
yield self.evaluate([chosen[label] for label in all_labels])
def evaluate(self, values: Iterable[Any]) -> Any:
def evaluate(self, values: Iterable[_cand]) -> _cand:
"""
Evaluate the result of this group.
``values`` should in the same order of ``inner_choices()``.
"""
return self._evaluate(iter(values), False)
def _evaluate(self, values: Iterable[Any], dry_run: bool = False) -> Any:
def _evaluate(self, values: Iterator[_cand], dry_run: bool = False) -> _cand:
# "values" iterates in the recursion
eval_args = []
for arg in self.arguments:
......@@ -497,7 +494,7 @@ class ValueChoiceX(Translatable, nn.Module):
"""
return self.dry_run()
def __repr__(self):
def __repr__(self) -> str:
reprs = []
for arg in self.arguments:
if isinstance(arg, ValueChoiceX) and not isinstance(arg, ValueChoice):
......@@ -513,7 +510,7 @@ class ValueChoiceX(Translatable, nn.Module):
# Special operators that can be useful in place of built-in conditional operators.
@staticmethod
@_valuechoice_staticmethod_helper
def to_int(obj: 'ValueChoiceOrAny') -> Union['ValueChoiceX', int]:
def to_int(obj: 'MaybeChoice[Any]') -> 'MaybeChoice[int]':
"""
Convert a ``ValueChoice`` to an integer.
"""
......@@ -523,7 +520,7 @@ class ValueChoiceX(Translatable, nn.Module):
@staticmethod
@_valuechoice_staticmethod_helper
def to_float(obj: 'ValueChoiceOrAny') -> Union['ValueChoiceX', float]:
def to_float(obj: 'MaybeChoice[Any]') -> 'MaybeChoice[float]':
"""
Convert a ``ValueChoice`` to a float.
"""
......@@ -533,9 +530,9 @@ class ValueChoiceX(Translatable, nn.Module):
@staticmethod
@_valuechoice_staticmethod_helper
def condition(pred: 'ValueChoiceOrAny',
true: 'ValueChoiceOrAny',
false: 'ValueChoiceOrAny') -> 'ValueChoiceOrAny':
def condition(pred: 'MaybeChoice[bool]',
true: 'MaybeChoice[_value]',
false: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
"""
Return ``true`` if the predicate ``pred`` is true else ``false``.
......@@ -549,35 +546,39 @@ class ValueChoiceX(Translatable, nn.Module):
@staticmethod
@_valuechoice_staticmethod_helper
def max(arg0: Union[Iterable['ValueChoiceOrAny'], 'ValueChoiceOrAny'],
*args: List['ValueChoiceOrAny']) -> 'ValueChoiceOrAny':
def max(arg0: Union[Iterable['MaybeChoice[_value]'], 'MaybeChoice[_value]'],
*args: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
"""
Returns the maximum value from a list of value choices.
The usage should be similar to Python's built-in value choices,
where the parameters could be an iterable, or at least two arguments.
"""
if not args:
return ValueChoiceX.max(*list(arg0))
lst = [arg0] + list(args)
if not isinstance(arg0, Iterable):
raise TypeError('Expect more than one items to compare max')
return cast(MaybeChoice[_value], ValueChoiceX.max(*list(arg0)))
lst = list(arg0) if isinstance(arg0, Iterable) else [arg0] + list(args)
if any(isinstance(obj, ValueChoiceX) for obj in lst):
return ValueChoiceX(max, 'max({})', lst)
return max(lst)
return max(cast(Any, lst))
@staticmethod
@_valuechoice_staticmethod_helper
def min(arg0: Union[Iterable['ValueChoiceOrAny'], 'ValueChoiceOrAny'],
*args: List['ValueChoiceOrAny']) -> 'ValueChoiceOrAny':
def min(arg0: Union[Iterable['MaybeChoice[_value]'], 'MaybeChoice[_value]'],
*args: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
"""
Returns the minunum value from a list of value choices.
The usage should be similar to Python's built-in value choices,
where the parameters could be an iterable, or at least two arguments.
"""
if not args:
return ValueChoiceX.min(*list(arg0))
lst = [arg0] + list(args)
if not isinstance(arg0, Iterable):
raise TypeError('Expect more than one items to compare min')
return cast(MaybeChoice[_value], ValueChoiceX.min(*list(arg0)))
lst = list(arg0) if isinstance(arg0, Iterable) else [arg0] + list(args)
if any(isinstance(obj, ValueChoiceX) for obj in lst):
return ValueChoiceX(min, 'min({})', lst)
return min(lst)
return min(cast(Any, lst))
def __hash__(self):
# this is required because we have implemented ``__eq__``
......@@ -589,24 +590,25 @@ class ValueChoiceX(Translatable, nn.Module):
# - Implementation effort is too huge.
# As a result, inplace operators like +=, *=, magic methods like `__getattr__` are not included in this list.
def __getitem__(self, key: Any) -> 'ValueChoiceX':
def __getitem__(self: 'ChoiceOf[Any]', key: Any) -> 'ChoiceOf[Any]':
return ValueChoiceX(lambda x, y: x[y], '{}[{}]', [self, key])
# region implement int, float, round, trunc, floor, ceil
# because I believe sometimes we need them to calculate #channels
# `__int__` and `__float__` are not supported because `__int__` is required to return int.
def __round__(self, ndigits: Optional[Any] = None) -> 'ValueChoiceX':
def __round__(self: 'ChoiceOf[SupportsRound[_value]]',
ndigits: Optional['MaybeChoice[int]'] = None) -> 'ChoiceOf[Union[int, SupportsRound[_value]]]':
if ndigits is not None:
return ValueChoiceX(round, 'round({}, {})', [self, ndigits])
return ValueChoiceX(round, 'round({})', [self])
return cast(ChoiceOf[Union[int, SupportsRound[_value]]], ValueChoiceX(round, 'round({}, {})', [self, ndigits]))
return cast(ChoiceOf[Union[int, SupportsRound[_value]]], ValueChoiceX(round, 'round({})', [self]))
def __trunc__(self) -> 'ValueChoiceX':
def __trunc__(self) -> NoReturn:
raise RuntimeError("Try to use `ValueChoice.to_int()` instead of `math.trunc()` on value choices.")
def __floor__(self) -> 'ValueChoiceX':
def __floor__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[int]':
return ValueChoiceX(math.floor, 'math.floor({})', [self])
def __ceil__(self) -> 'ValueChoiceX':
def __ceil__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[int]':
return ValueChoiceX(math.ceil, 'math.ceil({})', [self])
def __index__(self) -> NoReturn:
......@@ -622,132 +624,133 @@ class ValueChoiceX(Translatable, nn.Module):
# region the following code is generated with codegen (see above)
# Annotated with "region" because I want to collapse them in vscode
def __neg__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.neg, '-{}', [self])
def __neg__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.neg, '-{}', [self]))
def __pos__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.pos, '+{}', [self])
def __pos__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.pos, '+{}', [self]))
def __invert__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.invert, '~{}', [self])
def __invert__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.invert, '~{}', [self]))
def __add__(self, other: Any) -> 'ValueChoiceX':
def __add__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.add, '{} + {}', [self, other])
def __radd__(self, other: Any) -> 'ValueChoiceX':
def __radd__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.add, '{} + {}', [other, self])
def __sub__(self, other: Any) -> 'ValueChoiceX':
def __sub__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.sub, '{} - {}', [self, other])
def __rsub__(self, other: Any) -> 'ValueChoiceX':
def __rsub__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.sub, '{} - {}', [other, self])
def __mul__(self, other: Any) -> 'ValueChoiceX':
def __mul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mul, '{} * {}', [self, other])
def __rmul__(self, other: Any) -> 'ValueChoiceX':
def __rmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mul, '{} * {}', [other, self])
def __matmul__(self, other: Any) -> 'ValueChoiceX':
def __matmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.matmul, '{} @ {}', [self, other])
def __rmatmul__(self, other: Any) -> 'ValueChoiceX':
def __rmatmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.matmul, '{} @ {}', [other, self])
def __truediv__(self, other: Any) -> 'ValueChoiceX':
def __truediv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.truediv, '{} // {}', [self, other])
def __rtruediv__(self, other: Any) -> 'ValueChoiceX':
def __rtruediv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.truediv, '{} // {}', [other, self])
def __floordiv__(self, other: Any) -> 'ValueChoiceX':
def __floordiv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.floordiv, '{} / {}', [self, other])
def __rfloordiv__(self, other: Any) -> 'ValueChoiceX':
def __rfloordiv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.floordiv, '{} / {}', [other, self])
def __mod__(self, other: Any) -> 'ValueChoiceX':
def __mod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mod, '{} % {}', [self, other])
def __rmod__(self, other: Any) -> 'ValueChoiceX':
def __rmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mod, '{} % {}', [other, self])
def __lshift__(self, other: Any) -> 'ValueChoiceX':
def __lshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.lshift, '{} << {}', [self, other])
def __rlshift__(self, other: Any) -> 'ValueChoiceX':
def __rlshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.lshift, '{} << {}', [other, self])
def __rshift__(self, other: Any) -> 'ValueChoiceX':
def __rshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.rshift, '{} >> {}', [self, other])
def __rrshift__(self, other: Any) -> 'ValueChoiceX':
def __rrshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.rshift, '{} >> {}', [other, self])
def __and__(self, other: Any) -> 'ValueChoiceX':
def __and__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.and_, '{} & {}', [self, other])
def __rand__(self, other: Any) -> 'ValueChoiceX':
def __rand__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.and_, '{} & {}', [other, self])
def __xor__(self, other: Any) -> 'ValueChoiceX':
def __xor__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.xor, '{} ^ {}', [self, other])
def __rxor__(self, other: Any) -> 'ValueChoiceX':
def __rxor__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.xor, '{} ^ {}', [other, self])
def __or__(self, other: Any) -> 'ValueChoiceX':
def __or__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.or_, '{} | {}', [self, other])
def __ror__(self, other: Any) -> 'ValueChoiceX':
def __ror__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.or_, '{} | {}', [other, self])
def __lt__(self, other: Any) -> 'ValueChoiceX':
def __lt__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.lt, '{} < {}', [self, other])
def __le__(self, other: Any) -> 'ValueChoiceX':
def __le__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.le, '{} <= {}', [self, other])
def __eq__(self, other: Any) -> 'ValueChoiceX':
def __eq__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.eq, '{} == {}', [self, other])
def __ne__(self, other: Any) -> 'ValueChoiceX':
def __ne__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.ne, '{} != {}', [self, other])
def __ge__(self, other: Any) -> 'ValueChoiceX':
def __ge__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.ge, '{} >= {}', [self, other])
def __gt__(self, other: Any) -> 'ValueChoiceX':
def __gt__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.gt, '{} > {}', [self, other])
# endregion
# __pow__, __divmod__, __abs__ are special ones.
# Not easy to cover those cases with codegen.
def __pow__(self, other: Any, modulo: Optional[Any] = None) -> 'ValueChoiceX':
def __pow__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]', modulo: Optional['MaybeChoice[Any]'] = None) -> 'ChoiceOf[Any]':
if modulo is not None:
return ValueChoiceX(pow, 'pow({}, {}, {})', [self, other, modulo])
return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [self, other])
def __rpow__(self, other: Any, modulo: Optional[Any] = None) -> 'ValueChoiceX':
def __rpow__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]', modulo: Optional['MaybeChoice[Any]'] = None) -> 'ChoiceOf[Any]':
if modulo is not None:
return ValueChoiceX(pow, 'pow({}, {}, {})', [other, self, modulo])
return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [other, self])
def __divmod__(self, other: Any) -> 'ValueChoiceX':
def __divmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(divmod, 'divmod({}, {})', [self, other])
def __rdivmod__(self, other: Any) -> 'ValueChoiceX':
def __rdivmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(divmod, 'divmod({}, {})', [other, self])
def __abs__(self) -> 'ValueChoiceX':
def __abs__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(abs, 'abs({})', [self])
ValueChoiceOrAny = TypeVar('ValueChoiceOrAny', ValueChoiceX, Any)
ChoiceOf = ValueChoiceX
MaybeChoice = Union[ValueChoiceX[_cand], _cand]
class ValueChoice(ValueChoiceX, Mutable):
class ValueChoice(ValueChoiceX[_cand], Mutable):
"""
ValueChoice is to choose one from ``candidates``. The most common use cases are:
......@@ -865,14 +868,14 @@ class ValueChoice(ValueChoiceX, Mutable):
# FIXME: prior is designed but not supported yet
@classmethod
def create_fixed_module(cls, candidates: List[Any], *, label: Optional[str] = None, **kwargs):
def create_fixed_module(cls, candidates: List[_cand], *, label: Optional[str] = None, **kwargs):
value = get_fixed_value(label)
if value not in candidates:
raise ValueError(f'Value {value} does not belong to the candidates: {candidates}.')
return value
def __init__(self, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__(None, None, None)
def __init__(self, candidates: List[_cand], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__()
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
......@@ -894,10 +897,10 @@ class ValueChoice(ValueChoiceX, Mutable):
# yield self because self is the only value choice here
yield self
def dry_run(self) -> Any:
def dry_run(self) -> _cand:
return self.candidates[0]
def _evaluate(self, values: Iterable[Any], dry_run: bool = False) -> Any:
def _evaluate(self, values: Iterator[_cand], dry_run: bool = False) -> _cand:
if dry_run:
return self.candidates[0]
try:
......@@ -986,6 +989,7 @@ class ModelParameterChoice:
Examples
--------
Get a dynamic-shaped parameter. Because ``torch.zeros`` is not a basic unit, we can't use :class:`ValueChoice` on it.
>>> parameter_dim = nn.ModelParameterChoice([64, 128, 256])
>>> self.token = nn.Parameter(torch.zeros(1, parameter_dim, 32, 32))
"""
......@@ -1016,12 +1020,14 @@ class ModelParameterChoice:
if default not in candidates:
# could be callable
try:
default = default(candidates)
default = cast(Callable[[List[ValueType]], ValueType], default)(candidates)
except TypeError as e:
if 'not callable' in str(e):
raise TypeError("`default` is not in `candidates`, and it's also not callable.")
raise
default = cast(ValueType, default)
label = generate_new_label(label)
parameter_spec = ParameterSpec(
label, # name
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment