Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
...@@ -4,10 +4,11 @@ ...@@ -4,10 +4,11 @@
import os import os
import warnings import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Union, Optional, List, Callable from typing import Dict, Union, Optional, List, Callable, Type
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as nn_functional
import torch.optim as optim import torch.optim as optim
import torchmetrics import torchmetrics
import torch.utils.data as torch_data import torch.utils.data as torch_data
...@@ -124,12 +125,12 @@ class Lightning(Evaluator): ...@@ -124,12 +125,12 @@ class Lightning(Evaluator):
if other is None: if other is None:
return False return False
if hasattr(self, "function") and hasattr(other, "function"): if hasattr(self, "function") and hasattr(other, "function"):
eq_func = (self.function == other.function) eq_func = getattr(self, "function") == getattr(other, "function")
elif not (hasattr(self, "function") or hasattr(other, "function")): elif not (hasattr(self, "function") or hasattr(other, "function")):
eq_func = True eq_func = True
if hasattr(self, "arguments") and hasattr(other, "arguments"): if hasattr(self, "arguments") and hasattr(other, "arguments"):
eq_args = (self.arguments == other.arguments) eq_args = getattr(self, "arguments") == getattr(other, "arguments")
elif not (hasattr(self, "arguments") or hasattr(other, "arguments")): elif not (hasattr(self, "arguments") or hasattr(other, "arguments")):
eq_args = True eq_args = True
...@@ -159,10 +160,13 @@ def _check_dataloader(dataloader): ...@@ -159,10 +160,13 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ### ### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule): class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
trainer: pl.Trainer
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetrics.Metric]],
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam, optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None): export_onnx: Union[Path, str, bool, None] = None):
super().__init__() super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay') self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
...@@ -214,7 +218,7 @@ class _SupervisedLearningModule(LightningModule): ...@@ -214,7 +218,7 @@ class _SupervisedLearningModule(LightningModule):
self.log('test_' + name, metric(y_hat, y), prog_bar=True) self.log('test_' + name, metric(y_hat, y), prog_bar=True)
def configure_optimizers(self): def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self): def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics()) nni.report_intermediate_result(self._get_validation_metrics())
...@@ -233,15 +237,15 @@ class _SupervisedLearningModule(LightningModule): ...@@ -233,15 +237,15 @@ class _SupervisedLearningModule(LightningModule):
class _AccuracyWithLogits(torchmetrics.Accuracy): class _AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target): def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target) return super().update(nn_functional.softmax(pred), target)
@nni.trace @nni.trace
class _ClassificationModule(_SupervisedLearningModule): class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam, optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True): export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits}, super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
...@@ -275,10 +279,10 @@ class Classification(Lightning): ...@@ -275,10 +279,10 @@ class Classification(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details. `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
""" """
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss, def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam, optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None, train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True, export_onnx: bool = True,
...@@ -291,10 +295,10 @@ class Classification(Lightning): ...@@ -291,10 +295,10 @@ class Classification(Lightning):
@nni.trace @nni.trace
class _RegressionModule(_SupervisedLearningModule): class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss, def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam, optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True): export_onnx: bool = True):
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError}, super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer, learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
...@@ -328,10 +332,10 @@ class Regression(Lightning): ...@@ -328,10 +332,10 @@ class Regression(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details. `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
""" """
def __init__(self, criterion: nn.Module = nn.MSELoss, def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001, learning_rate: float = 0.001,
weight_decay: float = 0., weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam, optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None, train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None, val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True, export_onnx: bool = True,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .api import * from .api import *
...@@ -129,6 +129,7 @@ class BaseExecutionEngine(AbstractExecutionEngine): ...@@ -129,6 +129,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
@classmethod @classmethod
def pack_model_data(cls, model: Model) -> Any: def pack_model_data(cls, model: Model) -> Any:
mutation_summary = get_mutation_summary(model) mutation_summary = get_mutation_summary(model)
assert model.evaluator is not None, 'Model evaluator can not be None'
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary) return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)
@classmethod @classmethod
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os import os
import random import random
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
from ..graph import Model from ..graph import Model
from ..integration_api import receive_trial_parameters from ..integration_api import receive_trial_parameters
...@@ -39,6 +42,9 @@ class BenchmarkGraphData: ...@@ -39,6 +42,9 @@ class BenchmarkGraphData:
def load(data) -> 'BenchmarkGraphData': def load(data) -> 'BenchmarkGraphData':
return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path']) return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path'])
def __repr__(self) -> str:
return f"BenchmarkGraphData({self.mutation}, {self.benchmark}, {self.db_path})"
class BenchmarkExecutionEngine(BaseExecutionEngine): class BenchmarkExecutionEngine(BaseExecutionEngine):
""" """
...@@ -67,6 +73,7 @@ class BenchmarkExecutionEngine(BaseExecutionEngine): ...@@ -67,6 +73,7 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
@classmethod @classmethod
def trial_execute_graph(cls) -> None: def trial_execute_graph(cls) -> None:
graph_data = BenchmarkGraphData.load(receive_trial_parameters()) graph_data = BenchmarkGraphData.load(receive_trial_parameters())
assert graph_data.db_path is not None, f'Invalid graph data because db_path is None: {graph_data}'
os.environ['NASBENCHMARK_DIR'] = graph_data.db_path os.environ['NASBENCHMARK_DIR'] = graph_data.db_path
final, intermediates = cls.query_in_benchmark(graph_data) final, intermediates = cls.query_in_benchmark(graph_data)
...@@ -89,7 +96,6 @@ class BenchmarkExecutionEngine(BaseExecutionEngine): ...@@ -89,7 +96,6 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
arch = t arch = t
if arch is None: if arch is None:
raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}') raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}')
print(arch)
return _convert_to_final_and_intermediates( return _convert_to_final_and_intermediates(
query_nb101_trial_stats(arch, 108, include_intermediates=True), query_nb101_trial_stats(arch, 108, include_intermediates=True),
'valid_acc' 'valid_acc'
...@@ -146,4 +152,5 @@ def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_ ...@@ -146,4 +152,5 @@ def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_
benchmark_result = random.choice(benchmark_result) benchmark_result = random.choice(benchmark_result)
else: else:
benchmark_result = benchmark_result[0] benchmark_result = benchmark_result[0]
benchmark_result = cast(dict, benchmark_result)
return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None] return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None]
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import logging import logging
import os import os
import random import random
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Any, Type from typing import Dict, Any, Type
import torch.nn as nn import torch.nn as nn
...@@ -49,7 +52,8 @@ class PurePythonExecutionEngine(BaseExecutionEngine): ...@@ -49,7 +52,8 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod @classmethod
def pack_model_data(cls, model: Model) -> Any: def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model) mutation = get_mutation_dict(model)
graph_data = PythonGraphData(model.python_class, model.python_init_params, mutation, model.evaluator) assert model.evaluator is not None, 'Model evaluator is not available.'
graph_data = PythonGraphData(model.python_class, model.python_init_params or {}, mutation, model.evaluator)
return graph_data return graph_data
@classmethod @classmethod
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, List from typing import Any, List
from ..graph import Model from ..graph import Model
......
...@@ -11,7 +11,7 @@ from dataclasses import dataclass ...@@ -11,7 +11,7 @@ from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from subprocess import Popen from subprocess import Popen
from threading import Thread from threading import Thread
from typing import Any, List, Optional, Union from typing import Any, List, Optional, Union, cast
import colorama import colorama
import psutil import psutil
...@@ -23,6 +23,7 @@ from nni.experiment import Experiment, launcher, management, rest ...@@ -23,6 +23,7 @@ from nni.experiment import Experiment, launcher, management, rest
from nni.experiment.config import utils from nni.experiment.config import utils
from nni.experiment.config.base import ConfigBase from nni.experiment.config.base import ConfigBase
from nni.experiment.config.training_service import TrainingServiceConfig from nni.experiment.config.training_service import TrainingServiceConfig
from nni.experiment.config.training_services import RemoteConfig
from nni.experiment.pipe import Pipe from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command from nni.tools.nnictl.command_utils import kill_command
...@@ -222,6 +223,7 @@ class RetiariiExperiment(Experiment): ...@@ -222,6 +223,7 @@ class RetiariiExperiment(Experiment):
Examples Examples
-------- --------
Multi-trial NAS: Multi-trial NAS:
>>> base_model = Net() >>> base_model = Net()
>>> search_strategy = strategy.Random() >>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model) >>> model_evaluator = FunctionalEvaluator(evaluate_model)
...@@ -233,6 +235,7 @@ class RetiariiExperiment(Experiment): ...@@ -233,6 +235,7 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config, 8081) >>> exp.run(exp_config, 8081)
One-shot NAS: One-shot NAS:
>>> base_model = Net() >>> base_model = Net()
>>> search_strategy = strategy.DARTS() >>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader) >>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
...@@ -242,15 +245,16 @@ class RetiariiExperiment(Experiment): ...@@ -242,15 +245,16 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config) >>> exp.run(exp_config)
Export top models: Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'): >>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict) ... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict): >>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net() ... final_model = Net()
""" """
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = None, def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None, applied_mutators: List[Mutator] = cast(List[Mutator], None), strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: BaseOneShotTrainer = None): trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)):
if trainer is not None: if trainer is not None:
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. ' warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning) 'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
...@@ -260,21 +264,22 @@ class RetiariiExperiment(Experiment): ...@@ -260,21 +264,22 @@ class RetiariiExperiment(Experiment):
raise ValueError('Evaluator should not be none.') raise ValueError('Evaluator should not be none.')
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed. # TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = None self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
self.port: Optional[int] = None self.port: Optional[int] = None
self.base_model = base_model self.base_model = base_model
self.evaluator: Evaluator = evaluator self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator
self.applied_mutators = applied_mutators self.applied_mutators = applied_mutators
self.strategy = strategy self.strategy = strategy
# FIXME: this is only a workaround
from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy
if not isinstance(strategy, OneShotStrategy): if not isinstance(strategy, OneShotStrategy):
self._dispatcher = RetiariiAdvisor() self._dispatcher = RetiariiAdvisor()
self._dispatcher_thread: Optional[Thread] = None else:
self._proc: Optional[Popen] = None self._dispatcher = cast(RetiariiAdvisor, None)
self._pipe: Optional[Pipe] = None self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
self.url_prefix = None self.url_prefix = None
...@@ -325,7 +330,7 @@ class RetiariiExperiment(Experiment): ...@@ -325,7 +330,7 @@ class RetiariiExperiment(Experiment):
assert self.config.training_service.platform == 'remote', \ assert self.config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service" "CGO execution engine currently only supports remote training service"
assert self.config.batch_waiting_time is not None assert self.config.batch_waiting_time is not None and self.config.max_concurrency_cgo is not None
devices = self._construct_devices() devices = self._construct_devices()
engine = CGOExecutionEngine(devices, engine = CGOExecutionEngine(devices,
max_concurrency=self.config.max_concurrency_cgo, max_concurrency=self.config.max_concurrency_cgo,
...@@ -335,7 +340,10 @@ class RetiariiExperiment(Experiment): ...@@ -335,7 +340,10 @@ class RetiariiExperiment(Experiment):
engine = PurePythonExecutionEngine() engine = PurePythonExecutionEngine()
elif self.config.execution_engine == 'benchmark': elif self.config.execution_engine == 'benchmark':
from ..execution.benchmark import BenchmarkExecutionEngine from ..execution.benchmark import BenchmarkExecutionEngine
assert self.config.benchmark is not None, '"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(self.config.benchmark) engine = BenchmarkExecutionEngine(self.config.benchmark)
else:
raise ValueError(f'Unsupported engine type: {self.config.execution_engine}')
set_execution_engine(engine) set_execution_engine(engine)
self.id = management.generate_experiment_id() self.id = management.generate_experiment_id()
...@@ -377,9 +385,10 @@ class RetiariiExperiment(Experiment): ...@@ -377,9 +385,10 @@ class RetiariiExperiment(Experiment):
def _construct_devices(self): def _construct_devices(self):
devices = [] devices = []
if hasattr(self.config.training_service, 'machine_list'): if hasattr(self.config.training_service, 'machine_list'):
for machine in self.config.training_service.machine_list: for machine in cast(RemoteConfig, self.config.training_service).machine_list:
assert machine.gpu_indices is not None, \ assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine' 'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices: for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx)) devices.append(GPUDevice(machine.host, gpu_idx))
return devices return devices
...@@ -387,7 +396,7 @@ class RetiariiExperiment(Experiment): ...@@ -387,7 +396,7 @@ class RetiariiExperiment(Experiment):
def _create_dispatcher(self): def _create_dispatcher(self):
return self._dispatcher return self._dispatcher
def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str: def run(self, config: Optional[RetiariiExeConfig] = None, port: int = 8080, debug: bool = False) -> None:
""" """
Run the experiment. Run the experiment.
This function will block until experiment finish or error. This function will block until experiment finish or error.
...@@ -420,6 +429,7 @@ class RetiariiExperiment(Experiment): ...@@ -420,6 +429,7 @@ class RetiariiExperiment(Experiment):
This function will block until experiment finish or error. This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed. Return `True` when experiment done; or return `False` when experiment failed.
""" """
assert self._proc is not None
try: try:
while True: while True:
time.sleep(10) time.sleep(10)
...@@ -437,6 +447,7 @@ class RetiariiExperiment(Experiment): ...@@ -437,6 +447,7 @@ class RetiariiExperiment(Experiment):
_logger.warning('KeyboardInterrupt detected') _logger.warning('KeyboardInterrupt detected')
finally: finally:
self.stop() self.stop()
raise RuntimeError('Check experiment status failed.')
def stop(self) -> None: def stop(self) -> None:
""" """
...@@ -466,11 +477,11 @@ class RetiariiExperiment(Experiment): ...@@ -466,11 +477,11 @@ class RetiariiExperiment(Experiment):
if self._pipe is not None: if self._pipe is not None:
self._pipe.close() self._pipe.close()
self.id = None self.id = cast(str, None)
self.port = None self.port = cast(int, None)
self._proc = None self._proc = None
self._pipe = None self._pipe = None
self._dispatcher = None self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread = None self._dispatcher_thread = None
_logger.info('Experiment stopped') _logger.info('Experiment stopped')
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json import json
import logging import logging
from pathlib import Path from pathlib import Path
......
...@@ -5,10 +5,16 @@ ...@@ -5,10 +5,16 @@
Model representation. Model representation.
""" """
from __future__ import annotations
import abc import abc
import json import json
from enum import Enum from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload) from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union, cast, overload)
if TYPE_CHECKING:
from .mutator import Mutator
from .operation import Cell, Operation, _IOPseudoOperation from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid from .utils import uid
...@@ -63,7 +69,7 @@ class Evaluator(abc.ABC): ...@@ -63,7 +69,7 @@ class Evaluator(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def _execute(self, model_cls: type) -> Any: def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
pass pass
@abc.abstractmethod @abc.abstractmethod
...@@ -203,7 +209,7 @@ class Model: ...@@ -203,7 +209,7 @@ class Model:
matched_nodes.extend(nodes) matched_nodes.extend(nodes)
return matched_nodes return matched_nodes
def get_node_by_name(self, node_name: str) -> 'Node': def get_node_by_name(self, node_name: str) -> 'Node' | None:
""" """
Traverse all the nodes to find the matched node with the given name. Traverse all the nodes to find the matched node with the given name.
""" """
...@@ -217,7 +223,7 @@ class Model: ...@@ -217,7 +223,7 @@ class Model:
else: else:
return None return None
def get_node_by_python_name(self, python_name: str) -> 'Node': def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
""" """
Traverse all the nodes to find the matched node with the given python_name. Traverse all the nodes to find the matched node with the given python_name.
""" """
...@@ -297,7 +303,7 @@ class Graph: ...@@ -297,7 +303,7 @@ class Graph:
The name of torch.nn.Module, should have one-to-one mapping with items in python model. The name of torch.nn.Module, should have one-to-one mapping with items in python model.
""" """
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False): def __init__(self, model: Model, graph_id: int, name: str = cast(str, None), _internal: bool = False):
assert _internal, '`Graph()` is private' assert _internal, '`Graph()` is private'
self.model: Model = model self.model: Model = model
...@@ -338,9 +344,9 @@ class Graph: ...@@ -338,9 +344,9 @@ class Graph:
@overload @overload
def add_node(self, name: str, operation: Operation) -> 'Node': ... def add_node(self, name: str, operation: Operation) -> 'Node': ...
@overload @overload
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ... def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def add_node(self, name, operation_or_type, parameters=None): def add_node(self, name, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation): if isinstance(operation_or_type, Operation):
op = operation_or_type op = operation_or_type
else: else:
...@@ -350,9 +356,10 @@ class Graph: ...@@ -350,9 +356,10 @@ class Graph:
@overload @overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ... def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
@overload @overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ... def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node': def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node': # type: ignore
if isinstance(operation_or_type, Operation): if isinstance(operation_or_type, Operation):
op = operation_or_type op = operation_or_type
else: else:
...@@ -405,7 +412,7 @@ class Graph: ...@@ -405,7 +412,7 @@ class Graph:
def get_nodes_by_name(self, name: str) -> List['Node']: def get_nodes_by_name(self, name: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.name == name] return [node for node in self.hidden_nodes if node.name == name]
def get_nodes_by_python_name(self, python_name: str) -> Optional['Node']: def get_nodes_by_python_name(self, python_name: str) -> List['Node']:
return [node for node in self.nodes if node.python_name == python_name] return [node for node in self.nodes if node.python_name == python_name]
def topo_sort(self) -> List['Node']: def topo_sort(self) -> List['Node']:
...@@ -594,7 +601,7 @@ class Node: ...@@ -594,7 +601,7 @@ class Node:
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id)) return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))
@property @property
def successor_slots(self) -> List[Tuple['Node', Union[int, None]]]: def successor_slots(self) -> Set[Tuple['Node', Union[int, None]]]:
return set((edge.tail, edge.tail_slot) for edge in self.outgoing_edges) return set((edge.tail, edge.tail_slot) for edge in self.outgoing_edges)
@property @property
...@@ -610,19 +617,19 @@ class Node: ...@@ -610,19 +617,19 @@ class Node:
assert isinstance(self.operation, Cell) assert isinstance(self.operation, Cell)
return self.graph.model.graphs[self.operation.parameters['cell']] return self.graph.model.graphs[self.operation.parameters['cell']]
def update_label(self, label: str) -> None: def update_label(self, label: Optional[str]) -> None:
self.label = label self.label = label
@overload @overload
def update_operation(self, operation: Operation) -> None: ... def update_operation(self, operation: Operation) -> None: ...
@overload @overload
def update_operation(self, type_name: str, parameters: Dict[str, Any] = None) -> None: ... def update_operation(self, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> None: ...
def update_operation(self, operation_or_type, parameters=None): def update_operation(self, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation): if isinstance(operation_or_type, Operation):
self.operation = operation_or_type self.operation = operation_or_type
else: else:
self.operation = Operation.new(operation_or_type, parameters) self.operation = Operation.new(operation_or_type, cast(dict, parameters))
# mutation # mutation
def remove(self) -> None: def remove(self) -> None:
...@@ -663,7 +670,13 @@ class Node: ...@@ -663,7 +670,13 @@ class Node:
return node return node
def _dump(self) -> Any: def _dump(self) -> Any:
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters, 'attributes': self.operation.attributes}} ret: Dict[str, Any] = {
'operation': {
'type': self.operation.type,
'parameters': self.operation.parameters,
'attributes': self.operation.attributes
}
}
if isinstance(self.operation, Cell): if isinstance(self.operation, Cell):
ret['operation']['cell_name'] = self.operation.cell_name ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None: if self.label is not None:
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import Tuple, Optional, Callable from typing import Tuple, Optional, Callable, cast
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper from nni.retiarii import model_wrapper
...@@ -75,10 +75,10 @@ class MobileNetV3Space(nn.Module): ...@@ -75,10 +75,10 @@ class MobileNetV3Space(nn.Module):
bn_momentum: float = 0.1): bn_momentum: float = 0.1):
super().__init__() super().__init__()
self.widths = [ self.widths = cast(nn.ChoiceOf[int], [
nn.ValueChoice([make_divisible(base_width * mult, 8) for mult in width_multipliers], label=f'width_{i}') nn.ValueChoice([make_divisible(base_width * mult, 8) for mult in width_multipliers], label=f'width_{i}')
for i, base_width in enumerate(base_widths) for i, base_width in enumerate(base_widths)
] ])
self.expand_ratios = expand_ratios self.expand_ratios = expand_ratios
blocks = [ blocks = [
...@@ -115,7 +115,7 @@ class MobileNetV3Space(nn.Module): ...@@ -115,7 +115,7 @@ class MobileNetV3Space(nn.Module):
self.classifier = nn.Sequential( self.classifier = nn.Sequential(
nn.Dropout(dropout_rate), nn.Dropout(dropout_rate),
nn.Linear(self.widths[7], num_labels), nn.Linear(cast(int, self.widths[7]), num_labels),
) )
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps) reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
import math import math
import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii import model_wrapper from nni.retiarii import model_wrapper
from nni.retiarii.nn.pytorch import NasBench101Cell from nni.retiarii.nn.pytorch import NasBench101Cell
...@@ -11,7 +12,7 @@ from nni.retiarii.nn.pytorch import NasBench101Cell ...@@ -11,7 +12,7 @@ from nni.retiarii.nn.pytorch import NasBench101Cell
__all__ = ['NasBench101'] __all__ = ['NasBench101']
def truncated_normal_(tensor, mean=0, std=1): def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15 # https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size = tensor.shape size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_() tmp = tensor.new_empty(size + (4,)).normal_()
...@@ -117,9 +118,3 @@ class NasBench101(nn.Module): ...@@ -117,9 +118,3 @@ class NasBench101(nn.Module):
out = self.gap(out).view(bs, -1) out = self.gap(out).view(bs, -1)
out = self.classifier(out) out = self.classifier(out)
return out return out
def reset_parameters(self):
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.eps = self.config.bn_eps
module.momentum = self.config.bn_momentum
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import Callable, Dict
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -176,8 +178,10 @@ class NasBench201(nn.Module): ...@@ -176,8 +178,10 @@ class NasBench201(nn.Module):
if reduction: if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2) cell = ResNetBasicblock(C_prev, C_curr, 2)
else: else:
cell = NasBench201Cell({prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES}, ops: Dict[str, Callable[[int, int], nn.Module]] = {
C_prev, C_curr, label='cell') prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES
}
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
self.cells.append(cell) self.cells.append(cell)
C_prev = C_curr C_prev = C_curr
......
...@@ -8,7 +8,7 @@ It's called ``nasnet.py`` simply because NASNet is the first to propose such str ...@@ -8,7 +8,7 @@ It's called ``nasnet.py`` simply because NASNet is the first to propose such str
""" """
from collections import OrderedDict from collections import OrderedDict
from typing import Tuple, List, Union, Iterable, Dict, Callable from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast
try: try:
from typing import Literal from typing import Literal
...@@ -250,14 +250,14 @@ class CellPreprocessor(nn.Module): ...@@ -250,14 +250,14 @@ class CellPreprocessor(nn.Module):
See :class:`CellBuilder` on how to calculate those channel numbers. See :class:`CellBuilder` on how to calculate those channel numbers.
""" """
def __init__(self, C_pprev: int, C_prev: int, C: int, last_cell_reduce: bool) -> None: def __init__(self, C_pprev: nn.MaybeChoice[int], C_prev: nn.MaybeChoice[int], C: nn.MaybeChoice[int], last_cell_reduce: bool) -> None:
super().__init__() super().__init__()
if last_cell_reduce: if last_cell_reduce:
self.pre0 = FactorizedReduce(C_pprev, C) self.pre0 = FactorizedReduce(cast(int, C_pprev), cast(int, C))
else: else:
self.pre0 = ReLUConvBN(C_pprev, C, 1, 1, 0) self.pre0 = ReLUConvBN(cast(int, C_pprev), cast(int, C), 1, 1, 0)
self.pre1 = ReLUConvBN(C_prev, C, 1, 1, 0) self.pre1 = ReLUConvBN(cast(int, C_prev), cast(int, C), 1, 1, 0)
def forward(self, cells): def forward(self, cells):
assert len(cells) == 2 assert len(cells) == 2
...@@ -283,15 +283,19 @@ class CellBuilder: ...@@ -283,15 +283,19 @@ class CellBuilder:
Note that the builder is ephemeral, it can only be called once for every index. Note that the builder is ephemeral, it can only be called once for every index.
""" """
def __init__(self, op_candidates: List[str], C_prev_in: int, C_in: int, C: int, def __init__(self, op_candidates: List[str],
num_nodes: int, merge_op: Literal['all', 'loose_end'], C_prev_in: nn.MaybeChoice[int],
C_in: nn.MaybeChoice[int],
C: nn.MaybeChoice[int],
num_nodes: int,
merge_op: Literal['all', 'loose_end'],
first_cell_reduce: bool, last_cell_reduce: bool): first_cell_reduce: bool, last_cell_reduce: bool):
self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell. self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell.
self.C_in = C_in # This is the out channesl of last cell. self.C_in = C_in # This is the out channesl of last cell.
self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices) self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self.op_candidates = op_candidates self.op_candidates = op_candidates
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.merge_op = merge_op self.merge_op: Literal['all', 'loose_end'] = merge_op
self.first_cell_reduce = first_cell_reduce self.first_cell_reduce = first_cell_reduce
self.last_cell_reduce = last_cell_reduce self.last_cell_reduce = last_cell_reduce
self._expect_idx = 0 self._expect_idx = 0
...@@ -312,7 +316,7 @@ class CellBuilder: ...@@ -312,7 +316,7 @@ class CellBuilder:
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built. # self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce) preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce)
ops_factory: Dict[str, Callable[[int, int, int], nn.Module]] = { ops_factory: Dict[str, Callable[[int, int, Optional[int]], nn.Module]] = {
op: # make final chosen ops named with their aliases op: # make final chosen ops named with their aliases
lambda node_index, op_index, input_index: lambda node_index, op_index, input_index:
OPS[op](self.C, 2 if is_reduction_cell and ( OPS[op](self.C, 2 if is_reduction_cell and (
...@@ -353,7 +357,7 @@ _INIT_PARAMETER_DOCS = """ ...@@ -353,7 +357,7 @@ _INIT_PARAMETER_DOCS = """
class NDS(nn.Module): class NDS(nn.Module):
""" __doc__ = """
The unified version of NASNet search space. The unified version of NASNet search space.
We follow the implementation in We follow the implementation in
...@@ -378,8 +382,8 @@ class NDS(nn.Module): ...@@ -378,8 +382,8 @@ class NDS(nn.Module):
op_candidates: List[str], op_candidates: List[str],
merge_op: Literal['all', 'loose_end'] = 'all', merge_op: Literal['all', 'loose_end'] = 'all',
num_nodes_per_cell: int = 4, num_nodes_per_cell: int = 4,
width: Union[Tuple[int], int] = 16, width: Union[Tuple[int, ...], int] = 16,
num_cells: Union[Tuple[int], int] = 20, num_cells: Union[Tuple[int, ...], int] = 20,
dataset: Literal['cifar', 'imagenet'] = 'imagenet', dataset: Literal['cifar', 'imagenet'] = 'imagenet',
auxiliary_loss: bool = False): auxiliary_loss: bool = False):
super().__init__() super().__init__()
...@@ -394,30 +398,31 @@ class NDS(nn.Module): ...@@ -394,30 +398,31 @@ class NDS(nn.Module):
else: else:
C = width C = width
self.num_cells: nn.MaybeChoice[int] = cast(int, num_cells)
if isinstance(num_cells, Iterable): if isinstance(num_cells, Iterable):
num_cells = nn.ValueChoice(list(num_cells), label='depth') self.num_cells = nn.ValueChoice(list(num_cells), label='depth')
num_cells_per_stage = [i * num_cells // 3 - (i - 1) * num_cells // 3 for i in range(3)] num_cells_per_stage = [i * self.num_cells // 3 - (i - 1) * self.num_cells // 3 for i in range(3)]
# auxiliary head is different for network targetted at different datasets # auxiliary head is different for network targetted at different datasets
if dataset == 'imagenet': if dataset == 'imagenet':
self.stem0 = nn.Sequential( self.stem0 = nn.Sequential(
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False), nn.Conv2d(3, cast(int, C // 2), kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2), nn.BatchNorm2d(cast(int, C // 2)),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False), nn.Conv2d(cast(int, C // 2), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C), nn.BatchNorm2d(C),
) )
self.stem1 = nn.Sequential( self.stem1 = nn.Sequential(
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False), nn.Conv2d(cast(int, C), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C), nn.BatchNorm2d(C),
) )
C_pprev = C_prev = C_curr = C C_pprev = C_prev = C_curr = C
last_cell_reduce = True last_cell_reduce = True
elif dataset == 'cifar': elif dataset == 'cifar':
self.stem = nn.Sequential( self.stem = nn.Sequential(
nn.Conv2d(3, 3 * C, 3, padding=1, bias=False), nn.Conv2d(3, cast(int, 3 * C), 3, padding=1, bias=False),
nn.BatchNorm2d(3 * C) nn.BatchNorm2d(cast(int, 3 * C))
) )
C_pprev = C_prev = 3 * C C_pprev = C_prev = 3 * C
C_curr = C C_curr = C
...@@ -439,7 +444,7 @@ class NDS(nn.Module): ...@@ -439,7 +444,7 @@ class NDS(nn.Module):
# C_pprev is output channel number of last second cell among all the cells already built. # C_pprev is output channel number of last second cell among all the cells already built.
if len(stage) > 1: if len(stage) > 1:
# Contains more than one cell # Contains more than one cell
C_pprev = len(stage[-2].output_node_indices) * C_curr C_pprev = len(cast(nn.Cell, stage[-2]).output_node_indices) * C_curr
else: else:
# Look up in the out channels of last stage. # Look up in the out channels of last stage.
C_pprev = C_prev C_pprev = C_prev
...@@ -447,7 +452,7 @@ class NDS(nn.Module): ...@@ -447,7 +452,7 @@ class NDS(nn.Module):
# This was originally, # This was originally,
# C_prev = num_nodes_per_cell * C_curr. # C_prev = num_nodes_per_cell * C_curr.
# but due to loose end, it becomes, # but due to loose end, it becomes,
C_prev = len(stage[-1].output_node_indices) * C_curr C_prev = len(cast(nn.Cell, stage[-1]).output_node_indices) * C_curr
# Useful in aligning the pprev and prev cell. # Useful in aligning the pprev and prev cell.
last_cell_reduce = cell_builder.last_cell_reduce last_cell_reduce = cell_builder.last_cell_reduce
...@@ -457,11 +462,11 @@ class NDS(nn.Module): ...@@ -457,11 +462,11 @@ class NDS(nn.Module):
if auxiliary_loss: if auxiliary_loss:
assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.' assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.'
self.stages[2] = SequentialBreakdown(self.stages[2]) self.stages[2] = SequentialBreakdown(cast(nn.Sequential, self.stages[2]))
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1)) self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(C_prev, self.num_labels) self.classifier = nn.Linear(cast(int, C_prev), self.num_labels)
def forward(self, inputs): def forward(self, inputs):
if self.dataset == 'imagenet': if self.dataset == 'imagenet':
...@@ -483,7 +488,7 @@ class NDS(nn.Module): ...@@ -483,7 +488,7 @@ class NDS(nn.Module):
out = self.global_pooling(s1) out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1)) logits = self.classifier(out.view(out.size(0), -1))
if self.training and self.auxiliary_loss: if self.training and self.auxiliary_loss:
return logits, logits_aux return logits, logits_aux # type: ignore
else: else:
return logits return logits
...@@ -524,8 +529,8 @@ class NASNet(NDS): ...@@ -524,8 +529,8 @@ class NASNet(NDS):
] ]
def __init__(self, def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32), width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20), num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar', dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False): auxiliary_loss: bool = False):
super().__init__(self.NASNET_OPS, super().__init__(self.NASNET_OPS,
...@@ -555,8 +560,8 @@ class ENAS(NDS): ...@@ -555,8 +560,8 @@ class ENAS(NDS):
] ]
def __init__(self, def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32), width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20), num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar', dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False): auxiliary_loss: bool = False):
super().__init__(self.ENAS_OPS, super().__init__(self.ENAS_OPS,
...@@ -590,8 +595,8 @@ class AmoebaNet(NDS): ...@@ -590,8 +595,8 @@ class AmoebaNet(NDS):
] ]
def __init__(self, def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32), width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20), num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar', dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False): auxiliary_loss: bool = False):
...@@ -626,8 +631,8 @@ class PNAS(NDS): ...@@ -626,8 +631,8 @@ class PNAS(NDS):
] ]
def __init__(self, def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32), width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20), num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar', dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False): auxiliary_loss: bool = False):
super().__init__(self.PNAS_OPS, super().__init__(self.PNAS_OPS,
...@@ -660,8 +665,8 @@ class DARTS(NDS): ...@@ -660,8 +665,8 @@ class DARTS(NDS):
] ]
def __init__(self, def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32), width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20), num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar', dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False): auxiliary_loss: bool = False):
super().__init__(self.DARTS_OPS, super().__init__(self.DARTS_OPS,
......
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import math import math
from typing import Optional, Callable, List, Tuple from typing import Optional, Callable, List, Tuple, cast
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
...@@ -31,12 +31,12 @@ class ConvBNReLU(nn.Sequential): ...@@ -31,12 +31,12 @@ class ConvBNReLU(nn.Sequential):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: nn.MaybeChoice[int],
out_channels: int, out_channels: nn.MaybeChoice[int],
kernel_size: int = 3, kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1, stride: int = 1,
groups: int = 1, groups: nn.MaybeChoice[int] = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None, activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1, dilation: int = 1,
) -> None: ) -> None:
...@@ -46,9 +46,17 @@ class ConvBNReLU(nn.Sequential): ...@@ -46,9 +46,17 @@ class ConvBNReLU(nn.Sequential):
if activation_layer is None: if activation_layer is None:
activation_layer = nn.ReLU6 activation_layer = nn.ReLU6
super().__init__( super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups, nn.Conv2d(
bias=False), cast(int, in_channels),
norm_layer(out_channels), cast(int, out_channels),
cast(int, kernel_size),
stride,
cast(int, padding),
dilation=dilation,
groups=cast(int, groups),
bias=False
),
norm_layer(cast(int, out_channels)),
activation_layer(inplace=True) activation_layer(inplace=True)
) )
self.out_channels = out_channels self.out_channels = out_channels
...@@ -62,11 +70,11 @@ class SeparableConv(nn.Sequential): ...@@ -62,11 +70,11 @@ class SeparableConv(nn.Sequential):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: nn.MaybeChoice[int],
out_channels: int, out_channels: nn.MaybeChoice[int],
kernel_size: int = 3, kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1, stride: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None, activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
super().__init__( super().__init__(
...@@ -101,13 +109,13 @@ class InvertedResidual(nn.Sequential): ...@@ -101,13 +109,13 @@ class InvertedResidual(nn.Sequential):
def __init__( def __init__(
self, self,
in_channels: int, in_channels: nn.MaybeChoice[int],
out_channels: int, out_channels: nn.MaybeChoice[int],
expand_ratio: int, expand_ratio: nn.MaybeChoice[float],
kernel_size: int = 3, kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1, stride: int = 1,
squeeze_and_excite: Optional[Callable[[int], nn.Module]] = None, squeeze_and_excite: Optional[Callable[[nn.MaybeChoice[int]], nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None, norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None, activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -115,7 +123,7 @@ class InvertedResidual(nn.Sequential): ...@@ -115,7 +123,7 @@ class InvertedResidual(nn.Sequential):
self.out_channels = out_channels self.out_channels = out_channels
assert stride in [1, 2] assert stride in [1, 2]
hidden_ch = nn.ValueChoice.to_int(round(in_channels * expand_ratio)) hidden_ch = nn.ValueChoice.to_int(round(cast(int, in_channels * expand_ratio)))
# FIXME: check whether this equal works # FIXME: check whether this equal works
# Residual connection is added here stride = 1 and input channels and output channels are the same. # Residual connection is added here stride = 1 and input channels and output channels are the same.
...@@ -215,7 +223,7 @@ class ProxylessNAS(nn.Module): ...@@ -215,7 +223,7 @@ class ProxylessNAS(nn.Module):
self.first_conv = ConvBNReLU(3, widths[0], stride=2, norm_layer=nn.BatchNorm2d) self.first_conv = ConvBNReLU(3, widths[0], stride=2, norm_layer=nn.BatchNorm2d)
blocks = [ blocks: List[nn.Module] = [
# first stage is fixed # first stage is fixed
SeparableConv(widths[0], widths[1], kernel_size=3, stride=1) SeparableConv(widths[0], widths[1], kernel_size=3, stride=1)
] ]
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import cast
import torch import torch
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper from nni.retiarii import model_wrapper
...@@ -14,7 +16,7 @@ class ShuffleNetBlock(nn.Module): ...@@ -14,7 +16,7 @@ class ShuffleNetBlock(nn.Module):
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels. When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
""" """
def __init__(self, in_channels: int, out_channels: int, mid_channels: int, *, def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *,
kernel_size: int, stride: int, sequence: str = "pdp", affine: bool = True): kernel_size: int, stride: int, sequence: str = "pdp", affine: bool = True):
super().__init__() super().__init__()
assert stride in [1, 2] assert stride in [1, 2]
...@@ -57,14 +59,15 @@ class ShuffleNetBlock(nn.Module): ...@@ -57,14 +59,15 @@ class ShuffleNetBlock(nn.Module):
def _decode_point_depth_conv(self, sequence): def _decode_point_depth_conv(self, sequence):
result = [] result = []
first_depth = first_point = True first_depth = first_point = True
pc = c = self.channels pc: int = self.channels
c: int = self.channels
for i, token in enumerate(sequence): for i, token in enumerate(sequence):
# compute output channels of this conv # compute output channels of this conv
if i + 1 == len(sequence): if i + 1 == len(sequence):
assert token == "p", "Last conv must be point-wise conv." assert token == "p", "Last conv must be point-wise conv."
c = self.oup_main c = self.oup_main
elif token == "p" and first_point: elif token == "p" and first_point:
c = self.mid_channels c = cast(int, self.mid_channels)
if token == "d": if token == "d":
# depth-wise conv # depth-wise conv
if isinstance(pc, int) and isinstance(c, int): if isinstance(pc, int) and isinstance(c, int):
...@@ -101,7 +104,7 @@ class ShuffleXceptionBlock(ShuffleNetBlock): ...@@ -101,7 +104,7 @@ class ShuffleXceptionBlock(ShuffleNetBlock):
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__. `Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
""" """
def __init__(self, in_channels: int, out_channels: int, mid_channels: int, *, stride: int, affine: bool = True): def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *, stride: int, affine: bool = True):
super().__init__(in_channels, out_channels, mid_channels, super().__init__(in_channels, out_channels, mid_channels,
kernel_size=3, stride=stride, sequence="dpdpdp", affine=affine) kernel_size=3, stride=stride, sequence="dpdpdp", affine=affine)
...@@ -154,7 +157,7 @@ class ShuffleNetSpace(nn.Module): ...@@ -154,7 +157,7 @@ class ShuffleNetSpace(nn.Module):
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
) )
self.features = [] feature_blocks = []
global_block_idx = 0 global_block_idx = 0
for stage_idx, num_repeat in enumerate(self.stage_repeats): for stage_idx, num_repeat in enumerate(self.stage_repeats):
...@@ -175,15 +178,17 @@ class ShuffleNetSpace(nn.Module): ...@@ -175,15 +178,17 @@ class ShuffleNetSpace(nn.Module):
else: else:
mid_channels = int(base_mid_channels) mid_channels = int(base_mid_channels)
mid_channels = cast(nn.MaybeChoice[int], mid_channels)
choice_block = nn.LayerChoice([ choice_block = nn.LayerChoice([
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=3, stride=stride, affine=affine), ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=3, stride=stride, affine=affine),
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=5, stride=stride, affine=affine), ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=5, stride=stride, affine=affine),
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=7, stride=stride, affine=affine), ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=7, stride=stride, affine=affine),
ShuffleXceptionBlock(in_channels, out_channels, mid_channels=mid_channels, stride=stride, affine=affine) ShuffleXceptionBlock(in_channels, out_channels, mid_channels=mid_channels, stride=stride, affine=affine)
], label=f'layer_{global_block_idx}') ], label=f'layer_{global_block_idx}')
self.features.append(choice_block) feature_blocks.append(choice_block)
self.features = nn.Sequential(*self.features) self.features = nn.Sequential(*feature_blocks)
# final layers # final layers
last_conv_channels = self.stage_out_channels[-1] last_conv_channels = self.stage_out_channels[-1]
...@@ -226,13 +231,15 @@ class ShuffleNetSpace(nn.Module): ...@@ -226,13 +231,15 @@ class ShuffleNetSpace(nn.Module):
torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.weight, 1)
if m.bias is not None: if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001) torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.running_mean, 0) if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.BatchNorm1d): elif isinstance(m, nn.BatchNorm1d):
if m.weight is not None: if m.weight is not None:
torch.nn.init.constant_(m.weight, 1) torch.nn.init.constant_(m.weight, 1)
if m.bias is not None: if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001) torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.running_mean, 0) if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.Linear): elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0, 0.01) torch.nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None: if m.bias is not None:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Useful type hints
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import logging import logging
import os import os
from typing import Any, Callable from typing import Any, Callable, Optional
import nni import nni
from nni.common.serializer import PayloadTooLarge from nni.common.serializer import PayloadTooLarge
...@@ -53,11 +53,11 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -53,11 +53,11 @@ class RetiariiAdvisor(MsgDispatcherBase):
register_advisor(self) # register the current advisor as the "global only" advisor register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None self.search_space = None
self.send_trial_callback: Callable[[dict], None] = None self.send_trial_callback: Optional[Callable[[dict], None]] = None
self.request_trial_jobs_callback: Callable[[int], None] = None self.request_trial_jobs_callback: Optional[Callable[[int], None]] = None
self.trial_end_callback: Callable[[int, bool], None] = None self.trial_end_callback: Optional[Callable[[int, bool], None]] = None
self.intermediate_metric_callback: Callable[[int, MetricData], None] = None self.intermediate_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.final_metric_callback: Callable[[int, MetricData], None] = None self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.parameters_count = 0 self.parameters_count = 0
...@@ -158,19 +158,22 @@ class RetiariiAdvisor(MsgDispatcherBase): ...@@ -158,19 +158,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
def handle_trial_end(self, data): def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data) _logger.debug('Trial end: %s', data)
self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable if self.trial_end_callback is not None:
data['event'] == 'SUCCEEDED') self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data): def handle_report_metric_data(self, data):
_logger.debug('Metric reported: %s', data) _logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER: if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported') raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL: elif data['type'] == MetricType.PERIODICAL:
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable if self.intermediate_metric_callback is not None:
self._process_value(data['value'])) self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
elif data['type'] == MetricType.FINAL: elif data['type'] == MetricType.FINAL:
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable if self.final_metric_callback is not None:
self._process_value(data['value'])) self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
@staticmethod @staticmethod
def _process_value(value) -> Any: # hopefully a float def _process_value(value) -> Any: # hopefully a float
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import (Any, Iterable, List, Optional, Tuple) import warnings
from typing import (Any, Iterable, List, Optional, Tuple, cast)
from .graph import Model, Mutation, ModelStatus from .graph import Model, Mutation, ModelStatus
...@@ -44,9 +45,11 @@ class Mutator: ...@@ -44,9 +45,11 @@ class Mutator:
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label. If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
""" """
def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None): def __init__(self, sampler: Optional[Sampler] = None, label: str = cast(str, None)):
self.sampler: Optional[Sampler] = sampler self.sampler: Optional[Sampler] = sampler
self.label: Optional[str] = label if label is None:
warnings.warn('Each mutator should have an explicit label. Mutator without label is deprecated.', DeprecationWarning)
self.label: str = label
self._cur_model: Optional[Model] = None self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None self._cur_choice_idx: Optional[int] = None
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment