Unverified Commit 18962129 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add license header and typehints for NAS (#4774)

parent 8c2f717d
......@@ -4,10 +4,11 @@
import os
import warnings
from pathlib import Path
from typing import Dict, Union, Optional, List, Callable
from typing import Dict, Union, Optional, List, Callable, Type
import pytorch_lightning as pl
import torch.nn as nn
import torch.nn.functional as nn_functional
import torch.optim as optim
import torchmetrics
import torch.utils.data as torch_data
......@@ -124,12 +125,12 @@ class Lightning(Evaluator):
if other is None:
return False
if hasattr(self, "function") and hasattr(other, "function"):
eq_func = (self.function == other.function)
eq_func = getattr(self, "function") == getattr(other, "function")
elif not (hasattr(self, "function") or hasattr(other, "function")):
eq_func = True
if hasattr(self, "arguments") and hasattr(other, "arguments"):
eq_args = (self.arguments == other.arguments)
eq_args = getattr(self, "arguments") == getattr(other, "arguments")
elif not (hasattr(self, "arguments") or hasattr(other, "arguments")):
eq_args = True
......@@ -159,10 +160,13 @@ def _check_dataloader(dataloader):
### The following are some commonly used Lightning modules ###
class _SupervisedLearningModule(LightningModule):
def __init__(self, criterion: nn.Module, metrics: Dict[str, torchmetrics.Metric],
trainer: pl.Trainer
def __init__(self, criterion: Type[nn.Module], metrics: Dict[str, Type[torchmetrics.Metric]],
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: Union[Path, str, bool, None] = None):
super().__init__()
self.save_hyperparameters('criterion', 'optimizer', 'learning_rate', 'weight_decay')
......@@ -214,7 +218,7 @@ class _SupervisedLearningModule(LightningModule):
self.log('test_' + name, metric(y_hat, y), prog_bar=True)
def configure_optimizers(self):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay)
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self):
nni.report_intermediate_result(self._get_validation_metrics())
......@@ -233,15 +237,15 @@ class _SupervisedLearningModule(LightningModule):
class _AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)
return super().update(nn_functional.softmax(pred), target)
@nni.trace
class _ClassificationModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'acc': _AccuracyWithLogits},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
......@@ -275,10 +279,10 @@ class Classification(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.CrossEntropyLoss,
def __init__(self, criterion: Type[nn.Module] = nn.CrossEntropyLoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
......@@ -291,10 +295,10 @@ class Classification(Lightning):
@nni.trace
class _RegressionModule(_SupervisedLearningModule):
def __init__(self, criterion: nn.Module = nn.MSELoss,
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
export_onnx: bool = True):
super().__init__(criterion, {'mse': torchmetrics.MeanSquaredError},
learning_rate=learning_rate, weight_decay=weight_decay, optimizer=optimizer,
......@@ -328,10 +332,10 @@ class Regression(Lightning):
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, criterion: nn.Module = nn.MSELoss,
def __init__(self, criterion: Type[nn.Module] = nn.MSELoss,
learning_rate: float = 0.001,
weight_decay: float = 0.,
optimizer: optim.Optimizer = optim.Adam,
optimizer: Type[optim.Optimizer] = optim.Adam,
train_dataloader: Optional[DataLoader] = None,
val_dataloaders: Union[DataLoader, List[DataLoader], None] = None,
export_onnx: bool = True,
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .api import *
......@@ -129,6 +129,7 @@ class BaseExecutionEngine(AbstractExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation_summary = get_mutation_summary(model)
assert model.evaluator is not None, 'Model evaluator can not be None'
return BaseGraphData(codegen.model_to_pytorch_script(model), model.evaluator, mutation_summary)
@classmethod
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import os
import random
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable
from typing import Dict, Any, List, Optional, Union, Tuple, Callable, Iterable, cast
from ..graph import Model
from ..integration_api import receive_trial_parameters
......@@ -39,6 +42,9 @@ class BenchmarkGraphData:
def load(data) -> 'BenchmarkGraphData':
return BenchmarkGraphData(data['mutation'], data['benchmark'], data['metric_name'], data['db_path'])
def __repr__(self) -> str:
return f"BenchmarkGraphData({self.mutation}, {self.benchmark}, {self.db_path})"
class BenchmarkExecutionEngine(BaseExecutionEngine):
"""
......@@ -67,6 +73,7 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
@classmethod
def trial_execute_graph(cls) -> None:
graph_data = BenchmarkGraphData.load(receive_trial_parameters())
assert graph_data.db_path is not None, f'Invalid graph data because db_path is None: {graph_data}'
os.environ['NASBENCHMARK_DIR'] = graph_data.db_path
final, intermediates = cls.query_in_benchmark(graph_data)
......@@ -89,7 +96,6 @@ class BenchmarkExecutionEngine(BaseExecutionEngine):
arch = t
if arch is None:
raise ValueError(f'Cannot identify architecture from mutation dict: {graph_data.mutation}')
print(arch)
return _convert_to_final_and_intermediates(
query_nb101_trial_stats(arch, 108, include_intermediates=True),
'valid_acc'
......@@ -146,4 +152,5 @@ def _convert_to_final_and_intermediates(benchmark_result: Iterable[Any], metric_
benchmark_result = random.choice(benchmark_result)
else:
benchmark_result = benchmark_result[0]
benchmark_result = cast(dict, benchmark_result)
return benchmark_result[metric_name], [i[metric_name] for i in benchmark_result['intermediates'] if i[metric_name] is not None]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import random
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Dict, Any, Type
import torch.nn as nn
......@@ -49,7 +52,8 @@ class PurePythonExecutionEngine(BaseExecutionEngine):
@classmethod
def pack_model_data(cls, model: Model) -> Any:
mutation = get_mutation_dict(model)
graph_data = PythonGraphData(model.python_class, model.python_init_params, mutation, model.evaluator)
assert model.evaluator is not None, 'Model evaluator is not available.'
graph_data = PythonGraphData(model.python_class, model.python_init_params or {}, mutation, model.evaluator)
return graph_data
@classmethod
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, List
from ..graph import Model
......
......@@ -11,7 +11,7 @@ from dataclasses import dataclass
from pathlib import Path
from subprocess import Popen
from threading import Thread
from typing import Any, List, Optional, Union
from typing import Any, List, Optional, Union, cast
import colorama
import psutil
......@@ -23,6 +23,7 @@ from nni.experiment import Experiment, launcher, management, rest
from nni.experiment.config import utils
from nni.experiment.config.base import ConfigBase
from nni.experiment.config.training_service import TrainingServiceConfig
from nni.experiment.config.training_services import RemoteConfig
from nni.experiment.pipe import Pipe
from nni.tools.nnictl.command_utils import kill_command
......@@ -222,6 +223,7 @@ class RetiariiExperiment(Experiment):
Examples
--------
Multi-trial NAS:
>>> base_model = Net()
>>> search_strategy = strategy.Random()
>>> model_evaluator = FunctionalEvaluator(evaluate_model)
......@@ -233,6 +235,7 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config, 8081)
One-shot NAS:
>>> base_model = Net()
>>> search_strategy = strategy.DARTS()
>>> evaluator = pl.Classification(train_dataloader=train_loader, val_dataloaders=valid_loader)
......@@ -242,15 +245,16 @@ class RetiariiExperiment(Experiment):
>>> exp.run(exp_config)
Export top models:
>>> for model_dict in exp.export_top_models(formatter='dict'):
... print(model_dict)
>>> with nni.retarii.fixed_arch(model_dict):
... final_model = Net()
"""
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = None,
applied_mutators: List[Mutator] = None, strategy: BaseStrategy = None,
trainer: BaseOneShotTrainer = None):
def __init__(self, base_model: nn.Module, evaluator: Union[BaseOneShotTrainer, Evaluator] = cast(Evaluator, None),
applied_mutators: List[Mutator] = cast(List[Mutator], None), strategy: BaseStrategy = cast(BaseStrategy, None),
trainer: BaseOneShotTrainer = cast(BaseOneShotTrainer, None)):
if trainer is not None:
warnings.warn('Usage of `trainer` in RetiariiExperiment is deprecated and will be removed soon. '
'Please consider specifying it as a positional argument, or use `evaluator`.', DeprecationWarning)
......@@ -260,21 +264,22 @@ class RetiariiExperiment(Experiment):
raise ValueError('Evaluator should not be none.')
# TODO: The current design of init interface of Retiarii experiment needs to be reviewed.
self.config: RetiariiExeConfig = None
self.config: RetiariiExeConfig = cast(RetiariiExeConfig, None)
self.port: Optional[int] = None
self.base_model = base_model
self.evaluator: Evaluator = evaluator
self.evaluator: Union[Evaluator, BaseOneShotTrainer] = evaluator
self.applied_mutators = applied_mutators
self.strategy = strategy
# FIXME: this is only a workaround
from nni.retiarii.oneshot.pytorch.strategy import OneShotStrategy
if not isinstance(strategy, OneShotStrategy):
self._dispatcher = RetiariiAdvisor()
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
else:
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread: Optional[Thread] = None
self._proc: Optional[Popen] = None
self._pipe: Optional[Pipe] = None
self.url_prefix = None
......@@ -325,7 +330,7 @@ class RetiariiExperiment(Experiment):
assert self.config.training_service.platform == 'remote', \
"CGO execution engine currently only supports remote training service"
assert self.config.batch_waiting_time is not None
assert self.config.batch_waiting_time is not None and self.config.max_concurrency_cgo is not None
devices = self._construct_devices()
engine = CGOExecutionEngine(devices,
max_concurrency=self.config.max_concurrency_cgo,
......@@ -335,7 +340,10 @@ class RetiariiExperiment(Experiment):
engine = PurePythonExecutionEngine()
elif self.config.execution_engine == 'benchmark':
from ..execution.benchmark import BenchmarkExecutionEngine
assert self.config.benchmark is not None, '"benchmark" must be set when benchmark execution engine is used.'
engine = BenchmarkExecutionEngine(self.config.benchmark)
else:
raise ValueError(f'Unsupported engine type: {self.config.execution_engine}')
set_execution_engine(engine)
self.id = management.generate_experiment_id()
......@@ -377,9 +385,10 @@ class RetiariiExperiment(Experiment):
def _construct_devices(self):
devices = []
if hasattr(self.config.training_service, 'machine_list'):
for machine in self.config.training_service.machine_list:
for machine in cast(RemoteConfig, self.config.training_service).machine_list:
assert machine.gpu_indices is not None, \
'gpu_indices must be set in RemoteMachineConfig for CGO execution engine'
assert isinstance(machine.gpu_indices, list), 'gpu_indices must be a list'
for gpu_idx in machine.gpu_indices:
devices.append(GPUDevice(machine.host, gpu_idx))
return devices
......@@ -387,7 +396,7 @@ class RetiariiExperiment(Experiment):
def _create_dispatcher(self):
return self._dispatcher
def run(self, config: RetiariiExeConfig = None, port: int = 8080, debug: bool = False) -> str:
def run(self, config: Optional[RetiariiExeConfig] = None, port: int = 8080, debug: bool = False) -> None:
"""
Run the experiment.
This function will block until experiment finish or error.
......@@ -420,6 +429,7 @@ class RetiariiExperiment(Experiment):
This function will block until experiment finish or error.
Return `True` when experiment done; or return `False` when experiment failed.
"""
assert self._proc is not None
try:
while True:
time.sleep(10)
......@@ -437,6 +447,7 @@ class RetiariiExperiment(Experiment):
_logger.warning('KeyboardInterrupt detected')
finally:
self.stop()
raise RuntimeError('Check experiment status failed.')
def stop(self) -> None:
"""
......@@ -466,11 +477,11 @@ class RetiariiExperiment(Experiment):
if self._pipe is not None:
self._pipe.close()
self.id = None
self.port = None
self.id = cast(str, None)
self.port = cast(int, None)
self._proc = None
self._pipe = None
self._dispatcher = None
self._dispatcher = cast(RetiariiAdvisor, None)
self._dispatcher_thread = None
_logger.info('Experiment stopped')
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from pathlib import Path
......
......@@ -5,10 +5,16 @@
Model representation.
"""
from __future__ import annotations
import abc
import json
from enum import Enum
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Type, Union, overload)
from typing import (TYPE_CHECKING, Any, Callable, Dict, Iterable, List,
Optional, Set, Tuple, Type, Union, cast, overload)
if TYPE_CHECKING:
from .mutator import Mutator
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import uid
......@@ -63,7 +69,7 @@ class Evaluator(abc.ABC):
pass
@abc.abstractmethod
def _execute(self, model_cls: type) -> Any:
def _execute(self, model_cls: Union[Callable[[], Any], Any]) -> Any:
pass
@abc.abstractmethod
......@@ -203,7 +209,7 @@ class Model:
matched_nodes.extend(nodes)
return matched_nodes
def get_node_by_name(self, node_name: str) -> 'Node':
def get_node_by_name(self, node_name: str) -> 'Node' | None:
"""
Traverse all the nodes to find the matched node with the given name.
"""
......@@ -217,7 +223,7 @@ class Model:
else:
return None
def get_node_by_python_name(self, python_name: str) -> 'Node':
def get_node_by_python_name(self, python_name: str) -> Optional['Node']:
"""
Traverse all the nodes to find the matched node with the given python_name.
"""
......@@ -297,7 +303,7 @@ class Graph:
The name of torch.nn.Module, should have one-to-one mapping with items in python model.
"""
def __init__(self, model: Model, graph_id: int, name: str = None, _internal: bool = False):
def __init__(self, model: Model, graph_id: int, name: str = cast(str, None), _internal: bool = False):
assert _internal, '`Graph()` is private'
self.model: Model = model
......@@ -338,9 +344,9 @@ class Graph:
@overload
def add_node(self, name: str, operation: Operation) -> 'Node': ...
@overload
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
def add_node(self, name: str, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def add_node(self, name, operation_or_type, parameters=None):
def add_node(self, name, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
......@@ -350,9 +356,10 @@ class Graph:
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, operation: Operation) -> 'Node': ...
@overload
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str, parameters: Dict[str, Any] = None) -> 'Node': ...
def insert_node_on_edge(self, edge: 'Edge', name: str, type_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Node': ...
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node':
def insert_node_on_edge(self, edge, name, operation_or_type, parameters=None) -> 'Node': # type: ignore
if isinstance(operation_or_type, Operation):
op = operation_or_type
else:
......@@ -405,7 +412,7 @@ class Graph:
def get_nodes_by_name(self, name: str) -> List['Node']:
return [node for node in self.hidden_nodes if node.name == name]
def get_nodes_by_python_name(self, python_name: str) -> Optional['Node']:
def get_nodes_by_python_name(self, python_name: str) -> List['Node']:
return [node for node in self.nodes if node.python_name == python_name]
def topo_sort(self) -> List['Node']:
......@@ -594,7 +601,7 @@ class Node:
return sorted(set(edge.tail for edge in self.outgoing_edges), key=(lambda node: node.id))
@property
def successor_slots(self) -> List[Tuple['Node', Union[int, None]]]:
def successor_slots(self) -> Set[Tuple['Node', Union[int, None]]]:
return set((edge.tail, edge.tail_slot) for edge in self.outgoing_edges)
@property
......@@ -610,19 +617,19 @@ class Node:
assert isinstance(self.operation, Cell)
return self.graph.model.graphs[self.operation.parameters['cell']]
def update_label(self, label: str) -> None:
def update_label(self, label: Optional[str]) -> None:
self.label = label
@overload
def update_operation(self, operation: Operation) -> None: ...
@overload
def update_operation(self, type_name: str, parameters: Dict[str, Any] = None) -> None: ...
def update_operation(self, type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None)) -> None: ...
def update_operation(self, operation_or_type, parameters=None):
def update_operation(self, operation_or_type, parameters=None): # type: ignore
if isinstance(operation_or_type, Operation):
self.operation = operation_or_type
else:
self.operation = Operation.new(operation_or_type, parameters)
self.operation = Operation.new(operation_or_type, cast(dict, parameters))
# mutation
def remove(self) -> None:
......@@ -663,7 +670,13 @@ class Node:
return node
def _dump(self) -> Any:
ret = {'operation': {'type': self.operation.type, 'parameters': self.operation.parameters, 'attributes': self.operation.attributes}}
ret: Dict[str, Any] = {
'operation': {
'type': self.operation.type,
'parameters': self.operation.parameters,
'attributes': self.operation.attributes
}
}
if isinstance(self.operation, Cell):
ret['operation']['cell_name'] = self.operation.cell_name
if self.label is not None:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Tuple, Optional, Callable
from typing import Tuple, Optional, Callable, cast
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
......@@ -75,10 +75,10 @@ class MobileNetV3Space(nn.Module):
bn_momentum: float = 0.1):
super().__init__()
self.widths = [
self.widths = cast(nn.ChoiceOf[int], [
nn.ValueChoice([make_divisible(base_width * mult, 8) for mult in width_multipliers], label=f'width_{i}')
for i, base_width in enumerate(base_widths)
]
])
self.expand_ratios = expand_ratios
blocks = [
......@@ -115,7 +115,7 @@ class MobileNetV3Space(nn.Module):
self.classifier = nn.Sequential(
nn.Dropout(dropout_rate),
nn.Linear(self.widths[7], num_labels),
nn.Linear(cast(int, self.widths[7]), num_labels),
)
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)
......
......@@ -3,6 +3,7 @@
import math
import torch
import torch.nn as nn
from nni.retiarii import model_wrapper
from nni.retiarii.nn.pytorch import NasBench101Cell
......@@ -11,7 +12,7 @@ from nni.retiarii.nn.pytorch import NasBench101Cell
__all__ = ['NasBench101']
def truncated_normal_(tensor, mean=0, std=1):
def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
......@@ -117,9 +118,3 @@ class NasBench101(nn.Module):
out = self.gap(out).view(bs, -1)
out = self.classifier(out)
return out
def reset_parameters(self):
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.eps = self.config.bn_eps
module.momentum = self.config.bn_momentum
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Callable, Dict
import torch
import torch.nn as nn
......@@ -176,8 +178,10 @@ class NasBench201(nn.Module):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = NasBench201Cell({prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES},
C_prev, C_curr, label='cell')
ops: Dict[str, Callable[[int, int], nn.Module]] = {
prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES
}
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
self.cells.append(cell)
C_prev = C_curr
......
......@@ -8,7 +8,7 @@ It's called ``nasnet.py`` simply because NASNet is the first to propose such str
"""
from collections import OrderedDict
from typing import Tuple, List, Union, Iterable, Dict, Callable
from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast
try:
from typing import Literal
......@@ -250,14 +250,14 @@ class CellPreprocessor(nn.Module):
See :class:`CellBuilder` on how to calculate those channel numbers.
"""
def __init__(self, C_pprev: int, C_prev: int, C: int, last_cell_reduce: bool) -> None:
def __init__(self, C_pprev: nn.MaybeChoice[int], C_prev: nn.MaybeChoice[int], C: nn.MaybeChoice[int], last_cell_reduce: bool) -> None:
super().__init__()
if last_cell_reduce:
self.pre0 = FactorizedReduce(C_pprev, C)
self.pre0 = FactorizedReduce(cast(int, C_pprev), cast(int, C))
else:
self.pre0 = ReLUConvBN(C_pprev, C, 1, 1, 0)
self.pre1 = ReLUConvBN(C_prev, C, 1, 1, 0)
self.pre0 = ReLUConvBN(cast(int, C_pprev), cast(int, C), 1, 1, 0)
self.pre1 = ReLUConvBN(cast(int, C_prev), cast(int, C), 1, 1, 0)
def forward(self, cells):
assert len(cells) == 2
......@@ -283,15 +283,19 @@ class CellBuilder:
Note that the builder is ephemeral, it can only be called once for every index.
"""
def __init__(self, op_candidates: List[str], C_prev_in: int, C_in: int, C: int,
num_nodes: int, merge_op: Literal['all', 'loose_end'],
def __init__(self, op_candidates: List[str],
C_prev_in: nn.MaybeChoice[int],
C_in: nn.MaybeChoice[int],
C: nn.MaybeChoice[int],
num_nodes: int,
merge_op: Literal['all', 'loose_end'],
first_cell_reduce: bool, last_cell_reduce: bool):
self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell.
self.C_in = C_in # This is the out channesl of last cell.
self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self.op_candidates = op_candidates
self.num_nodes = num_nodes
self.merge_op = merge_op
self.merge_op: Literal['all', 'loose_end'] = merge_op
self.first_cell_reduce = first_cell_reduce
self.last_cell_reduce = last_cell_reduce
self._expect_idx = 0
......@@ -312,7 +316,7 @@ class CellBuilder:
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce)
ops_factory: Dict[str, Callable[[int, int, int], nn.Module]] = {
ops_factory: Dict[str, Callable[[int, int, Optional[int]], nn.Module]] = {
op: # make final chosen ops named with their aliases
lambda node_index, op_index, input_index:
OPS[op](self.C, 2 if is_reduction_cell and (
......@@ -353,7 +357,7 @@ _INIT_PARAMETER_DOCS = """
class NDS(nn.Module):
"""
__doc__ = """
The unified version of NASNet search space.
We follow the implementation in
......@@ -378,8 +382,8 @@ class NDS(nn.Module):
op_candidates: List[str],
merge_op: Literal['all', 'loose_end'] = 'all',
num_nodes_per_cell: int = 4,
width: Union[Tuple[int], int] = 16,
num_cells: Union[Tuple[int], int] = 20,
width: Union[Tuple[int, ...], int] = 16,
num_cells: Union[Tuple[int, ...], int] = 20,
dataset: Literal['cifar', 'imagenet'] = 'imagenet',
auxiliary_loss: bool = False):
super().__init__()
......@@ -394,30 +398,31 @@ class NDS(nn.Module):
else:
C = width
self.num_cells: nn.MaybeChoice[int] = cast(int, num_cells)
if isinstance(num_cells, Iterable):
num_cells = nn.ValueChoice(list(num_cells), label='depth')
num_cells_per_stage = [i * num_cells // 3 - (i - 1) * num_cells // 3 for i in range(3)]
self.num_cells = nn.ValueChoice(list(num_cells), label='depth')
num_cells_per_stage = [i * self.num_cells // 3 - (i - 1) * self.num_cells // 3 for i in range(3)]
# auxiliary head is different for network targetted at different datasets
if dataset == 'imagenet':
self.stem0 = nn.Sequential(
nn.Conv2d(3, C // 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C // 2),
nn.Conv2d(3, cast(int, C // 2), kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(cast(int, C // 2)),
nn.ReLU(inplace=True),
nn.Conv2d(C // 2, C, 3, stride=2, padding=1, bias=False),
nn.Conv2d(cast(int, C // 2), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(C, C, 3, stride=2, padding=1, bias=False),
nn.Conv2d(cast(int, C), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_pprev = C_prev = C_curr = C
last_cell_reduce = True
elif dataset == 'cifar':
self.stem = nn.Sequential(
nn.Conv2d(3, 3 * C, 3, padding=1, bias=False),
nn.BatchNorm2d(3 * C)
nn.Conv2d(3, cast(int, 3 * C), 3, padding=1, bias=False),
nn.BatchNorm2d(cast(int, 3 * C))
)
C_pprev = C_prev = 3 * C
C_curr = C
......@@ -439,7 +444,7 @@ class NDS(nn.Module):
# C_pprev is output channel number of last second cell among all the cells already built.
if len(stage) > 1:
# Contains more than one cell
C_pprev = len(stage[-2].output_node_indices) * C_curr
C_pprev = len(cast(nn.Cell, stage[-2]).output_node_indices) * C_curr
else:
# Look up in the out channels of last stage.
C_pprev = C_prev
......@@ -447,7 +452,7 @@ class NDS(nn.Module):
# This was originally,
# C_prev = num_nodes_per_cell * C_curr.
# but due to loose end, it becomes,
C_prev = len(stage[-1].output_node_indices) * C_curr
C_prev = len(cast(nn.Cell, stage[-1]).output_node_indices) * C_curr
# Useful in aligning the pprev and prev cell.
last_cell_reduce = cell_builder.last_cell_reduce
......@@ -457,11 +462,11 @@ class NDS(nn.Module):
if auxiliary_loss:
assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.'
self.stages[2] = SequentialBreakdown(self.stages[2])
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset)
self.stages[2] = SequentialBreakdown(cast(nn.Sequential, self.stages[2]))
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(C_prev, self.num_labels)
self.classifier = nn.Linear(cast(int, C_prev), self.num_labels)
def forward(self, inputs):
if self.dataset == 'imagenet':
......@@ -483,7 +488,7 @@ class NDS(nn.Module):
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
if self.training and self.auxiliary_loss:
return logits, logits_aux
return logits, logits_aux # type: ignore
else:
return logits
......@@ -524,8 +529,8 @@ class NASNet(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.NASNET_OPS,
......@@ -555,8 +560,8 @@ class ENAS(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.ENAS_OPS,
......@@ -590,8 +595,8 @@ class AmoebaNet(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
......@@ -626,8 +631,8 @@ class PNAS(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.PNAS_OPS,
......@@ -660,8 +665,8 @@ class DARTS(NDS):
]
def __init__(self,
width: Union[Tuple[int], int] = (16, 24, 32),
num_cells: Union[Tuple[int], int] = (4, 8, 12, 16, 20),
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.DARTS_OPS,
......
......@@ -2,7 +2,7 @@
# Licensed under the MIT license.
import math
from typing import Optional, Callable, List, Tuple
from typing import Optional, Callable, List, Tuple, cast
import torch
import nni.retiarii.nn.pytorch as nn
......@@ -31,12 +31,12 @@ class ConvBNReLU(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
groups: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
groups: nn.MaybeChoice[int] = 1,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None:
......@@ -46,9 +46,17 @@ class ConvBNReLU(nn.Sequential):
if activation_layer is None:
activation_layer = nn.ReLU6
super().__init__(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation=dilation, groups=groups,
bias=False),
norm_layer(out_channels),
nn.Conv2d(
cast(int, in_channels),
cast(int, out_channels),
cast(int, kernel_size),
stride,
cast(int, padding),
dilation=dilation,
groups=cast(int, groups),
bias=False
),
norm_layer(cast(int, out_channels)),
activation_layer(inplace=True)
)
self.out_channels = out_channels
......@@ -62,11 +70,11 @@ class SeparableConv(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__(
......@@ -101,13 +109,13 @@ class InvertedResidual(nn.Sequential):
def __init__(
self,
in_channels: int,
out_channels: int,
expand_ratio: int,
kernel_size: int = 3,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
expand_ratio: nn.MaybeChoice[float],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
squeeze_and_excite: Optional[Callable[[int], nn.Module]] = None,
norm_layer: Optional[Callable[..., nn.Module]] = None,
squeeze_and_excite: Optional[Callable[[nn.MaybeChoice[int]], nn.Module]] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
......@@ -115,7 +123,7 @@ class InvertedResidual(nn.Sequential):
self.out_channels = out_channels
assert stride in [1, 2]
hidden_ch = nn.ValueChoice.to_int(round(in_channels * expand_ratio))
hidden_ch = nn.ValueChoice.to_int(round(cast(int, in_channels * expand_ratio)))
# FIXME: check whether this equal works
# Residual connection is added here stride = 1 and input channels and output channels are the same.
......@@ -215,7 +223,7 @@ class ProxylessNAS(nn.Module):
self.first_conv = ConvBNReLU(3, widths[0], stride=2, norm_layer=nn.BatchNorm2d)
blocks = [
blocks: List[nn.Module] = [
# first stage is fixed
SeparableConv(widths[0], widths[1], kernel_size=3, stride=1)
]
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import cast
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
......@@ -14,7 +16,7 @@ class ShuffleNetBlock(nn.Module):
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int, *,
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *,
kernel_size: int, stride: int, sequence: str = "pdp", affine: bool = True):
super().__init__()
assert stride in [1, 2]
......@@ -57,14 +59,15 @@ class ShuffleNetBlock(nn.Module):
def _decode_point_depth_conv(self, sequence):
result = []
first_depth = first_point = True
pc = c = self.channels
pc: int = self.channels
c: int = self.channels
for i, token in enumerate(sequence):
# compute output channels of this conv
if i + 1 == len(sequence):
assert token == "p", "Last conv must be point-wise conv."
c = self.oup_main
elif token == "p" and first_point:
c = self.mid_channels
c = cast(int, self.mid_channels)
if token == "d":
# depth-wise conv
if isinstance(pc, int) and isinstance(c, int):
......@@ -101,7 +104,7 @@ class ShuffleXceptionBlock(ShuffleNetBlock):
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: int, *, stride: int, affine: bool = True):
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *, stride: int, affine: bool = True):
super().__init__(in_channels, out_channels, mid_channels,
kernel_size=3, stride=stride, sequence="dpdpdp", affine=affine)
......@@ -154,7 +157,7 @@ class ShuffleNetSpace(nn.Module):
nn.ReLU(inplace=True),
)
self.features = []
feature_blocks = []
global_block_idx = 0
for stage_idx, num_repeat in enumerate(self.stage_repeats):
......@@ -175,15 +178,17 @@ class ShuffleNetSpace(nn.Module):
else:
mid_channels = int(base_mid_channels)
mid_channels = cast(nn.MaybeChoice[int], mid_channels)
choice_block = nn.LayerChoice([
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=3, stride=stride, affine=affine),
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=5, stride=stride, affine=affine),
ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=7, stride=stride, affine=affine),
ShuffleXceptionBlock(in_channels, out_channels, mid_channels=mid_channels, stride=stride, affine=affine)
], label=f'layer_{global_block_idx}')
self.features.append(choice_block)
feature_blocks.append(choice_block)
self.features = nn.Sequential(*self.features)
self.features = nn.Sequential(*feature_blocks)
# final layers
last_conv_channels = self.stage_out_channels[-1]
......@@ -226,13 +231,15 @@ class ShuffleNetSpace(nn.Module):
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.running_mean, 0)
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.BatchNorm1d):
if m.weight is not None:
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
torch.nn.init.constant_(m.running_mean, 0)
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# Useful type hints
......@@ -3,7 +3,7 @@
import logging
import os
from typing import Any, Callable
from typing import Any, Callable, Optional
import nni
from nni.common.serializer import PayloadTooLarge
......@@ -53,11 +53,11 @@ class RetiariiAdvisor(MsgDispatcherBase):
register_advisor(self) # register the current advisor as the "global only" advisor
self.search_space = None
self.send_trial_callback: Callable[[dict], None] = None
self.request_trial_jobs_callback: Callable[[int], None] = None
self.trial_end_callback: Callable[[int, bool], None] = None
self.intermediate_metric_callback: Callable[[int, MetricData], None] = None
self.final_metric_callback: Callable[[int, MetricData], None] = None
self.send_trial_callback: Optional[Callable[[dict], None]] = None
self.request_trial_jobs_callback: Optional[Callable[[int], None]] = None
self.trial_end_callback: Optional[Callable[[int, bool], None]] = None
self.intermediate_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.final_metric_callback: Optional[Callable[[int, MetricData], None]] = None
self.parameters_count = 0
......@@ -158,19 +158,22 @@ class RetiariiAdvisor(MsgDispatcherBase):
def handle_trial_end(self, data):
_logger.debug('Trial end: %s', data)
self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
if self.trial_end_callback is not None:
self.trial_end_callback(nni.load(data['hyper_params'])['parameter_id'], # pylint: disable=not-callable
data['event'] == 'SUCCEEDED')
def handle_report_metric_data(self, data):
_logger.debug('Metric reported: %s', data)
if data['type'] == MetricType.REQUEST_PARAMETER:
raise ValueError('Request parameter not supported')
elif data['type'] == MetricType.PERIODICAL:
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
if self.intermediate_metric_callback is not None:
self.intermediate_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
elif data['type'] == MetricType.FINAL:
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
if self.final_metric_callback is not None:
self.final_metric_callback(data['parameter_id'], # pylint: disable=not-callable
self._process_value(data['value']))
@staticmethod
def _process_value(value) -> Any: # hopefully a float
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Iterable, List, Optional, Tuple)
import warnings
from typing import (Any, Iterable, List, Optional, Tuple, cast)
from .graph import Model, Mutation, ModelStatus
......@@ -44,9 +45,11 @@ class Mutator:
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
"""
def __init__(self, sampler: Optional[Sampler] = None, label: Optional[str] = None):
def __init__(self, sampler: Optional[Sampler] = None, label: str = cast(str, None)):
self.sampler: Optional[Sampler] = sampler
self.label: Optional[str] = label
if label is None:
warnings.warn('Each mutator should have an explicit label. Mutator without label is deprecated.', DeprecationWarning)
self.label: str = label
self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment