Unverified Commit a0fd0036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
parents d6dcb483 bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import copy
import random
from typing import Any, List, Dict, Sequence, cast
import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from .base import BaseSuperNetModule
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum
from .operation import MixedOperationSamplingPolicy, MixedOperation
__all__ = [
'PathSamplingLayer', 'PathSamplingInput',
'PathSamplingRepeat', 'PathSamplingCell',
'MixedOpPathSamplingPolicy'
]
class PathSamplingLayer(BaseSuperNetModule):
"""
Mixed layer, in which fprop is decided by exactly one inner layer or sum of multiple (sampled) layers.
If multiple modules are selected, the result will be summed and returned.
Attributes
----------
_sampled : int or list of str
Sampled module indices.
label : str
Name of the choice.
"""
def __init__(self, paths: list[tuple[str, nn.Module]], label: str):
super().__init__()
self.op_names = []
for name, module in paths:
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self._sampled: list[str] | str | None = None # sampled can be either a list of indices or an index
self.label = label
def resample(self, memo):
"""Random choose one path if label is not found in memo."""
if self.label in memo:
self._sampled = memo[self.label]
else:
self._sampled = random.choice(self.op_names)
return {self.label: self._sampled}
def export(self, memo):
"""Random choose one name if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: random.choice(self.op_names)}
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, LayerChoice):
return cls(list(module.named_children()), module.label)
def reduction(self, items: list[Any], sampled: list[Any]):
"""Override this to implement customized reduction."""
return weighted_sum(items)
def forward(self, *args, **kwargs):
if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
# str(samp) is needed here because samp can sometimes be integers, but attr are always str
res = [getattr(self, str(samp))(*args, **kwargs) for samp in sampled]
return self.reduction(res, sampled)
class PathSamplingInput(BaseSuperNetModule):
"""
Mixed input. Take a list of tensor as input, select some of them and return the sum.
Attributes
----------
_sampled : int or list of int
Sampled input indices.
"""
def __init__(self, n_candidates: int, n_chosen: int, reduction_type: str, label: str):
super().__init__()
self.n_candidates = n_candidates
self.n_chosen = n_chosen
self.reduction_type = reduction_type
self._sampled: list[int] | int | None = None
self.label = label
def _random_choose_n(self):
sampling = list(range(self.n_candidates))
random.shuffle(sampling)
sampling = sorted(sampling[:self.n_chosen])
if len(sampling) == 1:
return sampling[0]
else:
return sampling
def resample(self, memo):
"""Random choose one path / multiple paths if label is not found in memo.
If one path is selected, only one integer will be in ``self._sampled``.
If multiple paths are selected, a list will be in ``self._sampled``.
"""
if self.label in memo:
self._sampled = memo[self.label]
else:
self._sampled = self._random_choose_n()
return {self.label: self._sampled}
def export(self, memo):
"""Random choose one name if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self._random_choose_n()}
def search_space_spec(self):
return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
(self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen)
}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, InputChoice):
if module.reduction not in ['sum', 'mean', 'concat']:
raise ValueError('Only input choice of sum/mean/concat reduction is supported.')
if module.n_chosen is None:
raise ValueError('n_chosen is None is not supported yet.')
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
def reduction(self, items: list[Any], sampled: list[Any]) -> Any:
"""Override this to implement customized reduction."""
if len(items) == 1:
return items[0]
else:
if self.reduction_type == 'sum':
return sum(items)
elif self.reduction_type == 'mean':
return sum(items) / len(items)
elif self.reduction_type == 'concat':
return torch.cat(items, 1)
raise ValueError(f'Unsupported reduction type: {self.reduction_type}')
def forward(self, input_tensors):
if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.')
if len(input_tensors) != self.n_candidates:
raise ValueError(f'Expect {self.n_candidates} input tensors, found {len(input_tensors)}.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
res = [input_tensors[samp] for samp in sampled]
return self.reduction(res, sampled)
class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
"""Implementes the path sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Each value choice can be further decomposed into "leaf value choices".
We sample the leaf nodes, and composits them into the values on arguments.
"""
def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self._sampled: dict[str, Any] | None = None
def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Random sample for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
for label in space_spec:
if label in memo:
result[label] = memo[label]
else:
result[label] = random.choice(space_spec[label].values)
# composits to kwargs
# example: result = {"exp_ratio": 3}, self._sampled = {"in_channels": 48, "out_channels": 96}
self._sampled = {}
for key, value in operation.mutable_arguments.items():
self._sampled[key] = evaluate_value_choice_with_dict(value, result)
return result
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
for label in space_spec:
if label not in memo:
result[label] = random.choice(space_spec[label].values)
return result
def forward_argument(self, operation: MixedOperation, name: str) -> Any:
# NOTE: we don't support sampling a list here.
if self._sampled is None:
raise ValueError('Need to call resample() before running forward')
if name in operation.mutable_arguments:
return self._sampled[name]
return operation.init_arguments[name]
class PathSamplingRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a path-sampling supernet.
Samples one / some of the prefixes of the repeated blocks.
Attributes
----------
_sampled : int or list of int
Sampled depth.
"""
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int]):
super().__init__()
self.blocks = blocks
self.depth = depth
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._sampled: list[int] | int | None = None
def resample(self, memo):
"""Since depth is based on ValueChoice, we only need to randomly sample every leaf value choices."""
result = {}
for label in self._space_spec:
if label in memo:
result[label] = memo[label]
else:
result[label] = random.choice(self._space_spec[label].values)
self._sampled = evaluate_value_choice_with_dict(self.depth, result)
return result
def export(self, memo):
"""Random choose one if every choice not in memo."""
result = {}
for label in self._space_spec:
if label not in memo:
result[label] = random.choice(self._space_spec[label].values)
return result
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
# Only interesting when depth is mutable
return cls(cast(List[nn.Module], module.blocks), module.depth_choice)
def reduction(self, items: list[Any], sampled: list[Any]):
"""Override this to implement customized reduction."""
return weighted_sum(items)
def forward(self, x):
if self._sampled is None:
raise RuntimeError('At least one depth needs to be sampled before fprop.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
res = []
for cur_depth, block in enumerate(self.blocks, start=1):
x = block(x)
if cur_depth in sampled:
res.append(x)
if not any(d > cur_depth for d in sampled):
break
return self.reduction(res, sampled)
class PathSamplingCell(BaseSuperNetModule):
"""The implementation of super-net cell follows `DARTS <https://github.com/quark0/darts>`__.
When ``factory_used`` is true, it reconstructs the cell for every possible combination of operation and input index,
because for different input index, the cell factory could instantiate different operations (e.g., with different stride).
On export, we first have best (operation, input) pairs, the select the best ``num_ops_per_node``.
``loose_end`` is not supported yet, because it will cause more problems (e.g., shape mismatch).
We assumes ``loose_end`` to be ``all`` regardless of its configuration.
A supernet cell can't slim its own weight to fit into a sub network, which is also a known issue.
"""
def __init__(
self,
op_factory: list[CellOpFactory] | dict[str, CellOpFactory],
num_nodes: int,
num_ops_per_node: int,
num_predecessors: int,
preprocessor: Any,
postprocessor: Any,
concat_dim: int,
memo: dict, # although not used here, useful in subclass
mutate_kwargs: dict, # same as memo
label: str,
):
super().__init__()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
self.num_predecessors = num_predecessors
self.preprocessor = preprocessor
self.ops = nn.ModuleList()
self.postprocessor = postprocessor
self.concat_dim = concat_dim
self.op_names: list[str] = cast(List[str], None)
self.output_node_indices = list(range(self.num_predecessors, self.num_nodes + self.num_predecessors))
# Create a fully-connected graph.
# Each edge is a ModuleDict with op candidates.
# Can not reuse LayerChoice here, because the spec, resample, export all need to be customized.
# InputChoice is implicit in this graph.
for i in self.output_node_indices:
self.ops.append(nn.ModuleList())
for k in range(i + self.num_predecessors):
# Second argument in (i, **0**, k) is always 0.
# One-shot strategy can't handle the cases where op spec is dependent on `op_index`.
ops, _ = create_cell_op_candidates(op_factory, i, 0, k)
self.op_names = list(ops.keys())
cast(nn.ModuleList, self.ops[-1]).append(nn.ModuleDict(ops))
self.label = label
self._sampled: dict[str, str | int] = {}
def search_space_spec(self) -> dict[str, ParameterSpec]:
# TODO: Recreating the space here.
# The spec should be moved to definition of Cell itself.
space_spec = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for k in range(self.num_ops_per_node):
op_label = f'{self.label}/op_{i}_{k}'
input_label = f'{self.label}/input_{i}_{k}'
space_spec[op_label] = ParameterSpec(op_label, 'choice', self.op_names, (op_label,), True, size=len(self.op_names))
space_spec[input_label] = ParameterSpec(input_label, 'choice', list(range(i)), (input_label, ), True, size=i)
return space_spec
def resample(self, memo):
"""Random choose one path if label is not found in memo."""
self._sampled = {}
new_sampled = {}
for label, param_spec in self.search_space_spec().items():
if label in memo:
assert not isinstance(memo[label], list), 'Multi-path sampling is currently unsupported on cell.'
self._sampled[label] = memo[label]
else:
self._sampled[label] = new_sampled[label] = random.choice(param_spec.values)
return new_sampled
def export(self, memo):
"""Randomly choose one to export."""
return self.resample(memo)
def forward(self, *inputs: list[torch.Tensor] | torch.Tensor) -> tuple[torch.Tensor, ...] | torch.Tensor:
processed_inputs: List[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: List[torch.Tensor] = self.preprocessor(processed_inputs)
for i, ops in enumerate(cast(Sequence[Sequence[Dict[str, nn.Module]]], self.ops), start=self.num_predecessors):
current_state = []
for k in range(self.num_ops_per_node):
# Select op list based on the input chosen
input_index = self._sampled[f'{self.label}/input_{i}_{k}']
op_candidates = ops[cast(int, input_index)]
# Select op from op list based on the op chosen
op_index = self._sampled[f'{self.label}/op_{i}_{k}']
op = op_candidates[cast(str, op_index)]
current_state.append(op(states[cast(int, input_index)]))
states.append(sum(current_state)) # type: ignore
# Always merge all
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Cell):
op_factory = None # not all the cells need to be replaced
if module.op_candidates_factory is not None:
op_factory = module.op_candidates_factory
assert isinstance(op_factory, list) or isinstance(op_factory, dict), \
'Only support op_factory of type list or dict.'
elif module.merge_op == 'loose_end':
op_candidates_lc = module.ops[-1][-1] # type: ignore
assert isinstance(op_candidates_lc, LayerChoice)
op_factory = { # create a factory
name: lambda _, __, ___: copy.deepcopy(op_candidates_lc[name])
for name in op_candidates_lc.names
}
if op_factory is not None:
return cls(
op_factory,
module.num_nodes,
module.num_ops_per_node,
module.num_predecessors,
module.preprocessor,
module.postprocessor,
module.concat_dim,
memo,
mutate_kwargs,
module.label
)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base_mutator import BaseMutator
from .base_trainer import BaseTrainer
from .fixed import apply_fixed_architecture
from .mutables import Mutable, LayerChoice, InputChoice
from .mutator import Mutator
from .trainer import Trainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch.nn as nn
from nni.nas.pytorch.mutables import Mutable, MutableScope, InputChoice
from nni.nas.pytorch.utils import StructuredMutableTreeNode
logger = logging.getLogger(__name__)
class BaseMutator(nn.Module):
"""
A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing
callbacks that are called in ``forward`` in mutables.
Parameters
----------
model : nn.Module
PyTorch model to apply mutator on.
"""
def __init__(self, model):
super().__init__()
self.__dict__["model"] = model
self._structured_mutables = self._parse_search_space(self.model)
def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None):
if memo is None:
memo = set()
if root is None:
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
if nested_detection is not None:
raise RuntimeError("Cannot have nested search space. Error at {} in {}"
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY."
.format(k, module.key))
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo,
nested_detection=nested_detection)
return root
@property
def mutables(self):
"""
A generator of all modules inheriting :class:`~nni.nas.pytorch.mutables.Mutable`.
Modules are yielded in the order that they are defined in ``__init__``.
For mutables with their keys appearing multiple times, only the first one will appear.
"""
return self._structured_mutables
@property
def undedup_mutables(self):
return self._structured_mutables.traverse(deduplicate=False)
def forward(self, *inputs):
"""
Warnings
--------
Don't call forward of a mutator.
"""
raise RuntimeError("Forward is undefined for mutators.")
def __setattr__(self, name, value):
if name == "model":
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include you network, as it will include all parameters in model into the mutator.")
return super().__setattr__(name, value)
def enter_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is entered.
Parameters
----------
mutable_scope : MutableScope
The mutable scope that is entered.
"""
pass
def exit_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is exited.
Parameters
----------
mutable_scope : MutableScope
The mutable scope that is exited.
"""
pass
def on_forward_layer_choice(self, mutable, *args, **kwargs):
"""
Callbacks of forward in LayerChoice.
Parameters
----------
mutable : nni.nas.pytorch.mutables.LayerChoice
Module whose forward is called.
args : list of torch.Tensor
The arguments of its forward function.
kwargs : dict
The keyword arguments of its forward function.
Returns
-------
tuple of torch.Tensor and torch.Tensor
Output tensor and mask.
"""
raise NotImplementedError
def on_forward_input_choice(self, mutable, tensor_list):
"""
Callbacks of forward in InputChoice.
Parameters
----------
mutable : nni.nas.pytorch.mutables.InputChoice
Mutable that is called.
tensor_list : list of torch.Tensor
The arguments mutable is called with.
Returns
-------
tuple of torch.Tensor and torch.Tensor
Output tensor and mask.
"""
raise NotImplementedError
def export(self):
"""
Export the data of all decisions. This should output the decisions of all the mutables, so that the whole
network can be fully determined with these decisions for further training from scratch.
Returns
-------
dict
Mappings from mutable keys to decisions.
"""
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC, abstractmethod
class BaseTrainer(ABC):
@abstractmethod
def train(self):
"""
Override the method to train.
"""
raise NotImplementedError
@abstractmethod
def validate(self):
"""
Override the method to validate.
"""
raise NotImplementedError
@abstractmethod
def export(self, file):
"""
Override the method to export to file.
Parameters
----------
file : str
File path to export to.
"""
raise NotImplementedError
@abstractmethod
def checkpoint(self):
"""
Override to dump a checkpoint.
"""
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
import torch
import torch.nn as nn
_logger = logging.getLogger(__name__)
class Callback:
"""
Callback provides an easy way to react to events like begin/end of epochs.
"""
def __init__(self):
self.model = None
self.mutator = None
self.trainer = None
def build(self, model, mutator, trainer):
"""
Callback needs to be built with model, mutator, trainer, to get updates from them.
Parameters
----------
model : nn.Module
Model to be trained.
mutator : nn.Module
Mutator that mutates the model.
trainer : BaseTrainer
Trainer that is to call the callback.
"""
self.model = model
self.mutator = mutator
self.trainer = trainer
def on_epoch_begin(self, epoch):
"""
Implement this to do something at the begin of epoch.
Parameters
----------
epoch : int
Epoch number, starting from 0.
"""
pass
def on_epoch_end(self, epoch):
"""
Implement this to do something at the end of epoch.
Parameters
----------
epoch : int
Epoch number, starting from 0.
"""
pass
def on_batch_begin(self, epoch):
pass
def on_batch_end(self, epoch):
pass
class LRSchedulerCallback(Callback):
"""
Calls scheduler on every epoch ends.
Parameters
----------
scheduler : LRScheduler
Scheduler to be called.
"""
def __init__(self, scheduler, mode="epoch"):
super().__init__()
assert mode == "epoch"
self.scheduler = scheduler
self.mode = mode
def on_epoch_end(self, epoch):
"""
Call ``self.scheduler.step()`` on epoch end.
"""
self.scheduler.step()
class ArchitectureCheckpoint(Callback):
"""
Calls ``trainer.export()`` on every epoch ends.
Parameters
----------
checkpoint_dir : str
Location to save checkpoints.
"""
def __init__(self, checkpoint_dir):
super().__init__()
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch):
"""
Dump to ``/checkpoint_dir/epoch_{number}.json`` on epoch end.
"""
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))
_logger.info("Saving architecture to %s", dest_path)
self.trainer.export(dest_path)
class ModelCheckpoint(Callback):
"""
Calls ``trainer.export()`` on every epoch ends.
Parameters
----------
checkpoint_dir : str
Location to save checkpoints.
"""
def __init__(self, checkpoint_dir):
super().__init__()
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch):
"""
Dump to ``/checkpoint_dir/epoch_{number}.pth.tar`` on every epoch end.
``DataParallel`` object will have their inside modules exported.
"""
if isinstance(self.model, nn.DataParallel):
state_dict = self.model.module.state_dict()
else:
state_dict = self.model.state_dict()
dest_path = os.path.join(self.checkpoint_dir, "epoch_{}.pth.tar".format(epoch))
_logger.info("Saving model to %s", dest_path)
torch.save(state_dict, dest_path)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
from .mutables import InputChoice, LayerChoice, MutableScope
from .mutator import Mutator
from .utils import to_list
_logger = logging.getLogger(__name__)
class FixedArchitecture(Mutator):
"""
Fixed architecture mutator that always selects a certain graph.
Parameters
----------
model : nn.Module
A mutable network.
fixed_arc : dict
Preloaded architecture object.
strict : bool
Force everything that appears in ``fixed_arc`` to be used at least once.
verbose : bool
Print log messages if set to True
"""
def __init__(self, model, fixed_arc, strict=True, verbose=True):
super().__init__(model)
self._fixed_arc = fixed_arc
self.verbose = verbose
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
fixed_arc_keys = set(self._fixed_arc.keys())
if fixed_arc_keys - mutable_keys:
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
if mutable_keys - fixed_arc_keys:
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
self._fixed_arc = self._from_human_readable_architecture(self._fixed_arc)
def _from_human_readable_architecture(self, human_arc):
# convert from an exported architecture
result_arc = {k: to_list(v) for k, v in human_arc.items()} # there could be tensors, numpy arrays, etc.
# First, convert non-list to list, because there could be {"op1": 0} or {"op1": "conv"},
# which means {"op1": [0, ]} ir {"op1": ["conv", ]}
result_arc = {k: v if isinstance(v, list) else [v] for k, v in result_arc.items()}
# Second, infer which ones are multi-hot arrays and which ones are in human-readable format.
# This is non-trivial, since if an array in [0, 1], we cannot know for sure it means [false, true] or [true, true].
# Here, we assume an multihot array has to be a boolean array or a float array and matches the length.
for mutable in self.mutables:
if mutable.key not in result_arc:
continue # skip silently
choice_arr = result_arc[mutable.key]
if all(isinstance(v, bool) for v in choice_arr) or all(isinstance(v, float) for v in choice_arr):
if (isinstance(mutable, LayerChoice) and len(mutable) == len(choice_arr)) or \
(isinstance(mutable, InputChoice) and mutable.n_candidates == len(choice_arr)):
# multihot, do nothing
continue
if isinstance(mutable, LayerChoice):
choice_arr = [mutable.names.index(val) if isinstance(val, str) else val for val in choice_arr]
choice_arr = [i in choice_arr for i in range(len(mutable))]
elif isinstance(mutable, InputChoice):
choice_arr = [mutable.choose_from.index(val) if isinstance(val, str) else val for val in choice_arr]
choice_arr = [i in choice_arr for i in range(mutable.n_candidates)]
result_arc[mutable.key] = choice_arr
return result_arc
def sample_search(self):
"""
Always returns the fixed architecture.
"""
return self._fixed_arc
def sample_final(self):
"""
Always returns the fixed architecture.
"""
return self._fixed_arc
def replace_layer_choice(self, module=None, prefix=""):
"""
Replace layer choices with selected candidates. It's done with best effort.
In case of weighted choices or multiple choices. if some of the choices on weighted with zero, delete them.
If single choice, replace the module with a normal module.
Parameters
----------
module : nn.Module
Module to be processed.
prefix : str
Module name under global namespace.
"""
if module is None:
module = self.model
for name, mutable in module.named_children():
global_name = (prefix + "." if prefix else "") + name
if isinstance(mutable, LayerChoice):
chosen = self._fixed_arc[mutable.key]
if sum(chosen) == 1 and max(chosen) == 1 and not mutable.return_mask:
# sum is one, max is one, there has to be an only one
# this is compatible with both integer arrays, boolean arrays and float arrays
if self.verbose:
_logger.info("Replacing %s with candidate number %d.", global_name, chosen.index(1))
setattr(module, name, mutable[chosen.index(1)])
else:
if mutable.return_mask and self.verbose:
_logger.info("`return_mask` flag of %s is true. As it relies on the behavior of LayerChoice, " \
"LayerChoice will not be replaced.")
# remove unused parameters
for ch, n in zip(chosen, mutable.names):
if ch == 0 and not isinstance(ch, float):
setattr(mutable, n, None)
else:
self.replace_layer_choice(mutable, global_name)
def apply_fixed_architecture(model, fixed_arc, verbose=True):
"""
Load architecture from `fixed_arc` and apply to model.
Parameters
----------
model : torch.nn.Module
Model with mutables.
fixed_arc : str or dict
Path to the JSON that stores the architecture, or dict that stores the exported architecture.
verbose : bool
Print log messages if set to True
Returns
-------
FixedArchitecture
Mutator that is responsible for fixes the graph.
"""
if isinstance(fixed_arc, str):
with open(fixed_arc) as f:
fixed_arc = json.load(f)
architecture = FixedArchitecture(model, fixed_arc, verbose)
architecture.reset()
# for the convenience of parameters counting
architecture.replace_layer_choice()
return architecture
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import warnings
from collections import OrderedDict
import torch.nn as nn
from nni.nas.pytorch.utils import global_mutable_counting
logger = logging.getLogger(__name__)
class Mutable(nn.Module):
"""
Mutable is designed to function as a normal layer, with all necessary operators' weights.
States and weights of architectures should be included in mutator, instead of the layer itself.
Mutable has a key, which marks the identity of the mutable. This key can be used by users to share
decisions among different mutables. In mutator's implementation, mutators should use the key to
distinguish different mutables. Mutables that share the same key should be "similar" to each other.
Currently the default scope for keys is global. By default, the keys uses a global counter from 1 to
produce unique ids.
Parameters
----------
key : str
The key of mutable.
Notes
-----
The counter is program level, but mutables are model level. In case multiple models are defined, and
you want to have `counter` starting from 1 in the second model, it's recommended to assign keys manually
instead of using automatic keys.
"""
def __init__(self, key=None):
super().__init__()
if key is not None:
if not isinstance(key, str):
key = str(key)
logger.warning("Warning: key \"%s\" is not string, converted to string.", key)
self._key = key
else:
self._key = self.__class__.__name__ + str(global_mutable_counting())
self.init_hook = self.forward_hook = None
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator):
if "mutator" in self.__dict__:
raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? "
"Or did you apply multiple fixed architectures?")
self.__dict__["mutator"] = mutator
@property
def key(self):
"""
Read-only property of key.
"""
return self._key
@property
def name(self):
"""
After the search space is parsed, it will be the module name of the mutable.
"""
return self._name if hasattr(self, "_name") else self._key
@name.setter
def name(self, name):
self._name = name
def _check_built(self):
if not hasattr(self, "mutator"):
raise ValueError(
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
class MutableScope(Mutable):
"""
Mutable scope marks a subgraph/submodule to help mutators make better decisions.
If not annotated with mutable scope, search space will be flattened as a list. However, some mutators might
need to leverage the concept of a "cell". So if a module is defined as a mutable scope, everything in it will
look like "sub-search-space" in the scope. Scopes can be nested.
There are two ways mutators can use mutable scope. One is to traverse the search space as a tree during initialization
and reset. The other is to implement `enter_mutable_scope` and `exit_mutable_scope`. They are called before and after
the forward method of the class inheriting mutable scope.
Mutable scopes are also mutables that are listed in the mutator.mutables (search space), but they are not supposed
to appear in the dict of choices.
Parameters
----------
key : str
Key of mutable scope.
"""
def __init__(self, key):
super().__init__(key=key)
def _check_built(self):
return True # bypass the test because it's deprecated
def __call__(self, *args, **kwargs):
if not hasattr(self, 'mutator'):
return super().__call__(*args, **kwargs)
warnings.warn("`MutableScope` is deprecated in Retiarii.", DeprecationWarning)
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
class LayerChoice(Mutable):
"""
Layer choice selects one of the ``op_candidates``, then apply it on inputs and return results.
In rare cases, it can also select zero or many.
Layer choice does not allow itself to be nested.
Parameters
----------
op_candidates : list of nn.Module or OrderedDict
A module list to be selected from.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``. Policy if multiples are selected.
If ``none``, a list is returned. ``mean`` returns the average. ``sum`` returns the sum.
``concat`` concatenate the list at dimension 1.
return_mask : bool
If ``return_mask``, return output tensor and a mask. Otherwise return tensor only.
key : str
Key of the input choice.
Attributes
----------
length : int
Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended.
names : list of str
Names of candidates.
choices : list of Module
Deprecated. A list of all candidate modules in the layer choice module.
``list(layer_choice)`` is recommended, which will serve the same purpose.
Notes
-----
``op_candidates`` can be a list of modules or a ordered dict of named modules, for example,
.. code-block:: python
self.op_choice = LayerChoice(OrderedDict([
("conv3x3", nn.Conv2d(3, 16, 128)),
("conv5x5", nn.Conv2d(5, 16, 128)),
("conv7x7", nn.Conv2d(7, 16, 128))
]))
Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key)
self.names = []
if isinstance(op_candidates, OrderedDict):
for name, module in op_candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.add_module(name, module)
self.names.append(name)
elif isinstance(op_candidates, list):
for i, module in enumerate(op_candidates):
self.add_module(str(i), module)
self.names.append(str(i))
else:
raise TypeError("Unsupported op_candidates type: {}".format(type(op_candidates)))
self.reduction = reduction
self.return_mask = return_mask
def __getitem__(self, idx):
if isinstance(idx, str):
return self._modules[idx]
return list(self)[idx]
def __setitem__(self, idx, module):
key = idx if isinstance(idx, str) else self.names[idx]
return setattr(self, key, module)
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in self.names[idx]:
delattr(self, key)
else:
if isinstance(idx, str):
key, idx = idx, self.names.index(idx)
else:
key = self.names[idx]
delattr(self, key)
del self.names[idx]
@property
def length(self):
warnings.warn("layer_choice.length is deprecated. Use `len(layer_choice)` instead.", DeprecationWarning)
return len(self)
def __len__(self):
return len(self.names)
def __iter__(self):
return map(lambda name: self._modules[name], self.names)
@property
def choices(self):
warnings.warn("layer_choice.choices is deprecated. Use `list(layer_choice)` instead.", DeprecationWarning)
return list(self)
def forward(self, *args, **kwargs):
"""
Returns
-------
tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
"""
out, mask = self.mutator.on_forward_layer_choice(self, *args, **kwargs)
if self.return_mask:
return out, mask
return out
class InputChoice(Mutable):
"""
Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys). For beginners,
use ``n_candidates`` instead of ``choose_from`` is a safe option. To get the most power out of it, you might want to
know about ``choose_from``.
The keys in ``choose_from`` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones.
The keys are designed to be the keys of the sources. To help mutators make better decisions,
mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the
output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g.,
``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a
module/submodule, it needs to be annotated with a key: that's where a :class:`MutableScope` is needed.
In the example below, ``input_choice`` is a 4-choose-any. The first 3 is semantically output of cell1, output of cell2,
output of cell3 with respectively. Notice that an extra max pooling is followed by cell1, indicating x1 is not
"actually" the direct output of cell1.
.. code-block:: python
class Cell(MutableScope):
pass
class Net(nn.Module):
def __init__(self):
self.cell1 = Cell("cell1")
self.cell2 = Cell("cell2")
self.op = LayerChoice([conv3x3(), conv5x5()], key="op")
self.input_choice = InputChoice(choose_from=["cell1", "cell2", "op", InputChoice.NO_KEY])
def forward(self, x):
x1 = max_pooling(self.cell1(x))
x2 = self.cell2(x)
x3 = self.op(x)
x4 = torch.zeros_like(x)
return self.input_choice([x1, x2, x3, x4])
Parameters
----------
n_candidates : int
Number of inputs to choose from.
choose_from : list of str
List of source keys to choose from. At least of one of ``choose_from`` and ``n_candidates`` must be fulfilled.
If ``n_candidates`` has a value but ``choose_from`` is None, it will be automatically treated as ``n_candidates``
number of empty string.
n_chosen : int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``. See :class:`LayerChoice`.
return_mask : bool
If ``return_mask``, return output tensor and a mask. Otherwise return tensor only.
key : str
Key of the input choice.
"""
NO_KEY = ""
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
reduction="sum", return_mask=False, key=None):
super().__init__(key=key)
# precondition check
assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \
"must be not None."
if choose_from is not None and n_candidates is None:
n_candidates = len(choose_from)
elif choose_from is None and n_candidates is not None:
choose_from = [self.NO_KEY] * n_candidates
assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`."
assert n_candidates > 0, "Number of candidates must be greater than 0."
assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \
"than number of candidates."
self.n_candidates = n_candidates
self.choose_from = choose_from.copy()
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
def forward(self, optional_inputs):
"""
Forward method of LayerChoice.
Parameters
----------
optional_inputs : list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
``choose_from`` in initialization. As a list, inputs must follow the semantic order that is the same as
``choose_from``.
Returns
-------
tuple of tensors
Output and selection mask. If ``return_mask`` is ``False``, only output is returned.
"""
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), \
"Optional input list must be a list, not a {}.".format(type(optional_input_list))
assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates."
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
if self.return_mask:
return out, mask
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import defaultdict
import numpy as np
import torch
from .base_mutator import BaseMutator
from .mutables import LayerChoice, InputChoice
from .utils import to_list
logger = logging.getLogger(__name__)
class Mutator(BaseMutator):
def __init__(self, model):
super().__init__(model)
self._cache = dict()
self._connect_all = False
def sample_search(self):
"""
Override to implement this method to iterate over mutables and make decisions.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
def sample_final(self):
"""
Override to implement this method to iterate over mutables and make decisions that is final
for export and retraining.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
def reset(self):
"""
Reset the mutator by call the `sample_search` to resample (for search). Stores the result in a local
variable so that `on_forward_layer_choice` and `on_forward_input_choice` can use the decision directly.
"""
self._cache = self.sample_search()
def export(self):
"""
Resample (for final) and return results.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
sampled = self.sample_final()
result = dict()
for mutable in self.mutables:
if not isinstance(mutable, (LayerChoice, InputChoice)):
# not supported as built-in
continue
result[mutable.key] = self._convert_mutable_decision_to_human_readable(mutable, sampled.pop(mutable.key))
if sampled:
raise ValueError("Unexpected keys returned from 'sample_final()': %s", list(sampled.keys()))
return result
def status(self):
"""
Return current selection status of mutator.
Returns
-------
dict
A mapping from key of mutables to decisions. All weights (boolean type and float type)
are converted into real number values. Numpy arrays and tensors are converted into list.
"""
data = dict()
for k, v in self._cache.items():
if torch.is_tensor(v):
v = v.detach().cpu().numpy().tolist()
if isinstance(v, np.ndarray):
v = v.astype(np.float32).tolist()
data[k] = v
return data
def graph(self, inputs):
"""
Return model supernet graph.
Parameters
----------
inputs: tuple of tensor
Inputs that will be feeded into the network.
Returns
-------
dict
Containing ``node``, in Tensorboard GraphDef format.
Additional key ``mutable`` is a map from key to list of modules.
"""
if not torch.__version__.startswith("1.4"):
logger.warning("Graph is only tested with PyTorch 1.4. Other versions might not work.")
from nni.common.graph_utils import build_graph
from google.protobuf import json_format
# protobuf should be installed as long as tensorboard is installed
try:
self._connect_all = True
graph_def, _ = build_graph(self.model, inputs, verbose=False)
result = json_format.MessageToDict(graph_def)
finally:
self._connect_all = False
# `mutable` is to map the keys to a list of corresponding modules.
# A key can be linked to multiple modules, use `dedup=False` to find them all.
result["mutable"] = defaultdict(list)
for mutable in self.mutables.traverse(deduplicate=False):
# A module will be represent in the format of
# [{"type": "Net", "name": ""}, {"type": "Cell", "name": "cell1"}, {"type": "Conv2d": "name": "conv"}]
# which will be concatenated into Net/Cell[cell1]/Conv2d[conv] in frontend.
# This format is aligned with the scope name jit gives.
modules = mutable.name.split(".")
path = [
{"type": self.model.__class__.__name__, "name": ""}
]
m = self.model
for module in modules:
m = getattr(m, module)
path.append({
"type": m.__class__.__name__,
"name": module
})
result["mutable"][mutable.key].append(path)
return result
def on_forward_layer_choice(self, mutable, *args, **kwargs):
"""
On default, this method retrieves the decision obtained previously, and select certain operations.
Only operations with non-zero weight will be executed. The results will be added to a list.
Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`.
Parameters
----------
mutable : nni.nas.pytorch.mutables.LayerChoice
Layer choice module.
args : list of torch.Tensor
Inputs
kwargs : dict
Inputs
Returns
-------
tuple of torch.Tensor and torch.Tensor
Output and mask.
"""
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction,
[op(*args, **kwargs) for op in mutable]), \
torch.ones(len(mutable)).bool()
def _map_fn(op, args, kwargs):
return op(*args, **kwargs)
mask = self._get_decision(mutable)
assert len(mask) == len(mutable), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable))
out, mask = self._select_with_mask(_map_fn, [(choice, args, kwargs) for choice in mutable], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list):
"""
On default, this method retrieves the decision obtained previously, and select certain tensors.
Then it will reduce the list of all tensor outputs with the policy specified in `mutable.reduction`.
Parameters
----------
mutable : nni.nas.pytorch.mutables.InputChoice
Input choice module.
tensor_list : list of torch.Tensor
Tensor list to apply the decision on.
Returns
-------
tuple of torch.Tensor and torch.Tensor
Output and mask.
"""
if self._connect_all:
return self._all_connect_tensor_reduction(mutable.reduction, tensor_list), \
torch.ones(mutable.n_candidates).bool()
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
out, mask = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def _select_with_mask(self, map_fn, candidates, mask):
"""
Select masked tensors and return a list of tensors.
Parameters
----------
map_fn : function
Convert candidates to target candidates. Can be simply identity.
candidates : list of torch.Tensor
Tensor list to apply the decision on.
mask : list-like object
Can be a list, an numpy array or a tensor (recommended). Needs to
have the same length as ``candidates``.
Returns
-------
tuple of list of torch.Tensor and torch.Tensor
Output and mask.
"""
if (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], bool)) or \
(isinstance(mask, np.ndarray) and mask.dtype == np.bool) or \
"BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif (isinstance(mask, list) and len(mask) >= 1 and isinstance(mask[0], (float, int))) or \
(isinstance(mask, np.ndarray) and mask.dtype in (np.float32, np.float64, np.int32, np.int64)) or \
"FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask) if m]
else:
raise ValueError("Unrecognized mask '%s'" % mask)
if not torch.is_tensor(mask):
mask = torch.tensor(mask) # pylint: disable=not-callable
return out, mask
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def _all_connect_tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
return torch.stack(tensor_list).sum(0)
def _get_decision(self, mutable):
"""
By default, this method checks whether `mutable.key` is already in the decision cache,
and returns the result without double-check.
Parameters
----------
mutable : Mutable
Returns
-------
object
"""
if mutable.key not in self._cache:
raise ValueError("\"{}\" not found in decision cache.".format(mutable.key))
result = self._cache[mutable.key]
logger.debug("Decision %s: %s", mutable.key, result)
return result
def _convert_mutable_decision_to_human_readable(self, mutable, sampled):
# Assert the existence of mutable.key in returned architecture.
# Also check if there is anything extra.
multihot_list = to_list(sampled)
converted = None
# If it's a boolean array, we can do optimization.
if all([t == 0 or t == 1 for t in multihot_list]):
if isinstance(mutable, LayerChoice):
assert len(multihot_list) == len(mutable), \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all modules have different names and they indeed have names
if len(set(mutable.names)) == len(mutable) and not all(d.isdigit() for d in mutable.names):
converted = [name for i, name in enumerate(mutable.names) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if isinstance(mutable, InputChoice):
assert len(multihot_list) == mutable.n_candidates, \
"Results returned from 'sample_final()' (%s: %s) either too short or too long." \
% (mutable.key, multihot_list)
# check if all input candidates have different names
if len(set(mutable.choose_from)) == mutable.n_candidates:
converted = [name for i, name in enumerate(mutable.choose_from) if multihot_list[i]]
else:
converted = [i for i in range(len(multihot_list)) if multihot_list[i]]
if converted is not None:
# if only one element, then remove the bracket
if len(converted) == 1:
converted = converted[0]
else:
# do nothing
converted = multihot_list
return converted
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict
import torch.nn as nn
from nni.nas.pytorch.mutables import LayerChoice
from .nasbench201_ops import Pooling, ReLUConvBN, Zero, FactorizedReduce
class NASBench201Cell(nn.Module):
"""
Builtin cell structure of NAS Bench 201. One cell contains four nodes. The First node serves as an input node
accepting the output of the previous cell. And other nodes connect to all previous nodes with an edge that
represents an operation chosen from a set to transform the tensor from the source node to the target node.
Every node accepts all its inputs and adds them as its output.
Parameters
---
cell_id: str
the name of this cell
C_in: int
the number of input channels of the cell
C_out: int
the number of output channels of the cell
stride: int
stride of all convolution operations in the cell
bn_affine: bool
If set to ``True``, all ``torch.nn.BatchNorm2d`` in this cell will have learnable affine parameters. Default: True
bn_momentum: float
the value used for the running_mean and running_var computation. Default: 0.1
bn_track_running_stats: bool
When set to ``True``, all ``torch.nn.BatchNorm2d`` in this cell tracks the running mean and variance. Default: True
"""
def __init__(self, cell_id, C_in, C_out, stride, bn_affine=True, bn_momentum=0.1, bn_track_running_stats=True):
super(NASBench201Cell, self).__init__()
self.NUM_NODES = 4
self.layers = nn.ModuleList()
OPS = lambda layer_idx: OrderedDict([
("none", Zero(C_in, C_out, stride)),
("avg_pool_3x3", Pooling(C_in, C_out, stride if layer_idx == 0 else 1, bn_affine, bn_momentum,
bn_track_running_stats)),
("conv_3x3", ReLUConvBN(C_in, C_out, 3, stride if layer_idx == 0 else 1, 1, 1, bn_affine, bn_momentum,
bn_track_running_stats)),
("conv_1x1", ReLUConvBN(C_in, C_out, 1, stride if layer_idx == 0 else 1, 0, 1, bn_affine, bn_momentum,
bn_track_running_stats)),
("skip_connect", nn.Identity() if stride == 1 and C_in == C_out
else FactorizedReduce(C_in, C_out, stride if layer_idx == 0 else 1, bn_affine, bn_momentum,
bn_track_running_stats))
])
for i in range(self.NUM_NODES):
node_ops = nn.ModuleList()
for j in range(0, i):
node_ops.append(LayerChoice(OPS(j), key="%d_%d" % (j, i), reduction="mean"))
self.layers.append(node_ops)
self.in_dim = C_in
self.out_dim = C_out
self.cell_id = cell_id
def forward(self, input): # pylint: disable=W0622
"""
Parameters
---
input: torch.tensor
the output of the previous layer
"""
nodes = [input]
for i in range(1, self.NUM_NODES):
node_feature = sum(self.layers[i][k](nodes[k]) for k in range(i))
nodes.append(node_feature)
return nodes[-1]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
class ReLUConvBN(nn.Module):
"""
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
stride: int
stride of the convolution
padding: int
zero-padding added to both sides of the input
dilation: int
spacing between kernel elements
bn_affine: bool
If set to ``True``, ``torch.nn.BatchNorm2d`` will have learnable affine parameters. Default: True
bn_momentun: float
the value used for the running_mean and running_var computation. Default: 0.1
bn_track_running_stats: bool
When set to ``True``, ``torch.nn.BatchNorm2d`` tracks the running mean and variance. Default: True
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation,
bn_affine=True, bn_momentum=0.1, bn_track_running_stats=True):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out, affine=bn_affine, momentum=bn_momentum,
track_running_stats=bn_track_running_stats)
)
def forward(self, x):
"""
Parameters
---
x: torch.Tensor
input tensor
"""
return self.op(x)
class Pooling(nn.Module):
"""
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
stride: int
stride of the convolution
bn_affine: bool
If set to ``True``, ``torch.nn.BatchNorm2d`` will have learnable affine parameters. Default: True
bn_momentun: float
the value used for the running_mean and running_var computation. Default: 0.1
bn_track_running_stats: bool
When set to ``True``, ``torch.nn.BatchNorm2d`` tracks the running mean and variance. Default: True
"""
def __init__(self, C_in, C_out, stride, bn_affine=True, bn_momentum=0.1, bn_track_running_stats=True):
super(Pooling, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 0,
bn_affine, bn_momentum, bn_track_running_stats)
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
def forward(self, x):
"""
Parameters
---
x: torch.Tensor
input tensor
"""
if self.preprocess:
x = self.preprocess(x)
return self.op(x)
class Zero(nn.Module):
"""
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
stride: int
stride of the convolution
"""
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
"""
Parameters
---
x: torch.Tensor
input tensor
"""
if self.C_in == self.C_out:
if self.stride == 1:
return x.mul(0.)
else:
return x[:, :, ::self.stride, ::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride, bn_affine=True, bn_momentum=0.1,
bn_track_running_stats=True):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
else:
raise ValueError("Invalid stride : {:}".format(stride))
self.bn = nn.BatchNorm2d(C_out, affine=bn_affine, momentum=bn_momentum,
track_running_stats=bn_track_running_stats)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .darts_cell import DartsCell
from .enas_cell import ENASMicroLayer
from .enas_cell import ENASMacroLayer
from .enas_cell import ENASMacroGeneralModel
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict
import torch
import torch.nn as nn
from nni.nas.pytorch import mutables
from .darts_ops import PoolBN, SepConv, DilConv, FactorizedReduce, DropPath, StdConv
class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
"""
builtin Darts Node structure
Parameters
---
node_id: str
num_prev_nodes: int
the number of previous nodes in this cell
channels: int
output channels
num_downsample_connect: int
downsample the input node if this cell is reduction cell
"""
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(OrderedDict([
("maxpool", PoolBN('max', channels, 3, stride, 1, affine=False)),
("avgpool", PoolBN('avg', channels, 3, stride, 1, affine=False)),
("skipconnect",
nn.Identity() if stride == 1 else FactorizedReduce(channels, channels, affine=False)),
("sepconv3x3", SepConv(channels, channels, 3, stride, 1, affine=False)),
("sepconv5x5", SepConv(channels, channels, 5, stride, 2, affine=False)),
("dilconv3x3", DilConv(channels, channels, 3, stride, 2, 2, affine=False)),
("dilconv5x5", DilConv(channels, channels, 5, stride, 4, 2, affine=False))
]), key=choice_keys[-1]))
self.drop_path = DropPath()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out)
class DartsCell(nn.Module):
"""
Builtin Darts Cell structure. There are ``n_nodes`` nodes in one cell, in which the first two nodes' values are
fixed to the results of previous previous cell and previous cell respectively. One node will connect all
the nodes after with predefined operations in a mutable way. The last node accepts five inputs from nodes
before and it concats all inputs in channels as the output of the current cell, and the number of output
channels is ``n_nodes`` times ``channels``.
Parameters
---
n_nodes: int
the number of nodes contained in this cell
channels_pp: int
the number of previous previous cell's output channels
channels_p: int
the number of previous cell's output channels
channels: int
the number of output channels for each node
reduction_p: bool
Is previous cell a reduction cell
reduction: bool
is current cell a reduction cell
"""
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth, channels, 2 if reduction else 0))
def forward(self, pprev, prev):
"""
Parameters
---
pprev: torch.Tensor
the output of the previous previous layer
prev: torch.Tensor
the output of the previous layer
"""
tensors = [self.preproc0(pprev), self.preproc1(prev)]
for node in self.mutable_ops:
cur_tensor = node(tensors)
tensors.append(cur_tensor)
output = torch.cat(tensors[2:], dim=1)
return output
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
class DropPath(nn.Module):
def __init__(self, p=0.):
"""
Drop path with probability.
Parameters
----------
p : float
Probability of an path to be zeroed.
"""
super().__init__()
self.p = p
def forward(self, x):
if self.training and self.p > 0.:
keep_prob = 1. - self.p
# per data point mask
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
return x / keep_prob * mask
return x
class PoolBN(nn.Module):
"""
AvgPool or MaxPool with BN. ``pool_type`` must be ``max`` or ``avg``.
Parameters
---
pool_type: str
choose operation
C: int
number of channels
kernal_size: int
size of the convolving kernel
stride: int
stride of the convolution
padding: int
zero-padding added to both sides of the input
affine: bool
is using affine in BatchNorm
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C, affine=affine)
def forward(self, x):
out = self.pool(x)
out = self.bn(out)
return out
class StdConv(nn.Sequential):
"""
Standard conv: ReLU - Conv - BN
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
kernel_size: int
size of the convolution kernel
padding:
zero-padding added to both sides of the input
affine: bool
is using affine in BatchNorm
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential
for idx, ops in enumerate((nn.ReLU(), nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine))):
self.add_module(str(idx), ops)
class FacConv(nn.Module):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class DilConv(nn.Module):
"""
(Dilated) depthwise separable conv.
ReLU - (Dilated) depthwise separable - Pointwise - BN.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
kernal_size:
size of the convolving kernel
padding:
zero-padding added to both sides of the input
dilation: int
spacing between kernel elements.
affine: bool
is using affine in BatchNorm
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class SepConv(nn.Module):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
kernal_size:
size of the convolving kernel
padding:
zero-padding added to both sides of the input
dilation: int
spacing between kernel elements.
affine: bool
is using affine in BatchNorm
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
)
def forward(self, x):
return self.net(x)
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch import mutables
from .enas_ops import FactorizedReduce, StdConv, SepConvBN, Pool, ConvBranch, PoolBranch
class Cell(nn.Module):
def __init__(self, cell_name, prev_labels, channels):
super().__init__()
self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
key=cell_name + "_input")
self.op_choice = mutables.LayerChoice([
SepConvBN(channels, channels, 3, 1),
SepConvBN(channels, channels, 5, 2),
Pool("avg", 3, 1, 1),
Pool("max", 3, 1, 1),
nn.Identity()
], key=cell_name + "_op")
def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers)
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask
class Node(mutables.MutableScope):
def __init__(self, node_name, prev_node_names, channels):
super().__init__(node_name)
self.cell_x = Cell(node_name + "_x", prev_node_names, channels)
self.cell_y = Cell(node_name + "_y", prev_node_names, channels)
def forward(self, prev_layers):
out_x, mask_x = self.cell_x(prev_layers)
out_y, mask_y = self.cell_y(prev_layers)
return out_x + out_y, mask_x | mask_y
class Calibration(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.process = None
if in_channels != out_channels:
self.process = StdConv(in_channels, out_channels)
def forward(self, x):
if self.process is None:
return x
return self.process(x)
class ENASMicroLayer(nn.Module):
"""
Builtin EnasMicroLayer. Micro search designs only one building block whose architecture is repeated
throughout the final architecture. A cell has ``num_nodes`` nodes and searches the topology and
operations among them in RL way. The first two nodes in a layer stand for the outputs from previous
previous layer and previous layer respectively. For the following nodes, the controller chooses
two previous nodes and applies two operations respectively for each node. Nodes that are not served
as input for any other node are viewed as the output of the layer. If there are multiple output nodes,
the model will calculate the average of these nodes as the layer output. Every node's output has ``out_channels``
channels so the result of the layer has the same number of channels as each node.
Parameters
---
num_nodes: int
the number of nodes contained in this layer
in_channles_pp: int
the number of previous previous layer's output channels
in_channels_p: int
the number of previous layer's output channels
out_channels: int
output channels of this layer
reduction: bool
is reduction operation empolyed before this layer
"""
def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction):
super().__init__()
self.reduction = reduction
if self.reduction:
self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False)
self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False)
in_channels_pp = in_channels_p = out_channels
self.preproc0 = Calibration(in_channels_pp, out_channels)
self.preproc1 = Calibration(in_channels_p, out_channels)
self.num_nodes = num_nodes
name_prefix = "reduce" if reduction else "normal"
self.nodes = nn.ModuleList()
node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY]
for i in range(num_nodes):
node_labels.append("{}_node_{}".format(name_prefix, i))
self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels))
self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1),
requires_grad=True)
self.bn = nn.BatchNorm2d(out_channels, affine=False)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_normal_(self.final_conv_w)
def forward(self, pprev, prev):
"""
Parameters
---
pprev: torch.Tensor
the output of the previous previous layer
prev: torch.Tensor
the output of the previous layer
"""
if self.reduction:
pprev, prev = self.reduce0(pprev), self.reduce1(prev)
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)
prev_nodes_out = [pprev_, prev_]
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
for i in range(self.num_nodes):
node_out, mask = self.nodes[i](prev_nodes_out)
nodes_used_mask[:mask.size(0)] |= mask.to(node_out.device)
prev_nodes_out.append(node_out)
unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1)
unused_nodes = F.relu(unused_nodes)
conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :]
conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1)
out = F.conv2d(unused_nodes, conv_weight)
return prev, self.bn(out)
class ENASMacroLayer(mutables.MutableScope):
"""
Builtin ENAS Marco Layer. With search space changing to layer level, the controller decides
what operation is employed and the previous layer to connect to for skip connections. The model
is made up of the same layers but the choice of each layer may be different.
Parameters
---
key: str
the name of this layer
prev_labels: str
names of all previous layers
in_filters: int
the number of input channels
out_filters:
the number of output channels
"""
def __init__(self, key, prev_labels, in_filters, out_filters):
super().__init__(key)
self.in_filters = in_filters
self.out_filters = out_filters
self.mutable = mutables.LayerChoice([
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if prev_labels:
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None)
else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
def forward(self, prev_list):
"""
Parameters
---
prev_list: list
The cell selects the last element of the list as input and applies an operation on it.
The cell chooses none/one/multiple tensor(s) as SkipConnect(s) from the list excluding
the last element.
"""
out = self.mutable(prev_list[-1])
if self.skipconnect is not None:
connection = self.skipconnect(prev_list[:-1])
if connection is not None:
out += connection
return self.batch_norm(out)
class ENASMacroGeneralModel(nn.Module):
"""
The network is made up by stacking ENASMacroLayer. The Macro search space contains these layers.
Each layer chooses an operation from predefined ones and SkipConnect then forms a network.
Parameters
---
num_layers: int
The number of layers contained in the network.
out_filters: int
The number of each layer's output channels.
in_channel: int
The number of input's channels.
num_classes: int
The number of classes for classification.
dropout_rate: float
Dropout layer's dropout rate before the final dense layer.
"""
def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,
dropout_rate=0.0):
super().__init__()
self.num_layers = num_layers
self.num_classes = num_classes
self.out_filters = out_filters
self.stem = nn.Sequential(
nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_filters)
)
pool_distance = self.num_layers // 3
self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
self.dropout_rate = dropout_rate
self.dropout = nn.Dropout(self.dropout_rate)
self.layers = nn.ModuleList()
self.pool_layers = nn.ModuleList()
labels = []
for layer_id in range(self.num_layers):
labels.append("layer_{}".format(layer_id))
if layer_id in self.pool_layers_idx:
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
self.layers.append(ENASMacroLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters))
self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(self.out_filters, self.num_classes)
def forward(self, x):
"""
Parameters
---
x: torch.Tensor
the input of the network
"""
bs = x.size(0)
cur = self.stem(x)
layers = [cur]
for layer_id in range(self.num_layers):
cur = self.layers[layer_id](layers)
layers.append(cur)
if layer_id in self.pool_layers_idx:
for i, layer in enumerate(layers):
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
cur = layers[-1]
cur = self.gap(cur).view(bs, -1)
cur = self.dropout(cur)
logits = self.dense(cur)
return logits
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
class StdConv(nn.Module):
def __init__(self, C_in, C_out):
super(StdConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
class PoolBranch(nn.Module):
"""
Pooling structure for Macro search. First pass through a 1x1 Conv, then pooling operation followed by BatchNorm2d.
Parameters
---
pool_type: str
only accept ``max`` for MaxPool and ``avg`` for AvgPool
C_in: int
the number of input channels
C_out: int
the number of output channels
kernal_size: int
size of the convolving kernel
stride: int
stride of the convolution
padding: int
zero-padding added to both sides of the input
"""
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__()
self.preproc = StdConv(C_in, C_out)
self.pool = Pool(pool_type, kernel_size, stride, padding)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = self.preproc(x)
out = self.pool(out)
out = self.bn(out)
return out
class SeparableConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SeparableConv, self).__init__()
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
groups=C_in, bias=False)
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class ConvBranch(nn.Module):
"""
Conv structure for Macro search. First pass through a 1x1 Conv,
then Conv operation with kernal_size equals 3 or 5 followed by BatchNorm and ReLU.
Parameters
---
C_in: int
the number of input channels
C_out: int
the number of output channels
kernal_size: int
size of the convolving kernel
stride: int
stride of the convolution
padding: int
zero-padding added to both sides of the input
separable: True
is separable Conv is used
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
super(ConvBranch, self).__init__()
self.preproc = StdConv(C_in, C_out)
if separable:
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
else:
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
self.postproc = nn.Sequential(
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
out = self.preproc(x)
out = self.conv(out)
out = self.postproc(out)
return out
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=False):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
class Pool(nn.Module):
"""
Pooling structure
Parameters
---
pool_type: str
only accept ``max`` for MaxPool and ``avg`` for AvgPool
kernal_size: int
size of the convolving kernel
stride: int
stride of the convolution
padding: int
zero-padding added to both sides of the input
"""
def __init__(self, pool_type, kernel_size, stride, padding):
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
def forward(self, x):
return self.pool(x)
class SepConvBN(nn.Module):
"""
Implement SepConv followed by BatchNorm. The structure is ReLU ==> SepConv ==> BN.
Parameters
---
C_in: int
the number of imput channels
C_out: int
the number of output channels
kernal_size: int
size of the convolving kernel
padding: int
zero-padding added to both sides of the input
"""
def __init__(self, C_in, C_out, kernel_size, padding):
super().__init__()
self.relu = nn.ReLU()
self.conv = SeparableConv(C_in, C_out, kernel_size, 1, padding)
self.bn = nn.BatchNorm2d(C_out, affine=True)
def forward(self, x):
x = self.relu(x)
x = self.conv(x)
x = self.bn(x)
return x
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import logging
import os
import time
from abc import abstractmethod
import torch
from .base_trainer import BaseTrainer
_logger = logging.getLogger(__name__)
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
olist = o.tolist()
if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
_logger.warning("Every element in %s is either 0 or 1. "
"You might consider convert it into bool.", olist)
return olist
return super().default(o)
class Trainer(BaseTrainer):
"""
A trainer with some helper functions implemented. To implement a new trainer,
users need to implement :meth:`train_one_epoch`, :meth:`validate_one_epoch` and :meth:`checkpoint`.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : BaseMutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
See `PyTorch loss functions`_ for examples.
metrics : callable
Called with logits and targets. Returns a dict that maps metrics keys to metrics data. For example,
.. code-block:: python
def metrics_fn(output, target):
return {"acc1": accuracy(output, target, topk=1), "acc5": accuracy(output, target, topk=5)}
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : torch.utils.data.Dataset
Dataset of training. If not otherwise specified, ``dataset_train`` and ``dataset_valid`` should be standard
PyTorch Dataset. See `torch.utils.data`_ for examples.
dataset_valid : torch.utils.data.Dataset
Dataset of validation/testing.
batch_size : int
Batch size.
workers : int
Number of workers used in data preprocessing.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
.. _`PyTorch loss functions`: https://pytorch.org/docs/stable/nn.html#loss-functions
.. _`torch.utils.data`: https://pytorch.org/docs/stable/data.html
"""
def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model
self.mutator = mutator
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.model.to(self.device)
self.mutator.to(self.device)
self.loss.to(self.device)
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.batch_size = batch_size
self.workers = workers
self.log_frequency = log_frequency
self.log_dir = os.path.join("logs", str(time.time()))
os.makedirs(self.log_dir, exist_ok=True)
self.status_writer = open(os.path.join(self.log_dir, "log"), "w")
self.callbacks = callbacks if callbacks is not None else []
for callback in self.callbacks:
callback.build(self.model, self.mutator, self)
@abstractmethod
def train_one_epoch(self, epoch):
"""
Train one epoch.
Parameters
----------
epoch : int
Epoch number starting from 0.
"""
pass
@abstractmethod
def validate_one_epoch(self, epoch):
"""
Validate one epoch.
Parameters
----------
epoch : int
Epoch number starting from 0.
"""
pass
def train(self, validate=True):
"""
Train ``num_epochs``.
Trigger callbacks at the start and the end of each epoch.
Parameters
----------
validate : bool
If ``true``, will do validation every epoch.
"""
for epoch in range(self.num_epochs):
for callback in self.callbacks:
callback.on_epoch_begin(epoch)
# training
_logger.info("Epoch %d Training", epoch + 1)
self.train_one_epoch(epoch)
if validate:
# validation
_logger.info("Epoch %d Validating", epoch + 1)
self.validate_one_epoch(epoch)
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def validate(self):
"""
Do one validation.
"""
self.validate_one_epoch(-1)
def export(self, file):
"""
Call ``mutator.export()`` and dump the architecture to ``file``.
Parameters
----------
file : str
A file path. Expected to be a JSON.
"""
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self):
"""
Return trainer checkpoint.
"""
raise NotImplementedError("Not implemented yet")
def enable_visualization(self):
"""
Enable visualization. Write graph and training log to folder ``logs/<timestamp>``.
"""
sample = None
for x, _ in self.train_loader:
sample = x.to(self.device)[:2]
break
if sample is None:
_logger.warning("Sample is %s.", sample)
_logger.info("Creating graph json, writing to %s. Visualization enabled.", self.log_dir)
with open(os.path.join(self.log_dir, "graph.json"), "w") as f:
json.dump(self.mutator.graph(sample), f)
self.visualization_enabled = True
def _write_graph_status(self):
if hasattr(self, "visualization_enabled") and self.visualization_enabled:
print(json.dumps(self.mutator.status()), file=self.status_writer, flush=True)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import OrderedDict
import numpy as np
import torch
_counter = 0
_logger = logging.getLogger(__name__)
def global_mutable_counting():
"""
A program level counter starting from 1.
"""
global _counter
_counter += 1
return _counter
def _reset_global_mutable_counting():
"""
Reset the global mutable counting to count from 1. Useful when defining multiple models with default keys.
"""
global _counter
_counter = 0
def to_device(obj, device):
"""
Move a tensor, tuple, list, or dict onto device.
"""
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, tuple):
return tuple(to_device(t, device) for t in obj)
if isinstance(obj, list):
return [to_device(t, device) for t in obj]
if isinstance(obj, dict):
return {k: to_device(v, device) for k, v in obj.items()}
if isinstance(obj, (int, float, str)):
return obj
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))
def to_list(arr):
if torch.is_tensor(arr):
return arr.cpu().numpy().tolist()
if isinstance(arr, np.ndarray):
return arr.tolist()
if isinstance(arr, (list, tuple)):
return list(arr)
return arr
class AverageMeterGroup:
"""
Average meter group for multiple average meters.
"""
def __init__(self):
self.meters = OrderedDict()
def update(self, data):
"""
Update the meter group with a dict of metrics.
Non-exist average meters will be automatically created.
"""
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k, ":4f")
self.meters[k].update(v)
def __getattr__(self, item):
return self.meters[item]
def __getitem__(self, item):
return self.meters[item]
def __str__(self):
return " ".join(str(v) for v in self.meters.values())
def summary(self):
"""
Return a summary string of group data.
"""
return " ".join(v.summary() for v in self.meters.values())
class AverageMeter:
"""
Computes and stores the average and current value.
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
"""
Reset the meter.
"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""
Update with value and weight.
Parameters
----------
val : float or int
The new value to be accounted in.
n : int
The weight of the new value.
"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = '{name}: {avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
class StructuredMutableTreeNode:
"""
A structured representation of a search space.
A search space comes with a root (with `None` stored in its `mutable`), and a bunch of children in its `children`.
This tree can be seen as a "flattened" version of the module tree. Since nested mutable entity is not supported yet,
the following must be true: each subtree corresponds to a ``MutableScope`` and each leaf corresponds to a
``Mutable`` (other than ``MutableScope``).
Parameters
----------
mutable : nni.nas.pytorch.mutables.Mutable
The mutable that current node is linked with.
"""
def __init__(self, mutable):
self.mutable = mutable
self.children = []
def add_child(self, mutable):
"""
Add a tree node to the children list of current node.
"""
self.children.append(StructuredMutableTreeNode(mutable))
return self.children[-1]
def type(self):
"""
Return the ``type`` of mutable content.
"""
return type(self.mutable)
def __iter__(self):
return self.traverse()
def traverse(self, order="pre", deduplicate=True, memo=None):
"""
Return a generator that generates a list of mutables in this tree.
Parameters
----------
order : str
pre or post. If pre, current mutable is yield before children. Otherwise after.
deduplicate : bool
If true, mutables with the same key will not appear after the first appearance.
memo : dict
An auxiliary dict that memorize keys seen before, so that deduplication is possible.
Returns
-------
generator of Mutable
"""
if memo is None:
memo = set()
assert order in ["pre", "post"]
if order == "pre":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
for child in self.children:
for m in child.traverse(order=order, deduplicate=deduplicate, memo=memo):
yield m
if order == "post":
if self.mutable is not None:
if not deduplicate or self.mutable.key not in memo:
memo.add(self.mutable.key)
yield self.mutable
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .base import BaseStrategy
from .bruteforce import Random, GridSearch
from .evolution import RegularizedEvolution
from .hpo import TPEStrategy, TPE
from .rl import PolicyBasedRL
from .oneshot import DARTS, Proxyless, GumbelDARTS, ENAS, RandomOneShot
......@@ -21,9 +21,9 @@ from tianshou.env.worker import EnvWorker
from typing_extensions import TypedDict
from nni.nas.execution import submit_models, wait_models
from nni.nas.execution.common import ModelStatus
from .utils import get_targeted_model
from ..graph import ModelStatus
from ..execution import submit_models, wait_models
_logger = logging.getLogger(__name__)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import abc
from typing import List, Any
from nni.nas.execution.common import Model
from nni.nas.mutable import Mutator
class BaseStrategy(abc.ABC):
@abc.abstractmethod
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass
def export_top_models(self, top_k: int) -> List[Any]:
raise NotImplementedError('"export_top_models" is not implemented.')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment