Unverified Commit a0fd0036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
parents d6dcb483 bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Optional, Tuple, Union
import torch.nn as nn
from nni.nas.utils import NoContextError, ModelNamespace, get_current_context
class Mutable(nn.Module):
"""
This is just an implementation trick for now.
In future, this could be the base class for all PyTorch mutables including layer choice, input choice, etc.
This is not considered as an interface, but rather as a base class consisting of commonly used class/instance methods.
For API developers, it's not recommended to use ``isinstance(module, Mutable)`` to check for mutable modules either,
before the design is finalized.
"""
def __new__(cls, *args, **kwargs):
if not args and not kwargs:
# this can be the case of copy/deepcopy
# attributes are assigned afterwards in __dict__
return super().__new__(cls)
try:
return cls.create_fixed_module(*args, **kwargs)
except NoContextError:
return super().__new__(cls)
@classmethod
def create_fixed_module(cls, *args, **kwargs) -> Union[nn.Module, Any]:
"""
Try to create a fixed module from fixed dict.
If the code is running in a trial, this method would succeed, and a concrete module instead of a mutable will be created.
Raises no context error if the creation failed.
"""
raise NotImplementedError
def generate_new_label(label: Optional[str]):
if label is None:
return ModelNamespace.next_label()
return label
def get_fixed_value(label: Optional[str]) -> Any:
ret = get_current_context('fixed')
try:
return ret[generate_new_label(label)]
except KeyError:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
def get_fixed_dict(label_prefix: Optional[str]) -> Tuple[str, Any]:
ret = get_current_context('fixed')
try:
label_prefix = generate_new_label(label_prefix)
ret = {k: v for k, v in ret.items() if k.startswith(label_prefix + '/')}
if not ret:
raise KeyError
return label_prefix, ret
except KeyError:
raise KeyError(f'Fixed context with prefix {label_prefix} not found. Existing values are: {ret}')
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import inspect
from typing import Any, List, Optional, Tuple, Dict, Iterator, Iterable, cast
import torch.nn as nn
from nni.common.serializer import is_traceable, is_wrapped_with_trace
from nni.nas.execution.common.graph import Graph, Model, ModelStatus, Node, Evaluator
from nni.nas.execution.common.graph_op import Cell
from nni.nas.hub.pytorch.modules import NasBench101Cell, NasBench101Mutator
from nni.nas.mutable import Mutator
from nni.nas.utils import is_basic_unit, is_model_wrapped, ModelNamespace, uid
from .choice import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder
class LayerChoiceMutator(Mutator):
def __init__(self, nodes: List[Node]):
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes
def mutate(self, model):
candidates = self.nodes[0].operation.parameters['candidates']
chosen = self.choice(candidates)
for node in self.nodes:
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
# We add the connections here in the mutation logic.
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target = model.graphs[cast(Cell, node.operation).cell_name]
chosen_node = target.get_node_by_name(chosen)
assert chosen_node is not None
target.add_edge((target.input_node, 0), (chosen_node, None))
target.add_edge((chosen_node, None), (target.output_node, None))
operation = cast(Cell, node.operation)
target_node = cast(Node, model.get_node_by_name(node.name))
target_node.update_operation(Cell(operation.cell_name))
# remove redundant nodes
for rm_node in list(target.hidden_nodes): # remove from a list on the fly will cause issues
if rm_node.name != chosen_node.name:
rm_node.remove()
class InputChoiceMutator(Mutator):
def __init__(self, nodes: List[Node]):
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes
def mutate(self, model):
n_candidates = self.nodes[0].operation.parameters['n_candidates']
n_chosen = self.nodes[0].operation.parameters['n_chosen']
candidates = list(range(n_candidates))
if n_chosen is None:
chosen = [i for i in candidates if self.choice([False, True])]
# FIXME This is a hack to make choice align with the previous format
self._cur_samples = chosen
else:
chosen = [self.choice(candidates) for _ in range(n_chosen)]
for node in self.nodes:
target = cast(Node, model.get_node_by_name(node.name))
target.update_operation('__torch__.nni.nas.nn.pytorch.ChosenInputs',
{'chosen': chosen, 'reduction': node.operation.parameters['reduction']})
class ValueChoiceMutator(Mutator):
def __init__(self, nodes: List[Node], candidates: List[Any]):
# use nodes[0] as an example to get label
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes
self.candidates = candidates
def mutate(self, model):
chosen = self.choice(self.candidates)
# no need to support transformation here,
# because it is naturally done in forward loop
for node in self.nodes:
target = cast(Node, model.get_node_by_name(node.name))
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
class ParameterChoiceLeafMutator(Mutator):
# mutate the leaf node (i.e., ValueChoice) of parameter choices
# should be used together with ParameterChoiceMutator
def __init__(self, candidates: List[Any], label: str):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> None:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
class ParameterChoiceMutator(Mutator):
# To deal with ValueChoice used as a parameter of a basic unit
# should be used together with ParameterChoiceLeafMutator
# parameter choice mutator is an empty-shell-mutator
# calculate all the parameter values based on previous mutations of value choice mutator
def __init__(self, nodes: List[Tuple[Node, str]]):
super().__init__()
self.nodes = nodes
def mutate(self, model: Model) -> None:
# looks like {"label1": "cat", "label2": 123}
value_choice_decisions = {}
for mutation in model.history:
if isinstance(mutation.mutator, ParameterChoiceLeafMutator):
value_choice_decisions[mutation.mutator.label] = mutation.samples[0]
for node, argname in self.nodes:
# argname is the location of the argument
# e.g., Conv2d(out_channels=nn.ValueChoice([1, 2, 3])) => argname = "out_channels"
value_choice: ValueChoiceX = node.operation.parameters[argname]
# calculate all the values on the leaf node of ValueChoiceX computation graph
leaf_node_values = []
for choice in value_choice.inner_choices():
leaf_node_values.append(value_choice_decisions[choice.label])
result_value = value_choice.evaluate(leaf_node_values)
# update model with graph mutation primitives
target = cast(Node, model.get_node_by_name(node.name))
target.update_operation(target.operation.type, {**target.operation.parameters, argname: result_value})
class RepeatMutator(Mutator):
def __init__(self, nodes: List[Node]):
# nodes is a subgraph consisting of repeated blocks.
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes
def _retrieve_chain_from_graph(self, graph: Graph) -> List[Node]:
u = graph.input_node
chain = []
while u != graph.output_node:
if u != graph.input_node:
chain.append(u)
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successors}.'
u = u.successors[0]
return chain
def mutate(self, model):
for node in self.nodes:
# the logic here is similar to layer choice. We find cell attached to each node.
target: Graph = model.graphs[cast(Cell, node.operation).cell_name]
chain = self._retrieve_chain_from_graph(target)
# and we get the chosen depth (by value choice)
node_in_model = cast(Node, model.get_node_by_name(node.name))
# depth is a value choice in base model
# but it's already mutated by a ParameterChoiceMutator here
chosen_depth: int = node_in_model.operation.parameters['depth']
for edge in chain[chosen_depth - 1].outgoing_edges:
edge.remove()
target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None))
for rm_node in chain[chosen_depth:]:
for edge in rm_node.outgoing_edges:
edge.remove()
rm_node.remove()
# to delete the unused parameters.
target_node = cast(Node, model.get_node_by_name(node.name))
cell_operation = cast(Cell, node.operation)
target_node.update_operation(Cell(cell_operation.cell_name))
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
applied_mutators = []
ic_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.nas.nn.pytorch.choice.InputChoice'))
for node_list in ic_nodes:
assert _is_all_equal(map(lambda node: node.operation.parameters['n_candidates'], node_list)) and \
_is_all_equal(map(lambda node: node.operation.parameters['n_chosen'], node_list)), \
'Input choice with the same label must have the same number of candidates.'
mutator = InputChoiceMutator(node_list)
applied_mutators.append(mutator)
vc_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.nas.nn.pytorch.choice.ValueChoice'))
for node_list in vc_nodes:
assert _is_all_equal(map(lambda node: node.operation.parameters['candidates'], node_list)), \
'Value choice with the same label must have the same candidates.'
mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates'])
applied_mutators.append(mutator)
# `pc_nodes` are arguments of basic units. They can be compositions.
pc_nodes: List[Tuple[Node, str, ValueChoiceX]] = []
for node in model.get_nodes():
# arguments used in operators like Conv2d
# argument `valuechoice` used in generated repeat cell
for name, choice in node.operation.parameters.items():
if isinstance(choice, ValueChoiceX):
# e.g., (conv_node, "out_channels", ValueChoice([1, 3]))
pc_nodes.append((node, name, choice))
# Break `pc_nodes` down to leaf value choices. They should be what we want to sample.
leaf_value_choices: Dict[str, List[Any]] = {}
for _, __, choice in pc_nodes:
for inner_choice in choice.inner_choices():
if inner_choice.label not in leaf_value_choices:
leaf_value_choices[inner_choice.label] = inner_choice.candidates
else:
assert leaf_value_choices[inner_choice.label] == inner_choice.candidates, \
'Value choice with the same label must have the same candidates, but found ' \
f'{leaf_value_choices[inner_choice.label]} vs. {inner_choice.candidates}'
for label, candidates in leaf_value_choices.items():
applied_mutators.append(ParameterChoiceLeafMutator(candidates, label))
# in the end, add another parameter choice mutator for "real" mutations
if pc_nodes:
applied_mutators.append(ParameterChoiceMutator([(node, name) for node, name, _ in pc_nodes]))
# apply layer choice at last as it will delete some nodes
lc_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'layerchoice',
model.get_nodes_by_type('_cell')))
for node_list in lc_nodes:
assert _is_all_equal(map(lambda node: len(node.operation.parameters['candidates']), node_list)), \
'Layer choice with the same label must have the same number of candidates.'
mutator = LayerChoiceMutator(node_list)
applied_mutators.append(mutator)
repeat_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'repeat',
model.get_nodes_by_type('_cell')))
for node_list in repeat_nodes:
# this check is not completely reliable, because it only checks max and min
assert _is_all_equal(map(lambda node: node.operation.parameters['max_depth'], node_list)) and \
_is_all_equal(map(lambda node: node.operation.parameters['min_depth'], node_list)), \
'Repeat with the same label must have the same candidates.'
mutator = RepeatMutator(node_list)
applied_mutators.append(mutator)
if applied_mutators:
return applied_mutators
return None
# The following are written for pure-python mode
class ManyChooseManyMutator(Mutator):
"""
Choose based on labels. Will not affect the model itself.
"""
def __init__(self, label: str):
super().__init__(label=label)
@staticmethod
def candidates(node):
if 'n_candidates' in node.operation.parameters:
return list(range(node.operation.parameters['n_candidates']))
else:
return node.operation.parameters['candidates']
@staticmethod
def number_of_chosen(node):
if 'n_chosen' in node.operation.parameters:
return node.operation.parameters['n_chosen']
return 1
def mutate(self, model: Model) -> None:
# this mutate does not have any effect, but it is recorded in the mutation history
for node in model.get_nodes_by_label(self.label):
n_chosen = self.number_of_chosen(node)
if n_chosen is None:
candidates = [i for i in self.candidates(node) if self.choice([False, True])]
# FIXME This is a hack to make choice align with the previous format
# For example, it will convert [False, True, True] into [1, 2].
self._cur_samples = candidates
else:
for _ in range(n_chosen):
self.choice(self.candidates(node))
break
def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Optional[List[Mutator]]]:
model = Model(_internal=True)
graph = Graph(model, uid(), '_model', _internal=True)._register()
model.python_class = pytorch_model.__class__
if len(inspect.signature(model.python_class.__init__).parameters) > 1:
if not is_model_wrapped(pytorch_model):
raise ValueError('Please annotate the model with @model_wrapper decorator in python execution mode '
'if your model has init parameters.')
model.python_init_params = cast(dict, pytorch_model.trace_kwargs)
else:
model.python_init_params = {}
# hyper-parameter choice
namespace: ModelNamespace = cast(ModelNamespace, pytorch_model._model_namespace)
for param_spec in namespace.parameter_specs:
assert param_spec.categorical and param_spec.type == 'choice'
node = graph.add_node(f'param_spec_{param_spec.name}', 'ModelParameterChoice', {'candidates': param_spec.values})
node.label = param_spec.name
for name, module in pytorch_model.named_modules():
# tricky case: value choice that serves as parameters are stored in traced arguments
if is_basic_unit(module):
trace_kwargs = cast(Dict[str, Any], module.trace_kwargs)
for key, value in trace_kwargs.items():
if isinstance(value, ValueChoiceX):
for i, choice in enumerate(value.inner_choices()):
node = graph.add_node(f'{name}.init.{key}.{i}', 'ValueChoice', {'candidates': choice.candidates})
node.label = choice.label
if isinstance(module, (LayerChoice, InputChoice, ValueChoice)):
# TODO: check the label of module and warn if it's auto-generated
pass
if isinstance(module, LayerChoice):
node = graph.add_node(name, 'LayerChoice', {'candidates': module.names})
node.label = module.label
if isinstance(module, InputChoice):
node = graph.add_node(name, 'InputChoice',
{'n_candidates': module.n_candidates, 'n_chosen': module.n_chosen})
node.label = module.label
if isinstance(module, ValueChoiceX):
for i, choice in enumerate(module.inner_choices()):
node = graph.add_node(f'{name}.{i}', 'ValueChoice', {'candidates': choice.candidates})
node.label = choice.label
if isinstance(module, NasBench101Cell):
node = graph.add_node(name, 'NasBench101Cell', {
'max_num_edges': module.max_num_edges
})
node.label = module.label
if isinstance(module, Placeholder):
raise NotImplementedError('Placeholder is not supported in python execution mode.')
model.status = ModelStatus.Frozen
if not graph.hidden_nodes:
return model, None
mutators = []
mutators_final = []
for nodes in _group_by_label_and_type(graph.hidden_nodes):
label = nodes[0].label
assert label is not None, f'label of {nodes[0]} can not be None.'
assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
f'Node with label "{label}" does not all have the same type.'
assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
f'Node with label "{label}" does not agree on parameters.'
if nodes[0].operation.type == 'NasBench101Cell':
# The mutation of Nas-bench-101 is special, and has to be done lastly.
mutators_final.append(NasBench101Mutator(label))
else:
mutators.append(ManyChooseManyMutator(label))
return model, mutators + mutators_final
# mutations for evaluator
class EvaluatorValueChoiceLeafMutator(Mutator):
# see "ParameterChoiceLeafMutator"
# works in the same way
def __init__(self, candidates: List[Any], label: str):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> None:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
class EvaluatorValueChoiceMutator(Mutator):
# works in the same way as `ParameterChoiceMutator`
# we only need one such mutator for one model/evaluator
def _mutate_traceable_object(self, obj: Any, value_choice_decisions: Dict[str, Any]) -> Any:
if not _is_traceable_object(obj):
return obj
updates = {}
# For each argument that is a composition of value choice
# we find all the leaf-value-choice in the mutation
# and compute the final updates
for key, param in obj.trace_kwargs.items():
if isinstance(param, ValueChoiceX):
leaf_node_values = [value_choice_decisions[choice.label] for choice in param.inner_choices()]
updates[key] = param.evaluate(leaf_node_values)
elif is_traceable(param):
# Recursively
sub_update = self._mutate_traceable_object(param, value_choice_decisions)
if sub_update is not param: # if mutated
updates[key] = sub_update
if updates:
mutated_obj = obj.trace_copy() # Make a copy
mutated_obj.trace_kwargs.update(updates) # Mutate
mutated_obj = mutated_obj.get() # Instantiate the full mutated object
return mutated_obj
return obj
def mutate(self, model: Model) -> None:
value_choice_decisions = {}
for mutation in model.history:
if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator):
value_choice_decisions[mutation.mutator.label] = mutation.samples[0]
model.evaluator = self._mutate_traceable_object(model.evaluator, value_choice_decisions)
def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]:
# take all the value choice in the kwargs of evaluaator into a list
# `existing_mutators` can mutators generated from `model`
if not _is_traceable_object(evaluator):
return []
mutator_candidates = {}
for param in _expand_nested_trace_kwargs(evaluator):
if isinstance(param, ValueChoiceX):
for choice in param.inner_choices():
# merge duplicate labels
for mutator in existing_mutators:
if mutator.label == choice.label:
raise ValueError(
f'Found duplicated labels “{choice.label}”. When two value choices have the same name, '
'they would share choices. However, sharing choices between model and evaluator is not supported.'
)
if choice.label in mutator_candidates and mutator_candidates[choice.label] != choice.candidates:
raise ValueError(
f'Duplicate labels for evaluator ValueChoice {choice.label}. They should share choices.'
f'But their candidate list is not equal: {mutator_candidates[choice.label][1]} vs. {choice.candidates}'
)
mutator_candidates[choice.label] = choice.candidates
mutators = []
for label, candidates in mutator_candidates.items():
mutators.append(EvaluatorValueChoiceLeafMutator(candidates, label))
if mutators:
# one last mutator to actually apply the mutations
mutators.append(EvaluatorValueChoiceMutator())
return mutators
# the following are written for one-shot mode
# they shouldn't technically belong here, but all other engines are written here
# let's refactor later
def process_oneshot_mutations(base_model: nn.Module, evaluator: Evaluator):
# It's not intuitive, at all, (actually very hacky) to wrap a `base_model` and `evaluator` into a graph.Model.
# But unfortunately, this is the required interface of strategy.
model = Model(_internal=True)
model.python_object = base_model
# no need to set evaluator here because it will be set after this method is called
return model, []
# utility functions
def _is_all_equal(lst):
last = None
for x in lst:
if last is not None and last != x:
return False
last = x
return True
def _group_by_label_and_type(nodes: Iterable[Node]) -> List[List[Node]]:
result = {}
for node in nodes:
key = (node.label, node.operation.type)
if key not in result:
result[key] = []
result[key].append(node)
return list(result.values())
def _group_by_label(nodes: Iterable[Node]) -> List[List[Node]]:
result = {}
for node in nodes:
label = node.operation.parameters['label']
if label not in result:
result[label] = []
result[label].append(node)
return list(result.values())
def _expand_nested_trace_kwargs(obj: Any) -> Iterator[Any]:
# Get items from `trace_kwargs`.
# If some item is traceable itself, get items recursively.
if _is_traceable_object(obj):
for param in obj.trace_kwargs.values():
yield param
yield from _expand_nested_trace_kwargs(param)
def _is_traceable_object(obj: Any) -> bool:
# Is it a traceable "object" (not class)?
return is_traceable(obj) and not is_wrapped_with_trace(obj)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import warnings
from typing import Callable, List, Union, Tuple, Optional
import torch.nn as nn
from nni.nas.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
from .choice import ValueChoice, ValueChoiceX, ChoiceOf
from .mutation_utils import Mutable, get_fixed_value
__all__ = ['Repeat']
class Repeat(Mutable):
"""
Repeat a block by a variable number of times.
Parameters
----------
blocks : function, list of function, module or list of module
The block to be repeated. If not a list, it will be replicated (**deep-copied**) into a list.
If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken.
If a function, it will be called (the argument is the index) to instantiate a module.
Otherwise the module will be deep-copied.
depth : int or tuple of int
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
meaning that the block will be repeated at least ``min`` times and at most ``max`` times.
If a ValueChoice, it should choose from a series of positive integers.
.. versionadded:: 2.8
Minimum depth can be 0. But this feature is NOT supported on graph engine.
Examples
--------
Block() will be deep copied and repeated 3 times. ::
self.blocks = nn.Repeat(Block(), 3)
Block() will be repeated 1, 2, or 3 times. ::
self.blocks = nn.Repeat(Block(), (1, 3))
Can be used together with layer choice.
With deep copy, the 3 layers will have the same label, thus share the choice. ::
self.blocks = nn.Repeat(nn.LayerChoice([...]), (1, 3))
To make the three layer choices independent,
we need a factory function that accepts index (0, 1, 2, ...) and returns the module of the ``index``-th layer. ::
self.blocks = nn.Repeat(lambda index: nn.LayerChoice([...], label=f'layer{index}'), (1, 3))
Depth can be a ValueChoice to support arbitrary depth candidate list. ::
self.blocks = nn.Repeat(Block(), nn.ValueChoice([1, 3, 5]))
"""
@classmethod
def create_fixed_module(cls,
blocks: Union[Callable[[int], nn.Module],
List[Callable[[int], nn.Module]],
nn.Module,
List[nn.Module]],
depth: Union[int, Tuple[int, int], ChoiceOf[int]], *, label: Optional[str] = None):
if isinstance(depth, tuple):
# we can't create a value choice here,
# otherwise we will have two value choices, one created here, another in init.
depth = get_fixed_value(label)
if isinstance(depth, int):
# if depth is a valuechoice, it should be already an int
result = nn.Sequential(*cls._replicate_and_instantiate(blocks, depth))
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'blocks.{v}' for k, v in prev_mapping.items()})
else:
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': 'blocks'})
return result
raise NoContextError(f'Not in fixed mode, or {depth} not an integer.')
def __init__(self,
blocks: Union[Callable[[int], nn.Module],
List[Callable[[int], nn.Module]],
nn.Module,
List[nn.Module]],
depth: Union[int, Tuple[int, int], ChoiceOf[int]], *, label: Optional[str] = None):
super().__init__()
self._label = None # by default, no label
if isinstance(depth, ValueChoiceX):
if label is not None:
warnings.warn(
'In repeat, `depth` is already a ValueChoice, but `label` is still set. It will be ignored.',
RuntimeWarning
)
self.depth_choice: Union[int, ChoiceOf[int]] = depth
all_values = list(self.depth_choice.all_options())
self.min_depth = min(all_values)
self.max_depth = max(all_values)
if isinstance(depth, ValueChoice):
self._label = depth.label # if a leaf node
elif isinstance(depth, tuple):
self.min_depth = depth if isinstance(depth, int) else depth[0]
self.max_depth = depth if isinstance(depth, int) else depth[1]
self.depth_choice: Union[int, ChoiceOf[int]] = ValueChoice(list(range(self.min_depth, self.max_depth + 1)), label=label)
self._label = self.depth_choice.label
elif isinstance(depth, int):
self.min_depth = self.max_depth = depth
self.depth_choice: Union[int, ChoiceOf[int]] = depth
else:
raise TypeError(f'Unsupported "depth" type: {type(depth)}')
assert self.max_depth >= self.min_depth >= 0 and self.max_depth >= 1, f'Depth of {self.min_depth} to {self.max_depth} is invalid.'
self.blocks = nn.ModuleList(self._replicate_and_instantiate(blocks, self.max_depth))
@property
def label(self) -> Optional[str]:
return self._label
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@staticmethod
def _replicate_and_instantiate(blocks, repeat):
if not isinstance(blocks, list):
if isinstance(blocks, nn.Module):
blocks = [blocks if i == 0 else copy.deepcopy(blocks) for i in range(repeat)]
else:
blocks = [blocks for _ in range(repeat)]
assert repeat <= len(blocks), f'Not enough blocks to be used. {repeat} expected, only found {len(blocks)}.'
if repeat < len(blocks):
blocks = blocks[:repeat]
if len(blocks) > 0 and not isinstance(blocks[0], nn.Module):
blocks = [b(i) for i, b in enumerate(blocks)]
return blocks
def __getitem__(self, index):
# shortcut for blocks[index]
return self.blocks[index]
def __len__(self):
return self.max_depth
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from tensorflow.keras import Layer
class LayerChoice(Layer):
# FIXME: This is only a draft to test multi-framework support, it's not unimplemented at all.
pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasLightningModule, RandomSamplingLightningModule
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import warnings
from itertools import chain
from typing import Callable, Any, Dict, Union, Tuple, List, cast
import pytorch_lightning as pl
import torch.optim as optim
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import nni.nas.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.typehint import Literal
from .supermodule.base import BaseSuperNetModule
__all__ = [
'MutationHook',
'BaseSuperNetModule',
'BaseOneShotLightningModule',
'traverse_and_mutate_submodules',
'no_default_hook'
]
MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
def traverse_and_mutate_submodules(
root_module: nn.Module, hooks: list[MutationHook], mutate_kwargs: dict[str, Any], topdown: bool = True
) -> list[BaseSuperNetModule]:
"""
Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node.
Parameters
----------
root_module : nn.Module
User-defined model space.
Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`,
it's usually a ``pytorch_lightning.LightningModule``.
The mutation will be in-place on ``root_module``.
hooks : list[MutationHook]
List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks.
When a hook returns an module, the module will be replaced (mutated) to the new module.
mutate_kwargs : dict
Extra keyword arguments passed to hooks.
topdown : bool, default = False
If topdown is true, hooks are first called, before traversing its sub-module (i.e., pre-order DFS).
Otherwise, sub-modules are first traversed, before calling hooks on this node (i.e., post-order DFS).
Returns
----------
modules : dict[str, nn.Module]
The replace result.
"""
memo = {}
module_list = []
def apply(m):
# Need to call list() here because the loop body might replace some children in-place.
for name, child in list(m.named_children()):
# post-order DFS
if not topdown:
apply(child)
mutate_result = None
for hook in hooks:
hook_suggest = hook(child, name, memo, mutate_kwargs)
# parse the mutate result
if isinstance(hook_suggest, tuple):
hook_suggest, suppress = hook_suggest
elif hook_suggest is True:
hook_suggest, suppress = None, True
elif not hook_suggest: # none / false
hook_suggest, suppress = None, False
elif isinstance(hook_suggest, nn.Module):
suppress = True
else:
raise TypeError(f'Mutation hook returned {hook_suggest} of unsupported type: {type(hook_suggest)}.')
if hook_suggest is not None:
if not isinstance(hook_suggest, BaseSuperNetModule):
warnings.warn("Mutation hook didn't return a BaseSuperNetModule. It will be ignored in hooked module list.",
RuntimeWarning)
setattr(m, name, hook_suggest)
mutate_result = hook_suggest
# if suppress, no further mutation hooks are called
if suppress:
break
if isinstance(mutate_result, BaseSuperNetModule):
# Replace child with the mutate result, and DFS this one
child = mutate_result
module_list.append(mutate_result)
# pre-order DFS
if topdown:
apply(child)
apply(root_module)
return module_list
def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> bool:
"""Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""
# Forward IS NOT supernet
primitive_list = (
nas_nn.LayerChoice,
nas_nn.InputChoice,
nas_nn.Repeat,
# nas_nn.NasBench101Cell, # FIXME: nasbench101 is moved to hub, can't check any more.
# nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet
)
if isinstance(module, primitive_list):
raise TypeError(f'{type(module).__name__} is not supported')
if isinstance(module, nas_nn.Cell) and module.merge_op != 'all':
# need output_node_indices, which depends on super-net
raise TypeError(f'Cell with merge_op `{module.merge_op}` is not supported')
if is_traceable(module):
# check whether there is a value-choice in its arguments
has_valuechoice = False
for arg in chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()):
if isinstance(arg, ValueChoiceX):
has_valuechoice = True
break
if has_valuechoice:
raise TypeError(f'`basic_unit` {type(module).__name__} with value choice in its arguments is not supported. '
'Please try to remove `basic_unit` to see if that works, or support this type with value choice manually.')
return True # suppress all other hooks
class BaseOneShotLightningModule(pl.LightningModule):
_mutation_hooks_note = """mutation_hooks : list[MutationHook]
Extra mutation hooks to support customized mutation on primitives other than built-ins.
Mutation hooks are callable that inputs an Module and returns a
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.nas.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
stored in :attr:`nas_modules`, and be the focus of the NAS algorithm.
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
To be more specific, the input arguments are four arguments:
1. a module that might be processed,
2. name of the module in its parent module,
3. a memo dict whose usage depends on the particular algorithm.
4. keyword arguments (configurations).
Note that the memo should be read/written by hooks.
There won't be any hooks called on root module.
The returned arguments can be also one of the three kinds:
1. tuple of: :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
2. boolean,
3. :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
The boolean value is ``suppress`` indicates whether the following hooks should be called.
When it's true, it suppresses the subsequent hooks, and they will never be invoked.
Without boolean value specified, it's assumed to be false.
If a none value appears on the place of
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
it means the hook suggests to
keep the module unchanged, and nothing will happen.
An example of mutation hook is given in :func:`~nni.nas.oneshot.pytorch.base_lightning.no_default_hook`.
However it's recommended to implement mutation hooks by deriving
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
and add its classmethod ``mutate`` to this list.
"""
_inner_module_note = """inner_module : pytorch_lightning.LightningModule
It's a `LightningModule <https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html>`__
that defines computations, train/val loops, optimizers in a single class.
When used in NNI, the ``inner_module`` is the combination of instances of evaluator + base model
(to be precise, a base model wrapped with LightningModule in evaluator).
"""
__doc__ = """
The base class for all one-shot NAS modules.
In NNI, we try to separate the "search" part and "training" part in one-shot NAS.
The "training" part is defined with evaluator interface (has to be lightning evaluator interface to work with oneshot).
Since the lightning evaluator has already broken down the training into minimal building blocks,
we can re-assemble them after combining them with the "search" part of a particular algorithm.
After the re-assembling, this module has defined all the search + training. The experiment can use a lightning trainer
(which is another part in the evaluator) to train this module, so as to complete the search process.
Essential function such as preprocessing user's model, redirecting lightning hooks for user's model,
configuring optimizers and exporting NAS result are implemented in this class.
Attributes
----------
nas_modules : list[BaseSuperNetModule]
Modules that have been mutated, which the search algorithms should care about.
model : pl.LightningModule
PyTorch lightning module. A model space with training recipe defined (wrapped by LightningModule in evaluator).
Parameters
----------
""" + _inner_module_note + _mutation_hooks_note
trainer: pl.Trainer
@property
def automatic_optimization(self) -> bool:
return False
def default_mutation_hooks(self) -> list[MutationHook]:
"""Override this to define class-default mutation hooks."""
return [no_default_hook]
def mutate_kwargs(self) -> dict[str, Any]:
"""Extra keyword arguments passed to mutation hooks. Usually algo-specific."""
return {}
def __init__(self, model: pl.LightningModule, mutation_hooks: list[MutationHook] | None = None):
super().__init__()
assert isinstance(model, pl.LightningModule)
self.model = model
# append the default hooks
mutation_hooks = (mutation_hooks or []) + self.default_mutation_hooks()
# traverse the model, calling hooks on every submodule
self.nas_modules: list[BaseSuperNetModule] = traverse_and_mutate_submodules(
self.model, mutation_hooks, self.mutate_kwargs(), topdown=True)
def search_space_spec(self) -> dict[str, ParameterSpec]:
"""Get the search space specification from :attr:`nas_modules`.
Returns
-------
dict
Key is the name of the choice, value is the corresponding :class:`ParameterSpec`.
"""
result = {}
for module in self.nas_modules:
result.update(module.search_space_spec())
return result
def resample(self) -> dict[str, Any]:
"""Trigger the resample for each :attr:`nas_modules`.
Sometimes (e.g., in differentiable cases), it does nothing.
Returns
-------
dict
Sampled architecture.
"""
result = {}
for module in self.nas_modules:
result.update(module.resample(memo=result))
return result
def export(self) -> dict[str, Any]:
"""
Export the NAS result, ideally the best choice of each :attr:`nas_modules`.
You may implement an ``export`` method for your customized :attr:`nas_modules`.
Returns
--------
dict
Keys are names of ``nas_modules``, and values are the choice indices of them.
"""
result = {}
for module in self.nas_modules:
result.update(module.export(memo=result))
return result
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
"""This is the implementation of what happens in training loops of one-shot algos.
It usually calls ``self.model.training_step`` which implements the real training recipe of the users' model.
"""
return self.model.training_step(batch, batch_idx)
def configure_optimizers(self):
"""
Combine architecture optimizers and user's model optimizers.
You can overwrite :meth:`configure_architecture_optimizers` if architecture optimizers are needed in your NAS algorithm.
For now :attr:`model` is tested against evaluators in :mod:`nni.nas.evaluator.pytorch.lightning`
and it only returns 1 optimizer.
But for extendibility, codes for other return value types are also implemented.
"""
# pylint: disable=assignment-from-none
arc_optimizers = self.configure_architecture_optimizers()
if arc_optimizers is None:
return self.model.configure_optimizers()
if isinstance(arc_optimizers, optim.Optimizer):
arc_optimizers = [arc_optimizers]
self.arc_optim_count = len(arc_optimizers)
# FIXME: this part uses non-official lightning API.
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
try:
# above v1.6
from pytorch_lightning.core.optimizer import ( # pylint: disable=import-error
_configure_optimizers, # type: ignore
_configure_schedulers_automatic_opt, # type: ignore
_configure_schedulers_manual_opt # type: ignore
)
w_optimizers, lr_schedulers, self.frequencies, monitor = \
_configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = (
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
if self.automatic_optimization
else _configure_schedulers_manual_opt(lr_schedulers)
)
except ImportError:
# under v1.5
w_optimizers, lr_schedulers, self.frequencies, monitor = \
self.trainer._configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization) # type: ignore
if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers): # type: ignore
raise Exception(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
# variables used to handle optimizer frequency
self.cur_optimizer_step = 0
self.cur_optimizer_index = 0
return arc_optimizers + w_optimizers, lr_schedulers
def on_train_start(self):
return self.model.on_train_start()
def on_train_end(self):
return self.model.on_train_end()
def on_fit_start(self):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer # type: ignore
self.model.log = self.log
return self.model.on_fit_start()
def on_fit_end(self):
return self.model.on_fit_end()
def on_train_batch_start(self, batch, batch_idx, unused=0):
return self.model.on_train_batch_start(batch, batch_idx, unused)
def on_train_batch_end(self, outputs, batch, batch_idx, unused=0):
return self.model.on_train_batch_end(outputs, batch, batch_idx, unused)
# Deprecated hooks in pytorch-lightning
def on_epoch_start(self):
return self.model.on_epoch_start()
def on_epoch_end(self):
return self.model.on_epoch_end()
def on_train_epoch_start(self):
return self.model.on_train_epoch_start()
def on_train_epoch_end(self):
return self.model.on_train_epoch_end()
def on_before_backward(self, loss):
return self.model.on_before_backward(loss)
def on_after_backward(self):
return self.model.on_after_backward()
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None):
return self.model.configure_gradient_clipping(optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm)
def configure_architecture_optimizers(self):
"""
Hook kept for subclasses. A specific NAS method inheriting this base class should return its architecture optimizers here
if architecture parameters are needed. Note that lr schedulers are not supported now for architecture_optimizers.
Returns
----------
arc_optimizers : list[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
"""
return None
def call_lr_schedulers(self, batch_index):
"""
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off
by this class, you can use this function to make schedulers behave as they were automatically handled by the lightning trainer.
Parameters
----------
batch_idx : int
batch index
"""
def apply(lr_scheduler):
# single scheduler is called every epoch
if isinstance(lr_scheduler, _LRScheduler):
if self.trainer.is_last_batch:
lr_scheduler.step()
# lr_scheduler_config is called as configured
elif isinstance(lr_scheduler, dict):
interval = lr_scheduler['interval']
frequency = lr_scheduler['frequency']
if (
interval == 'step' and
batch_index % frequency == 0
) or \
(
interval == 'epoch' and
self.trainer.is_last_batch and
(self.trainer.current_epoch + 1) % frequency == 0
):
lr_scheduler['scheduler'].step()
lr_schedulers = self.lr_schedulers()
if isinstance(lr_schedulers, list):
for lr_scheduler in lr_schedulers:
apply(lr_scheduler)
else:
apply(lr_schedulers)
def call_weight_optimizers(self, method: Literal['step', 'zero_grad']):
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Parameters
----------
method : str
Method to call. Only ``step`` and ``zero_grad`` are supported now.
"""
def apply_method(optimizer, method):
if method == 'step':
optimizer.step()
elif method == 'zero_grad':
optimizer.zero_grad()
optimizers = self.weight_optimizers()
if optimizers is None:
return
assert isinstance(optimizers, list), 'Did you forget to set use_pl_optimizers to true?'
if len(self.frequencies) > 0:
self.cur_optimizer_step += 1
if self.frequencies[self.cur_optimizer_index] == self.cur_optimizer_step:
self.cur_optimizer_step = 0
self.cur_optimizer_index = self.cur_optimizer_index + 1 \
if self.cur_optimizer_index + 1 < len(optimizers) \
else 0
apply_method(optimizers[self.cur_optimizer_index], method)
else:
for optimizer in optimizers:
apply_method(optimizer, method)
def architecture_optimizers(self) -> list[Optimizer] | Optimizer | None:
"""
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in :meth:`training_step`.
Returns
----------
opts : list[Optimizer], Optimizer, None
Architecture optimizers defined in :meth:`configure_architecture_optimizers`. This will be None if there is no
architecture optimizers.
"""
opts = self.optimizers()
if isinstance(opts, list):
# pylint: disable=unsubscriptable-object
arc_opts = opts[:self.arc_optim_count]
if len(arc_opts) == 1:
return cast(Optimizer, arc_opts[0])
return cast(List[Optimizer], arc_opts)
# If there is only 1 optimizer and it is the architecture optimizer
if self.arc_optim_count == 1:
return cast(Union[List[Optimizer], Optimizer], opts)
return None
def weight_optimizers(self) -> list[Optimizer] | Optimizer | None:
"""
Get user optimizers from all optimizers. Use this to get user optimizers in :meth:`training_step`.
Returns
----------
opts : list[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
# Since use_pl_optimizer is set true (by default) here.
# opts always return a list
opts = self.optimizers()
if isinstance(opts, list):
# pylint: disable=unsubscriptable-object
return cast(List[Optimizer], opts[self.arc_optim_count:])
# FIXME: this case is actually not correctly handled
# If there is only 1 optimizer and no architecture optimizer
if self.arc_optim_count == 0:
return cast(Union[List[Optimizer], Optimizer], opts)
return None
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import Any
from pytorch_lightning.trainer.supporters import CombinedLoader, CombinedLoaderIterator
__all__ = ['ConcatLoader']
class ConcatLoader(CombinedLoader):
"""This loader is same as CombinedLoader in PyTorch-Lightning, but concatenate sub-loaders
instead of loading them in parallel.
Parameters
----------
loaders
For example, ::
{
"train": DataLoader(train_dataset),
"val": DataLoader(val_dataset)
}
In this example, the loader will first produce the batches from "train", then "val".
mode
Only support "min_size" for now.
"""
def __init__(self, loaders: dict[str, Any], mode: str = 'min_size'):
# FIXME: max_cycle will make dataloaders cycle iterators,
# causing extra problems.
if mode != 'min_size':
raise ValueError('Only min_size mode is supported now.')
super().__init__(loaders, mode)
def __iter__(self) -> Any:
"""Replace the super-class iterator with ours."""
self._try_to_patch_pytorch_dataloader()
iterator = ConcatLoaderIterator(self.loaders)
# handle fault tolerant restart.
self.on_restart(iterator)
self._iterator = iterator
return iterator
@staticmethod
def _try_to_patch_pytorch_dataloader():
"""Copied from CombinedLoader."""
from torch.utils.data.dataloader import _BaseDataLoaderIter
# prevent `NotImplementedError` from PyTorch:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
def __getstate__patch__(*_):
return {}
_BaseDataLoaderIter.__getstate__ = __getstate__patch__ # type: ignore
def __len__(self) -> int:
return int(sum(self._calc_num_batches(loader) for loader in self.loaders.values()))
class ConcatLoaderIterator(CombinedLoaderIterator):
"""Similar to CombinedLoaderIterator in Lightning, but in a concat manner."""
def __next__(self) -> Any:
"""Fetches the next batch from multiple data loaders,
by looking for the first iterator that isn't exhausted yet.
"""
if not len(self.loader_iters) == len(self.loaders):
raise RuntimeError('loader_iters must have the same length as loaders.')
for i, (loader_name, iterator) in enumerate(self.loader_iters.items()):
try:
return (self.request_next_batch(iterator), loader_name)
except StopIteration:
if i + 1 == len(self.loader_iters):
raise
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Experimental version of differentiable one-shot implementation."""
from __future__ import annotations
import pytorch_lightning as pl
import torch
import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.differentiable import (
DifferentiableMixedLayer, DifferentiableMixedInput,
MixedOpDifferentiablePolicy, GumbelSoftmax,
DifferentiableMixedCell, DifferentiableMixedRepeat
)
from .supermodule.proxyless import ProxylessMixedInput, ProxylessMixedLayer
from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES
class DartsLightningModule(BaseOneShotLightningModule):
_darts_note = """
Continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient descent.
`Reference <https://arxiv.org/abs/1806.09055>`__.
DARTS algorithm is one of the most fundamental one-shot algorithm.
DARTS repeats iterations, where each iteration consists of 2 training phases.
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet.
.. versionadded:: 2.8
Supports searching for ValueChoices on operations, with the technique described in
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
One difference is that, in DARTS, we are using Softmax instead of GumbelSoftmax.
The supported mutation primitives of DARTS are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
)
__doc__ = _darts_note.format(
module_notes='The DARTS Module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with differentiable versions"""
hooks = [
DifferentiableMixedLayer.mutate,
DifferentiableMixedInput.mutate,
DifferentiableMixedCell.mutate,
DifferentiableMixedRepeat.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
return hooks
def mutate_kwargs(self):
"""Use differentiable strategy for mixed operations."""
return {
'mixed_op_sampling': MixedOpDifferentiablePolicy
}
def __init__(self, inner_module: pl.LightningModule,
mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0E-4):
self.arc_learning_rate = arc_learning_rate
super().__init__(inner_module, mutation_hooks=mutation_hooks)
def training_step(self, batch, batch_idx):
# grad manually
arc_optim = self.architecture_optimizers()
if not isinstance(arc_optim, optim.Optimizer):
raise TypeError(f'Expect arc_optim to be a single Optimizer, but found: {arc_optim}')
# DARTS strategy makes sure that ``train`` and ``val`` must be in the batch
trn_batch = batch['train']
val_batch = batch['val']
# phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless.
# See code of those methods for details.
self.resample()
arc_optim.zero_grad()
arc_step_loss = self.model.training_step(val_batch, 2 * batch_idx)
if isinstance(arc_step_loss, dict):
arc_step_loss = arc_step_loss['loss']
self.manual_backward(arc_step_loss)
self.finalize_grad()
arc_optim.step()
# phase 2: model step
self.resample()
self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1)
w_step_loss = loss_and_metrics['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
self.call_lr_schedulers(batch_idx)
return loss_and_metrics
def finalize_grad(self):
# Note: This hook is currently kept for Proxyless NAS.
pass
def configure_architecture_optimizers(self):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params = []
for m in self.nas_modules:
ctrl_params += list(m.parameters(arch=True)) # type: ignore
ctrl_optim = torch.optim.Adam(list(set(ctrl_params)), 3.e-4, betas=(0.5, 0.999),
weight_decay=1.0E-3)
return ctrl_optim
class ProxylessLightningModule(DartsLightningModule):
_proxyless_note = """
A low-memory-consuming optimized version of differentiable architecture search. See `reference <https://arxiv.org/abs/1812.00332>`__.
This is a DARTS-based method that resamples the architecture to reduce memory consumption.
Essentially, it samples one path on forward,
and implements its own backward to update the architecture parameters based on only one path.
The supported mutation primitives of Proxyless are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _proxyless_note.format(
module_notes='This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with gumbel-differentiable versions"""
hooks = [
ProxylessMixedLayer.mutate,
ProxylessMixedInput.mutate,
no_default_hook,
]
# FIXME: no support for mixed operation currently
return hooks
def finalize_grad(self):
for m in self.nas_modules:
m.finalize_grad() # type: ignore
class GumbelDartsLightningModule(DartsLightningModule):
_gumbel_darts_note = """
Choose the best block by using Gumbel Softmax random sampling and differentiable training.
See `FBNet <https://arxiv.org/abs/1812.03443>`__ and `SNAS <https://arxiv.org/abs/1812.09926>`__.
This is a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution.
Essentially, it tries to mimick the behavior of sampling one path on forward by gradually
cool down the temperature, aiming to bridge the gap between differentiable architecture weights and
discretization of architectures.
.. versionadded:: 2.8
Supports searching for ValueChoices on operations, with the technique described in
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
The supported mutation primitives of GumbelDARTS are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
gumbel_temperature : float
The initial temperature used in gumbel-softmax.
use_temp_anneal : bool
If true, a linear annealing will be applied to ``gumbel_temperature``.
Otherwise, run at a fixed temperature. See `SNAS <https://arxiv.org/abs/1812.09926>`__ for details.
min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
)
def mutate_kwargs(self):
"""Use gumbel softmax."""
return {
'mixed_op_sampling': MixedOpDifferentiablePolicy,
'softmax': GumbelSoftmax(),
}
def __init__(self, inner_module,
mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0e-4,
gumbel_temperature: float = 1.,
use_temp_anneal: bool = False,
min_temp: float = .33):
super().__init__(inner_module, mutation_hooks, arc_learning_rate=arc_learning_rate)
self.temp = gumbel_temperature
self.init_temp = gumbel_temperature
self.use_temp_anneal = use_temp_anneal
self.min_temp = min_temp
def on_train_epoch_end(self):
if self.use_temp_anneal:
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp)
for module in self.nas_modules:
if hasattr(module, '_softmax'):
module._softmax.temp = self.temp # type: ignore
return self.model.on_train_epoch_end()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
......@@ -29,101 +28,80 @@ class StackedLSTMCell(nn.Module):
return next_h, next_c
class EnasMutator(Mutator):
class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
class ReinforceController(nn.Module):
"""
A mutator that mutates the graph with RL.
A controller that mutates the graph with RL.
Parameters
----------
model : nn.Module
PyTorch model.
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
cell_exit_extra_step : bool
If true, RL controller will perform an extra step at the exit of each MutableScope, dump the hidden state
and mark it as the hidden state of this MutableScope. This is to align with the original implementation of paper.
skip_target : float
Target probability that skipconnect will appear.
Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float
Temperature constant that divides the logits.
branch_bias : float
Manual bias applied to make some operations more likely to be chosen.
Currently this is implemented with a hardcoded match rule that aligns with original repo.
If a mutable has a ``reduce`` in its key, all its op choices
that contains `conv` in their typename will receive a bias of ``+self.branch_bias`` initially; while others
receive a bias of ``-self.branch_bias``.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, temperature=None, branch_bias=0.25, entropy_reduction="sum"):
super().__init__(model)
def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, entropy_reduction='sum'):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target
self.branch_bias = branch_bias
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
assert entropy_reduction in ["sum", "mean"], "Entropy reduction must be one of sum and mean."
self.entropy_reduction = torch.sum if entropy_reduction == "sum" else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict()
self.max_layer_choice = 0
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = len(mutable)
assert self.max_layer_choice == len(mutable), \
"ENAS mutator requires all layer choice have the same number of candidates."
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key:
def is_conv(choice):
return "conv" in str(type(choice)).lower()
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
for choice in mutable])
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
def sample_search(self):
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
})
self.embedding = nn.ModuleDict({
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})
def resample(self):
self._initialize()
self._sample(self.mutables)
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
return result
def _initialize(self):
self._choices = dict()
self._anchors_hid = dict()
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
......@@ -131,67 +109,42 @@ class EnasMutator(Mutator):
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
self.sample_log_prob: torch.Tensor = cast(torch.Tensor, 0)
self.sample_entropy: torch.Tensor = cast(torch.Tensor, 0)
self.sample_skip_penalty: torch.Tensor = cast(torch.Tensor, 0)
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
def _sample_layer_choice(self, mutable):
def _sample_single(self, field):
self._lstm_next_step()
logit = self.soft(self._h[-1])
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
self._inputs = self.embedding(branch_id)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.temperature is not None:
query /= self.temperature
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
if mutable.n_chosen is None:
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
if field.choose_one:
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip)
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = query.view(1, -1)
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()]
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) # type: ignore
sampled = sampled.detach().cpu().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
return skip.bool()
if len(sampled) == 1:
sampled = sampled[0]
return sampled
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Experimental version of sampling-based one-shot implementation."""
from __future__ import annotations
import warnings
from typing import Any
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES
from .supermodule.sampling import (
PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy,
PathSamplingCell, PathSamplingRepeat
)
from .enas import ReinforceController, ReinforceField
class RandomSamplingLightningModule(BaseOneShotLightningModule):
_random_note = """
Train a super-net with uniform path sampling. See `reference <https://arxiv.org/abs/1904.00420>`__.
In each epoch, model parameters are trained after a uniformly random sampling of each choice.
Notably, the exporting result is **also a random sample** of the search space.
The supported mutation primitives of RandomOneShot are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
Parameters
----------
{{module_params}}
{base_params}
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
)
__doc__ = _random_note.format(
module_params=BaseOneShotLightningModule._inner_module_note,
)
# turn on automatic optimization because nothing interesting is going on here.
@property
def automatic_optimization(self) -> bool:
return True
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with differentiable versions"""
hooks = [
PathSamplingLayer.mutate,
PathSamplingInput.mutate,
PathSamplingRepeat.mutate,
PathSamplingCell.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
return hooks
def mutate_kwargs(self):
"""Use path sampling strategy for mixed-operations."""
return {
'mixed_op_sampling': MixedOpPathSamplingPolicy
}
def training_step(self, batch, batch_idx):
self.resample()
return self.model.training_step(batch, batch_idx)
def export(self) -> dict[str, Any]:
"""
Export of Random one-shot. It will return an arbitrary architecture.
"""
warnings.warn(
'Direct export from RandomOneShot returns an arbitrary architecture. '
'Sampling the best architecture from this trained supernet is another search process. '
'Users need to do another search based on the checkpoint of the one-shot strategy.',
UserWarning
)
return super().export()
class EnasLightningModule(RandomSamplingLightningModule):
_enas_note = """
RL controller learns to generate the best network on a super-net. See `ENAS paper <https://arxiv.org/abs/1802.03268>`__.
There are 2 steps in an epoch.
- Firstly, training model parameters.
- Secondly, training ENAS RL agent. The agent will produce a sample of model architecture to get the best reward.
.. note::
ENAS requires the evaluator to report metrics via ``self.log`` in its ``validation_step``.
See explanation of ``reward_metric_name`` for details.
The supported mutation primitives of ENAS are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`~nni.retiarii.oneshot.pytorch.enas.ReinforceController`.
entropy_weight : float
Weight of sample entropy loss in RL.
skip_weight : float
Weight of skip penalty loss. See :class:`~nni.retiarii.oneshot.pytorch.enas.ReinforceController` for details.
baseline_decay : float
Decay factor of reward baseline, which is used to normalize the reward in RL.
At each step, the new reward baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_steps_aggregate : int
Number of steps for which the gradients will be accumulated,
before updating the weights of RL controller.
ctrl_grad_clip : float
Gradient clipping value of controller.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
If there are multiple, by default, it will find the metric with key name ``default``.
If reward_metric_name is specified, it will find reward_metric_name.
Otherwise it raises an exception indicating multiple metrics are found.
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
)
__doc__ = _enas_note.format(
module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
@property
def automatic_optimization(self) -> bool:
return False
def __init__(self,
inner_module: pl.LightningModule,
*,
ctrl_kwargs: dict[str, Any] | None = None,
entropy_weight: float = 1e-4,
skip_weight: float = .8,
baseline_decay: float = .999,
ctrl_steps_aggregate: float = 20,
ctrl_grad_clip: float = 0,
reward_metric_name: str | None = None,
mutation_hooks: list[MutationHook] | None = None):
super().__init__(inner_module, mutation_hooks)
# convert parameter spec to legacy ReinforceField
# this part will be refactored
self.nas_fields: list[ReinforceField] = []
for name, param_spec in self.search_space_spec().items():
if param_spec.chosen_size not in (1, None):
raise ValueError('ENAS does not support n_chosen to be values other than 1 or None.')
self.nas_fields.append(ReinforceField(name, param_spec.size, param_spec.chosen_size == 1))
self.controller = ReinforceController(self.nas_fields, **(ctrl_kwargs or {}))
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.ctrl_grad_clip = ctrl_grad_clip
self.reward_metric_name = reward_metric_name
def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4)
def training_step(self, batch_packed, batch_idx):
batch, mode = batch_packed
if mode == 'train':
# train model params
with torch.no_grad():
self.resample()
self.call_weight_optimizers('zero_grad')
step_output = self.model.training_step(batch, batch_idx)
w_step_loss = step_output['loss'] \
if isinstance(step_output, dict) else step_output
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
else:
# train ENAS agent
arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad()
self.resample()
step_output = self.model.validation_step(batch, batch_idx)
# use the default metric of self.model as reward function
if len(self.trainer.callback_metrics) == 1:
_, metric = next(iter(self.trainer.callback_metrics.items()))
else:
metric_name = self.reward_metric_name or 'default'
if metric_name not in self.trainer.callback_metrics:
raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but '
f'found multiple (or zero) metrics without default: {list(self.trainer.callback_metrics.keys())}. '
f'Try to use self.log to report metrics with the specified key ``{metric_name}`` in validation_step, '
'and remember to set on_step=True.')
metric = self.trainer.callback_metrics[metric_name]
reward: float = metric.item()
if self.entropy_weight:
reward = reward + self.entropy_weight * self.controller.sample_entropy.item() # type: ignore
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
rnn_step_loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
rnn_step_loss = rnn_step_loss + self.skip_weight * self.controller.sample_skip_penalty
rnn_step_loss = rnn_step_loss / self.ctrl_steps_aggregate
self.manual_backward(rnn_step_loss)
if (batch_idx + 1) % self.ctrl_steps_aggregate == 0:
if self.ctrl_grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.ctrl_grad_clip)
arc_opt.step()
arc_opt.zero_grad()
return step_output
def resample(self):
"""Resample the architecture with ENAS controller."""
sample = self.controller.resample()
result = self._interpret_controller_sampling_result(sample)
for module in self.nas_modules:
module.resample(memo=result)
return result
def export(self):
"""Run one more inference of ENAS controller."""
self.controller.eval()
with torch.no_grad():
return self._interpret_controller_sampling_result(self.controller.resample())
def _interpret_controller_sampling_result(self, sample: dict[str, int]) -> dict[str, Any]:
"""Convert ``{label: index}`` to ``{label: name}``"""
space_spec = self.search_space_spec()
for key in list(sample.keys()):
sample[key] = space_spec[key].values[sample[key]]
return sample
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Strategy integration of one-shot.
This file is put here simply because it relies on "pytorch".
For consistency, please consider importing strategies from ``nni.nas.strategy``.
For example, ``nni.nas.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
"""
from __future__ import annotations
import warnings
from typing import Any, Type
import torch.nn as nn
from nni.nas.execution.common import Model
from nni.nas.strategy.base import BaseStrategy
from nni.nas.evaluator.pytorch.lightning import Lightning, LightningModule
from .base_lightning import BaseOneShotLightningModule
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasLightningModule, RandomSamplingLightningModule
class OneShotStrategy(BaseStrategy):
"""Wrap an one-shot lightning module as a one-shot strategy."""
def __init__(self, oneshot_module: Type[BaseOneShotLightningModule], **kwargs):
self.oneshot_module = oneshot_module
self.oneshot_kwargs = kwargs
self.model: BaseOneShotLightningModule | None = None
def preprocess_dataloader(self, train_dataloaders: Any, val_dataloaders: Any) -> tuple[Any, Any]:
"""
One-shot strategy typically requires fusing train and validation dataloader in an ad-hoc way.
As one-shot strategy doesn't try to open the blackbox of a batch,
theoretically, these dataloader can be
`any dataloader types supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
Returns
-------
A tuple of preprocessed train dataloaders and validation dataloaders.
"""
return train_dataloaders, val_dataloaders
def run(self, base_model: Model, applied_mutators):
# one-shot strategy doesn't use ``applied_mutators``
# but get the "mutators" on their own
_reason = 'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
if not isinstance(base_model.python_object, nn.Module):
raise TypeError('Model is not a nn.Module. ' + _reason)
py_model: nn.Module = base_model.python_object
if applied_mutators:
raise ValueError('Mutator is not empty. ' + _reason)
if not isinstance(base_model.evaluator, Lightning):
raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.')
evaluator_module: LightningModule = base_model.evaluator.module
evaluator_module.running_mode = 'oneshot'
evaluator_module.set_model(py_model)
self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
evaluator: Lightning = base_model.evaluator
if evaluator.train_dataloaders is None or evaluator.val_dataloaders is None:
raise TypeError('Training and validation dataloader are both required to set in evaluator for one-shot strategy.')
train_loader, val_loader = self.preprocess_dataloader(evaluator.train_dataloaders, evaluator.val_dataloaders)
evaluator.trainer.fit(self.model, train_loader, val_loader)
def export_top_models(self, top_k: int = 1) -> list[Any]:
"""The behavior of export top models in strategy depends on the implementation of inner one-shot module."""
if self.model is None:
raise RuntimeError('One-shot strategy needs to be run before export.')
if top_k != 1:
warnings.warn('One-shot strategy currently only supports exporting top-1 model.', RuntimeWarning)
return [self.model.export()]
class DARTS(OneShotStrategy):
__doc__ = DartsLightningModule._darts_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(DartsLightningModule, **kwargs)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
# By returning a dict, we make a CombinedLoader (in Lightning)
return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class Proxyless(OneShotStrategy):
__doc__ = ProxylessLightningModule._proxyless_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(ProxylessLightningModule, **kwargs)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class GumbelDARTS(OneShotStrategy):
__doc__ = GumbelDartsLightningModule._gumbel_darts_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(GumbelDartsLightningModule, **kwargs)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class ENAS(OneShotStrategy):
__doc__ = EnasLightningModule._enas_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(EnasLightningModule, **kwargs)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
# Import locally to avoid import error on legacy PL version
from .dataloader import ConcatLoader
return ConcatLoader({
'train': train_dataloaders,
'val': val_dataloaders
}), None
class RandomOneShot(OneShotStrategy):
__doc__ = RandomSamplingLightningModule._random_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(RandomSamplingLightningModule, **kwargs)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Thie file handles "slice" commonly used in mixed-operation.
The ``slice_type`` we support here, is "slice" or "list of slice".
The reason is that sometimes (e.g., in multi-head attention),
the tensor slice could be from multiple parts. This type is extensible.
We can support arbitrary masks in future if we need them.
To slice a tensor, we need ``multidim_slice``,
which is simply a tuple consists of ``slice_type``.
Usually in python programs, the variable put into slice's start, stop and step
should be integers (or NoneType).
But in our case, it could also be a dict from integer to float,
representing a distribution of integer. When that happens,
we convert a "slice with some weighted values", to a "weighted slice".
To this end, we track the computation with ``MaybeWeighted``,
and replay the computation with each possible value.
Meanwhile, we record their weights.
Note that ``MaybeWeighted`` is also extensible.
We can support more types of objects on slice in future.
The fixed/weighted slice is fed into ``_slice_weight``,
which interprets the slice and apply it on a tensor.
"""
from __future__ import annotations
import operator
from typing import Callable, Iterator, TypeVar, Any, Optional, Tuple, Union, List, Dict, Generic, cast
import numpy as np
import torch
__all__ = [
'slice_type',
'multidim_slice',
'scalar_or_scalar_dict',
'int_or_int_dict',
'zeros_like',
'Slicable',
'MaybeWeighted',
]
T = TypeVar('T')
slice_type = Union[slice, List[slice]]
multidim_slice = Tuple[slice_type, ...]
scalar_or_scalar_dict = Union[T, Dict[T, float]]
int_or_int_dict = scalar_or_scalar_dict[int]
_value_fn_type = Optional[Callable[[int_or_int_dict], int]]
def zeros_like(arr: T) -> T:
if isinstance(arr, np.ndarray):
return np.zeros_like(arr)
elif isinstance(arr, torch.Tensor):
return torch.zeros_like(arr)
else:
raise TypeError(f'Unsupported type for {arr}: {type(arr)}')
def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slice:
# get rid of list of slice
result = []
for i in range(len(slice_)):
if isinstance(slice_[i], list):
# convert list of slices to mask
mask = np.zeros(shape[i], dtype=np.bool) # type: ignore
for sl in cast(List[slice], slice_[i]):
mask[sl] = 1
result.append(mask)
else:
result.append(slice_[i])
return tuple(result)
def _slice_weight(weight: T, slice_: multidim_slice | list[tuple[multidim_slice, float]]) -> T:
# slice_ can be a tuple of slice, e.g., ([3:6], [2:4])
# or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3}
if isinstance(slice_, list):
# for weighted case, we get the corresponding masks. e.g.,
# {([3:6],): 0.6, ([2:4],): 0.3} => [0, 0, 0.3, 0.9, 0.6, 0.6] (if the whole length is 6)
# this mask is broadcasted and multiplied onto the weight
masks = []
# the accepted argument is list of tuple here
# because slice can't be key of dict
for sl, wt in slice_:
# create a mask with weight w
with torch.no_grad():
mask = zeros_like(weight)
mask[_eliminate_list_slice(weight.shape, sl)] = 1 # type: ignore
# track gradients here
masks.append(mask * wt) # type: ignore
masks = sum(masks)
return masks * weight # type: ignore
else:
# for unweighted case, we slice it directly.
def _do_slice(arr, slice_):
return arr[_eliminate_list_slice(arr.shape, slice_)] # type: ignore
# sometimes, we don't need slice.
# this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr = np.zeros(weight.shape, dtype=bool) # type: ignore
no_effect = cast(Any, _do_slice(dummy_arr, slice_)).shape == dummy_arr.shape
if no_effect:
return weight
return _do_slice(weight, slice_)
class Slicable(Generic[T]):
"""Wraps the weight so that in can be sliced with a ``multidim_slice``.
The value within the slice can be instances of :class:`MaybeWeighted`.
Examples
--------
>>> weight = conv2d.weight
>>> Slicable(weight)[:MaybeWeighted({32: 0.4, 64: 0.6})]
Tensor of shape (64, 64, 3, 3)
"""
def __init__(self, weight: T):
if not isinstance(weight, np.ndarray) and not torch.is_tensor(weight):
raise TypeError(f'Unsuppoted weight type: {type(weight)}')
self.weight = weight
def __getitem__(self, index: slice_type | multidim_slice | Any) -> T:
if not isinstance(index, tuple):
index = (index, )
index = cast(multidim_slice, index)
# Get the dict value in index's leafs
# There can be at most one dict
leaf_dict: dict[int, float] | None = None
for maybe_weighted in _iterate_over_multidim_slice(index):
for d in maybe_weighted.leaf_values():
if isinstance(d, dict):
if leaf_dict is None:
leaf_dict = d
elif leaf_dict is not d:
raise ValueError('There can be at most one distinct dict in leaf values.')
if leaf_dict is None:
# in case of simple types with no dict
res_index = _evaluate_multidim_slice(index)
else:
# there is a dict, iterate over dict
res_index = []
for val, wt in leaf_dict.items():
res_index_item = _evaluate_multidim_slice(index, lambda _: val)
res_index.append((res_index_item, wt))
return _slice_weight(self.weight, res_index)
class MaybeWeighted:
"""Wrap a value (int or dict with int keys), so that the computation on it can be replayed.
It builds a binary tree. If ``value`` is not None, it's a leaf node.
Otherwise, it has left sub-tree and right sub-tree and an operation.
Only support basic arithmetic operations: ``+``, ``-``, ``*``, ``//``.
"""
def __init__(self,
value: int_or_int_dict | None = None, *,
lhs: 'MaybeWeighted' | int | None = None,
rhs: 'MaybeWeighted' | int | None = None,
operation: Callable[[int_or_int_dict, int_or_int_dict], int_or_int_dict] | None = None):
if operation is None:
if not isinstance(value, (int, dict)):
raise TypeError(f'Unsupported value type: {type(value)}')
self.value = value
self.lhs = lhs
self.rhs = rhs
self.operation = operation
def leaf_values(self) -> Iterator[int_or_int_dict]:
"""Iterate over values on leaf nodes."""
if self.value is not None:
yield self.value
else:
if isinstance(self.lhs, MaybeWeighted):
yield from self.lhs.leaf_values()
if isinstance(self.rhs, MaybeWeighted):
yield from self.rhs.leaf_values()
def evaluate(self, value_fn: _value_fn_type = None) -> int_or_int_dict:
"""Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``.
If ``value_fn`` is none, no replacement will happen and the raw value will be used.
"""
if self.value is not None:
if value_fn is not None:
return value_fn(self.value)
return self.value
else:
if isinstance(self.lhs, MaybeWeighted):
eval_lhs = self.lhs.evaluate(value_fn)
else:
eval_lhs = cast(int, self.lhs)
if isinstance(self.rhs, MaybeWeighted):
eval_rhs = self.rhs.evaluate(value_fn)
else:
eval_rhs = cast(int, self.rhs)
assert self.operation is not None
return self.operation(eval_lhs, eval_rhs)
def __repr__(self):
if self.value is not None:
return f'{self.__class__.__name__}({self.value})'
return f'{self.__class__.__name__}(lhs={self.lhs}, rhs={self.rhs}, op={self.operation})'
def __add__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.add)
def __radd__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.add)
def __sub__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.sub)
def __rsub__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.sub)
def __mul__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.mul)
def __rmul__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.mul)
def __floordiv__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.floordiv)
def __rfloordiv__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.floordiv)
def _iterate_over_slice_type(s: slice_type):
if isinstance(s, list):
for se in s:
yield from _iterate_over_slice_type(se)
else:
# s must be a "slice" now
if isinstance(s.start, MaybeWeighted):
yield s.start
if isinstance(s.stop, MaybeWeighted):
yield s.stop
if isinstance(s.step, MaybeWeighted):
yield s.step
def _iterate_over_multidim_slice(ms: multidim_slice):
"""Get :class:`MaybeWeighted` instances in ``ms``."""
for s in ms:
if s is not None and s is not Ellipsis:
yield from _iterate_over_slice_type(s)
def _evaluate_slice_type(s: slice_type, value_fn: _value_fn_type = None):
if isinstance(s, list):
return [_evaluate_slice_type(se, value_fn) for se in s]
else:
return slice(
s.start.evaluate(value_fn) if isinstance(s.start, MaybeWeighted) else s.start,
s.stop.evaluate(value_fn) if isinstance(s.stop, MaybeWeighted) else s.stop,
s.step.evaluate(value_fn) if isinstance(s.step, MaybeWeighted) else s.step
)
def _evaluate_multidim_slice(ms: multidim_slice, value_fn: _value_fn_type = None):
"""Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``."""
res = []
for s in ms:
if s is not None and s is not Ellipsis:
res.append(_evaluate_slice_type(s, value_fn))
else:
res.append(s)
return tuple(res)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# type: ignore
"""This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__.
These are merely some components of the algorithm. The complete support is an undergoing work item.
Keep this file here so that it can be "blamed".
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.nn.pytorch import ValueChoice
class DifferentiableSuperConv2d(nn.Conv2d):
"""
Only ``kernel_size`` ``in_channels`` and ``out_channels`` are supported. Kernel size candidates should be larger or smaller
than each other in both candidates. See examples below:
the following example is not allowed:
>>> ValueChoice(candidates = [(5, 3), (3, 5)])
□ ■ ■ ■ □ □ □ □ □ □
□ ■ ■ ■ □ ■ ■ ■ ■ ■ # candidates are not bigger or smaller on both dimension
□ ■ ■ ■ □ ■ ■ ■ ■ ■
□ ■ ■ ■ □ ■ ■ ■ ■ ■
□ ■ ■ ■ □ □ □ □ □ □
the following 3 examples are valid:
>>> ValueChoice(candidates = [5, 3, 1])
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ ■ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
>>> ValueChoice(candidates = [(5, 7), (3, 5), (1, 3)])
■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ ■ ■ ■ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ □ □
>>> # when the difference between any two candidates is not even, the left upper will be picked:
>>> ValueChoice(candidates = [(5, 5), (4, 4), (3, 3)])
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
"""
def __init__(self, module, name):
self.label = name
args = module.trace_kwargs
# compulsory params
if isinstance(args['in_channels'], ValueChoice):
args['in_channels'] = max(args['in_channels'].candidates)
self.out_channel_candidates = None
if isinstance(args['out_channels'], ValueChoice):
self.out_channel_candidates = sorted(args['out_channels'].candidates, reverse=True)
args['out_channels'] = self.out_channel_candidates[0]
# kernel_size may be an int or tuple, we turn it into a tuple for simplicity
self.kernel_size_candidates = None
if isinstance(args['kernel_size'], ValueChoice):
# unify kernel size as tuple
candidates = args['kernel_size'].candidates
if not isinstance(candidates[0], tuple):
candidates = [(k, k) for k in candidates]
# sort kernel size in descending order
self.kernel_size_candidates = sorted(candidates, key=lambda t: t[0], reverse=True)
for i in range(0, len(self.kernel_size_candidates) - 1):
bigger = self.kernel_size_candidates[i]
smaller = self.kernel_size_candidates[i + 1]
assert bigger[1] > smaller[1] or (bigger[1] == smaller[1] and bigger[0] > smaller[0]), f'Kernel_size candidates ' \
f'should be larger or smaller than each other on both dimensions, but found {bigger} and {smaller}.'
args['kernel_size'] = self.kernel_size_candidates[0]
super().__init__(**args)
self.generate_architecture_params()
def forward(self, input):
# Note that there is no need to handle ``in_channels`` here since it is already handle by the ``out_channels`` in the
# previous module. If we multiply alpha with refer to ``in_channels`` here again, the alpha will indeed be considered
# twice, which is not what we expect.
weight = self.weight
def sum_weight(input_weight, masks, thresholds, indicator):
"""
This is to get the weighted sum of weight.
Parameters
----------
input_weight : Tensor
the weight to be weighted summed
masks : list[Tensor]
weight masks.
thresholds : list[float]
thresholds, should have a length of ``len(masks) - 1``
indicator : Callable[[Tensor, float], float]
take a tensor and a threshold as input, and output the weight
Returns
----------
weight : Tensor
weighted sum of ``input_weight``. this is of the same shape as ``input_sum``
"""
# Note that ``masks`` and ``thresholds`` have different lengths. There alignment is shown below:
# self.xxx_candidates = [ c_0 , c_1 , ... , c_n-2 , c_n-1 ] # descending order
# self.xxx_mask = [ mask_0 , mask_1 , ... , mask_n-2, mask_n-1]
# self.t_xxx = [ t_0 , t_2 , ... , t_n-2 ]
# So we zip the first n-1 items, and multiply masks[-1] in the end.
weight = torch.zeros_like(input_weight)
for mask, t in zip(masks[:-1], thresholds):
cur_part = input_weight * mask
alpha = indicator(cur_part, t)
weight = (weight + cur_part) * alpha
# we do not consider skip-op here for out_channel/expansion candidates, which means at least the smallest channel
# candidate is included
weight += input_weight * masks[-1]
return weight
if self.kernel_size_candidates is not None:
weight = sum_weight(weight, self.kernel_masks, self.t_kernel, self.Lasso_sigmoid)
if self.out_channel_candidates is not None:
weight = sum_weight(weight, self.channel_masks, self.t_expansion, self.Lasso_sigmoid)
output = self._conv_forward(input, weight, self.bias)
return output
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super().named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
"""
result = {
'kernel_size': i,
'out_channels': j
}
which means the best candidate for an argument is the i-th one if candidates are sorted in descending order
"""
result = {}
eps = 1e-5
with torch.no_grad():
if self.kernel_size_candidates is not None:
weight = torch.zeros_like(self.weight)
# ascending order
for i in range(len(self.kernel_size_candidates) - 2, -1, -1):
mask = self.kernel_masks[i]
t = self.t_kernel[i]
cur_part = self.weight * mask
alpha = self.Lasso_sigmoid(cur_part, t)
if alpha <= eps: # takes the smaller one
result['kernel_size'] = self.kernel_size_candidates[i + 1]
break
weight = (weight + cur_part) * alpha
if 'kernel_size' not in result:
result['kernel_size'] = self.kernel_size_candidates[0]
else:
weight = self.weight
if self.out_channel_candidates is not None:
for i in range(len(self.out_channel_candidates) - 2, -1, -1):
mask = self.channel_masks[i]
t = self.t_expansion[i]
alpha = self.Lasso_sigmoid(weight * mask, t)
if alpha <= eps:
result['out_channels'] = self.out_channel_candidates[i + 1]
if 'out_channels' not in result:
result['out_channels'] = self.out_channel_candidates[0]
return result
@staticmethod
def Lasso_sigmoid(matrix, t):
"""
A trick that can make use of both the value of bool(lasso > t) and the gradient of sigmoid(lasso - t)
Parameters
----------
matrix : Tensor
the matrix to calculate lasso norm
t : float
the threshold
"""
lasso = torch.norm(matrix) - t
indicator = (lasso > 0).float() # torch.sign(lasso)
with torch.no_grad():
# indicator = indicator / 2 + .5 # realign indicator from (-1, 1) to (0, 1)
indicator -= F.sigmoid(lasso)
indicator += F.sigmoid(lasso)
return indicator
def generate_architecture_params(self):
self.alpha = {}
if self.kernel_size_candidates is not None:
# kernel size arch params
self.t_kernel = nn.Parameter(torch.rand(len(self.kernel_size_candidates) - 1))
self.alpha['kernel_size'] = self.t_kernel
# kernel size mask
self.kernel_masks = []
for i in range(0, len(self.kernel_size_candidates) - 1):
big_size = self.kernel_size_candidates[i]
small_size = self.kernel_size_candidates[i + 1]
mask = torch.zeros_like(self.weight)
mask[:, :, :big_size[0], :big_size[1]] = 1 # if self.weight.shape = (out, in, 7, 7), big_size = (5, 5) and
mask[:, :, :small_size[0], :small_size[1]] = 0 # small_size = (3, 3), mask will look like:
self.kernel_masks.append(mask) # 0 0 0 0 0 0 0
mask = torch.zeros_like(self.weight) # 0 1 1 1 1 1 0
mask[:, :, :self.kernel_size_candidates[-1][0], :self.kernel_size_candidates[-1][1]] = 1 # 0 1 0 0 0 1 0
self.kernel_masks.append(mask) # 0 1 0 0 0 1 0
# 0 1 0 0 0 1 0
if self.out_channel_candidates is not None: # 0 1 1 1 1 1 0
# out_channel (or expansion) arch params. we do not consider skip-op here, so we # 0 0 0 0 0 0 0
# only generate ``len(self.kernel_size_candidates) - 1 `` thresholds
self.t_expansion = nn.Parameter(torch.rand(len(self.out_channel_candidates) - 1))
self.alpha['out_channels'] = self.t_expansion
self.channel_masks = []
for i in range(0, len(self.out_channel_candidates) - 1):
big_channel, small_channel = self.out_channel_candidates[i], self.out_channel_candidates[i + 1]
mask = torch.zeros_like(self.weight)
mask[:big_channel] = 1
mask[:small_channel] = 0
# if self.weight.shape = (32, in, W, H), big_channel = 16 and small_size = 8, mask will look like:
# 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
self.channel_masks.append(mask)
mask = torch.zeros_like(self.weight)
mask[:self.out_channel_candidates[-1]] = 1
self.channel_masks.append(mask)
class DifferentiableBatchNorm2d(nn.BatchNorm2d):
def __init__(self, module, name):
self.label = name
args = module.trace_kwargs
if isinstance(args['num_features'], ValueChoice):
args['num_features'] = max(args['num_features'].candidates)
super().__init__(**args)
# no architecture parameter is needed for BatchNorm2d Layers
self.alpha = nn.Parameter(torch.tensor([]))
def export(self):
"""
No need to export ``BatchNorm2d``. Refer to the ``Conv2d`` layer that has the ``ValueChoice`` as ``out_channels``.
"""
return -1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Utilities to process the value choice compositions,
in the way that is most convenient to one-shot algorithms."""
from __future__ import annotations
import itertools
from typing import Any, TypeVar, List, cast, Mapping, Sequence, Optional, Iterable
import numpy as np
import torch
from nni.common.hpo_utils import ParameterSpec
from nni.nas.nn.pytorch.choice import ChoiceOf, ValueChoiceX
Choice = Any
T = TypeVar('T')
__all__ = [
'dedup_inner_choices',
'evaluate_value_choice_with_dict',
'traverse_all_options',
'weighted_sum',
'evaluate_constant',
]
def dedup_inner_choices(value_choices: list[ValueChoiceX]) -> dict[str, ParameterSpec]:
"""Find all leaf nodes in ``value_choices``,
save them into in the format of ``{label: parameter_spec}``.
"""
result = {}
for value_choice in value_choices:
for choice in value_choice.inner_choices():
param_spec = ParameterSpec(choice.label, 'choice', choice.candidates, (choice.label, ), True, size=len(choice.candidates))
if choice.label in result:
if param_spec != result[choice.label]:
raise ValueError('Value choice conflict: same label with different candidates: '
f'{param_spec} vs. {result[choice.label]}')
else:
result[choice.label] = param_spec
return result
def evaluate_value_choice_with_dict(value_choice: ChoiceOf[T], chosen: dict[str, Choice]) -> T:
"""To evaluate a composition of value-choice with a dict,
with format of ``{label: chosen_value}``.
The implementation is two-pass. We first get a list of values,
then feed the values into ``value_choice.evaluate``.
This can be potentially optimized in terms of speed.
Examples
--------
>>> chosen = {"exp_ratio": 3}
>>> evaluate_value_choice_with_dict(value_choice_in, chosen)
48
>>> evaluate_value_choice_with_dict(value_choice_out, chosen)
96
"""
choice_inner_values = []
for choice in value_choice.inner_choices():
if choice.label not in chosen:
raise KeyError(f'{value_choice} depends on a value with key {choice.label}, but not found in {chosen}')
choice_inner_values.append(chosen[choice.label])
return value_choice.evaluate(choice_inner_values)
def traverse_all_options(
value_choice: ChoiceOf[T],
weights: dict[str, list[float]] | dict[str, np.ndarray] | dict[str, torch.Tensor] | None = None
) -> list[tuple[T, float]] | list[T]:
"""Traverse all possible computation outcome of a value choice.
If ``weights`` is not None, it will also compute the probability of each possible outcome.
Parameters
----------
value_choice : ValueChoiceX
The value choice to traverse.
weights : Optional[dict[str, list[float]]], default = None
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
weights can be provided. The key is label, value are list of float indicating probability.
Normally, they should sum up to 1, but we will not check them in this function.
Returns
-------
list[Union[tuple[Any, float], Any]]
Results will be sorted and duplicates will be eliminated.
If weights is provided, the return value will be a list of tuple, with option and its weight.
Otherwise, it will be a list of options.
"""
# get a dict of {label: list of tuple of choice and weight}
leafs: dict[str, list[tuple[T, float]]] = {}
for label, param_spec in dedup_inner_choices([value_choice]).items():
if weights is not None:
if label not in weights:
raise KeyError(f'{value_choice} depends on a weight with key {label}, but not found in {weights}')
if len(weights[label]) != param_spec.size:
raise KeyError(f'Expect weights with {label} to be of length {param_spec.size}, but {len(weights[label])} found')
leafs[label] = list(zip(param_spec.values, cast(List[float], weights[label])))
else:
# create a dummy weight of zero, in case that weights are not provided.
leafs[label] = list(zip(param_spec.values, itertools.repeat(0., param_spec.size)))
# result is a dict from a option to its weight
result: dict[T, float | None] = {}
labels, values = list(leafs.keys()), list(leafs.values())
if not labels:
raise ValueError(f'There expects at least one leaf value choice in {value_choice}, but nothing found')
# get all combinations
for prod_value in itertools.product(*values):
# For example,
# prod_value = ((3, 0.1), ("cat", 0.3), ({"in": 5}, 0.5))
# the first dim is chosen value, second dim is probability
# chosen = {"ks": 3, "animal": "cat", "linear_args": {"in": 5}}
# chosen_weight = np.prod([0.1, 0.3, 0.5])
chosen = {label: value[0] for label, value in zip(labels, prod_value)}
eval_res = evaluate_value_choice_with_dict(value_choice, chosen)
if weights is None:
result[eval_res] = None
else:
# we can't use reduce or inplace product here,
# because weight can sometimes be tensors
chosen_weight = prod_value[0][1]
for value in prod_value[1:]:
if chosen_weight is None:
chosen_weight = value[1]
else:
chosen_weight = chosen_weight * value[1]
if eval_res in result:
result[eval_res] = result[eval_res] + chosen_weight
else:
result[eval_res] = chosen_weight
if weights is None:
return sorted(result.keys()) # type: ignore
else:
return sorted(result.items()) # type: ignore
def evaluate_constant(expr: Any) -> Any:
"""Evaluate a value choice expression to a constant. Raise ValueError if it's not a constant."""
all_options = traverse_all_options(expr)
if len(all_options) > 1:
raise ValueError(f'{expr} is not evaluated to a constant. All possible values are: {all_options}')
res = all_options[0]
return res
def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence[Optional[float]], None)) -> T:
"""Return a weighted sum of items.
Items can be list of tensors, numpy arrays, or nested lists / dicts.
If ``weights`` is None, this is simply an unweighted sum.
"""
if weights is None:
weights = [None] * len(items)
assert len(items) == len(weights) > 0
elem = items[0]
unsupported_msg = f'Unsupported element type in weighted sum: {type(elem)}. Value is: {elem}'
if isinstance(elem, str):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise TypeError(unsupported_msg)
try:
if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)):
if weights[0] is None:
res = elem
else:
res = elem * weights[0]
for it, weight in zip(items[1:], weights[1:]):
if type(it) != type(elem):
raise TypeError(f'Expect type {type(elem)} but found {type(it)}. Can not be summed')
if weight is None:
res = res + it # type: ignore
else:
res = res + it * weight # type: ignore
return cast(T, res)
if isinstance(elem, Mapping):
for item in items:
if not isinstance(item, Mapping):
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
if set(item) != set(elem):
raise KeyError(f'Expect keys {list(elem)} but found {list(item)}')
return cast(T, {
key: weighted_sum(cast(List[dict], [cast(Mapping, d)[key] for d in items]), weights) for key in elem
})
if isinstance(elem, Sequence):
for item in items:
if not isinstance(item, Sequence):
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
if len(item) != len(elem):
raise ValueError(f'Expect length {len(item)} but found {len(elem)}')
transposed = cast(Iterable[list], zip(*items)) # type: ignore
return cast(T, [weighted_sum(column, weights) for column in transposed])
except (TypeError, ValueError, RuntimeError, KeyError):
raise ValueError(
'Error when summing items. Value format / shape does not match. See full traceback for details.' +
''.join([
f'\n {idx}: {_summarize_elem_format(it)}' for idx, it in enumerate(items)
])
)
# Dealing with all unexpected types.
raise TypeError(unsupported_msg)
def _summarize_elem_format(elem: Any) -> Any:
# Get a summary of one elem
# Helps generate human-readable error messages
class _repr_object:
# empty object is only repr
def __init__(self, representation):
self.representation = representation
def __repr__(self):
return self.representation
if isinstance(elem, torch.Tensor):
return _repr_object('torch.Tensor(' + ', '.join(map(str, elem.shape)) + ')')
if isinstance(elem, np.ndarray):
return _repr_object('np.array(' + ', '.join(map(str, elem.shape)) + ')')
if isinstance(elem, Mapping):
return {key: _summarize_elem_format(value) for key, value in elem.items()}
if isinstance(elem, Sequence):
return [_summarize_elem_format(value) for value in elem]
# fallback to original, for cases like float, int, ...
return elem
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from typing import Any
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
__all__ = ['BaseSuperNetModule']
class BaseSuperNetModule(nn.Module):
"""
Mutated module in super-net.
Usually, the feed-forward of the module itself is undefined.
It has to be resampled with ``resample()`` so that a specific path is selected.
(Sometimes, this is not required. For example, differentiable super-net.)
A super-net module usually corresponds to one sample. But two exceptions:
* A module can have multiple parameter spec. For example, a convolution-2d can sample kernel size, channels at the same time.
* Multiple modules can share one parameter spec. For example, multiple layer choices with the same label.
For value choice compositions, the parameter spec are bounded to the underlying (original) value choices,
rather than their compositions.
"""
def resample(self, memo: dict[str, Any]) -> dict[str, Any]:
"""
Resample the super-net module.
Parameters
----------
memo : dict[str, Any]
Used to ensure the consistency of samples with the same label.
Returns
-------
dict
Sampled result. If nothing new is sampled, it should return an empty dict.
"""
raise NotImplementedError()
def export(self, memo: dict[str, Any]) -> dict[str, Any]:
"""
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
Parameters
----------
memo : dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise NotImplementedError()
def search_space_spec(self) -> dict[str, ParameterSpec]:
"""
Space specification (sample points).
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
For example: ::
{"layer1": ParameterSpec(values=["conv", "pool"])}
"""
raise NotImplementedError()
@classmethod
def mutate(cls, module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> \
'BaseSuperNetModule' | bool | tuple['BaseSuperNetModule', bool]:
"""This is a mutation hook that creates a :class:`BaseSuperNetModule`.
The method should be implemented in each specific super-net module,
because they usually have specific rules about what kind of modules to operate on.
Parameters
----------
module : nn.Module
The module to be mutated (replaced).
name : str
Name of this module. With full prefix. For example, ``module1.block1.conv``.
memo : dict
Memo to enable sharing parameters among mutated modules. It should be read and written by
mutate functions themselves.
mutate_kwargs : dict
Algo-related hyper-parameters, and some auxiliary information.
Returns
-------
Union[BaseSuperNetModule, bool, tuple[BaseSuperNetModule, bool]]
The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks.
See :class:`BaseOneShotLightningModule <nni.retiarii.oneshot.pytorch.base_lightning.BaseOneShotLightningModule>` for details.
"""
raise NotImplementedError()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
import functools
import logging
import warnings
from typing import Any, Dict, Sequence, List, Tuple, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.common.hpo_utils import ParameterSpec
from nni.nas.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.nn.pytorch.cell import preprocess_cell_inputs
from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy
from .sampling import PathSamplingCell
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, weighted_sum
_logger = logging.getLogger(__name__)
__all__ = [
'DifferentiableMixedLayer', 'DifferentiableMixedInput',
'DifferentiableMixedRepeat', 'DifferentiableMixedCell',
'MixedOpDifferentiablePolicy', 'GumbelSoftmax',
]
class GumbelSoftmax(nn.Softmax):
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
dim: int
def __init__(self, dim: int = -1) -> None:
super().__init__(dim)
self.tau = 1
self.hard = False
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return F.gumbel_softmax(inputs, tau=self.tau, hard=self.hard, dim=self.dim)
class DifferentiableMixedLayer(BaseSuperNetModule):
"""
Mixed layer, in which fprop is decided by a weighted sum of several layers.
Proposed in `DARTS: Differentiable Architecture Search <https://arxiv.org/abs/1806.09055>`__.
The weight ``alpha`` is usually learnable, and optimized on validation dataset.
Differentiable sampling layer requires all operators returning the same shape for one input,
as all outputs will be weighted summed to get the final output.
Parameters
----------
paths : list[tuple[str, nn.Module]]
Layers to choose from. Each is a tuple of name, and its module.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
op_names : str
Operator names.
label : str
Name of the choice.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self,
paths: list[tuple[str, nn.Module]],
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__()
self.op_names = []
if len(alpha) != len(paths):
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({len(paths)}).')
for name, module in paths:
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self.label = label
self._arch_alpha = alpha
self._softmax = softmax
def resample(self, memo):
"""Do nothing. Differentiable layer doesn't need resample."""
return {}
def export(self, memo):
"""Choose the operator with the maximum logit."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, LayerChoice):
size = len(module)
if module.label in memo:
alpha = memo[module.label]
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(list(module.named_children()), alpha, softmax, module.label)
def reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, *args, **kwargs):
"""The forward of mixed layer accepts same arguments as its sub-layer."""
all_op_results = [getattr(self, op)(*args, **kwargs) for op in self.op_names]
return self.reduction(all_op_results, self._softmax(self._arch_alpha))
def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
"""Named parameters excluding architecture parameters."""
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name == par_name for par_name in self._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
class DifferentiableMixedInput(BaseSuperNetModule):
"""
Mixed input. Forward returns a weighted sum of candidates.
Implementation is very similar to :class:`DifferentiableMixedLayer`.
Parameters
----------
n_candidates : int
Expect number of input candidates.
n_chosen : int
Expect numebr of inputs finally chosen.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
label : str
Name of the choice.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self,
n_candidates: int,
n_chosen: int | None,
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__()
self.n_candidates = n_candidates
if len(alpha) != n_candidates:
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({n_candidates}).')
if n_chosen is None:
warnings.warn('Differentiable architecture search does not support choosing multiple inputs. Assuming one.',
RuntimeWarning)
self.n_chosen = 1
self.n_chosen = n_chosen
self.label = label
self._softmax = softmax
self._arch_alpha = alpha
def resample(self, memo):
"""Do nothing. Differentiable layer doesn't need resample."""
return {}
def export(self, memo):
"""Choose the operator with the top ``n_chosen`` logits."""
if self.label in memo:
return {} # nothing new to export
chosen = sorted(torch.argsort(-self._arch_alpha).cpu().numpy().tolist()[:self.n_chosen])
if len(chosen) == 1:
chosen = chosen[0]
return {self.label: chosen}
def search_space_spec(self):
return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
(self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen)
}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, InputChoice):
if module.reduction not in ['sum', 'mean']:
raise ValueError('Only input choice of sum/mean reduction is supported.')
size = module.n_candidates
if module.label in memo:
alpha = memo[module.label]
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label)
def reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, inputs):
"""Forward takes a list of input candidates."""
return self.reduction(inputs, self._softmax(self._arch_alpha))
def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
"""Named parameters excluding architecture parameters."""
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name == par_name for par_name in self._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
"""Implementes the differentiable sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Thus the ``_arch_alpha`` here is a parameter dict, and ``named_parameters``
filters out multiple parameters with ``_arch_alpha`` as its prefix.
When this class is asked for ``forward_argument``, it returns a distribution,
i.e., a dict from int to float based on its weights.
All the parameters (``_arch_alpha``, ``parameters()``, ``_softmax``) are
saved as attributes of ``operation``, rather than ``self``,
because this class itself is not a ``nn.Module``, and saved parameters here
won't be optimized.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation._arch_alpha = nn.ParameterDict()
for name, spec in operation.search_space_spec().items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
operation._arch_alpha[name] = alpha
operation.parameters = functools.partial(self.parameters, module=operation) # bind self
operation.named_parameters = functools.partial(self.named_parameters, module=operation)
operation._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
@staticmethod
def parameters(module, *args, **kwargs):
for _, p in module.named_parameters(*args, **kwargs):
yield p
@staticmethod
def named_parameters(module, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super(module.__class__, module).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Differentiable. Do nothing in resample."""
return {}
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is argmax for each leaf value choice."""
result = {}
for name, spec in operation.search_space_spec().items():
if name in memo:
continue
chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item())
result[name] = spec.values[chosen_index]
return result
def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any:
if name in operation.mutable_arguments:
weights: dict[str, torch.Tensor] = {
label: cast(nn.Module, operation._softmax)(alpha) for label, alpha in cast(dict, operation._arch_alpha).items()
}
return dict(traverse_all_options(operation.mutable_arguments[name], weights=weights))
return operation.init_arguments[name]
class DifferentiableMixedRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths.
If the output is not a single tensor, it will be summed at every independant dimension.
See :func:`weighted_sum` for details.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self,
blocks: list[nn.Module],
depth: ChoiceOf[int],
softmax: nn.Module,
memo: dict[str, Any]):
super().__init__()
self.blocks = blocks
self.depth = depth
self._softmax = softmax
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._arch_alpha = nn.ParameterDict()
for name, spec in self._space_spec.items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
self._arch_alpha[name] = alpha
def resample(self, memo):
"""Do nothing."""
return {}
def export(self, memo):
"""Choose argmax for each leaf value choice."""
result = {}
for name, spec in self._space_spec.items():
if name in memo:
continue
chosen_index = int(torch.argmax(self._arch_alpha[name]).item())
result[name] = spec.values[chosen_index]
return result
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
# Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(cast(List[nn.Module], module.blocks), module.depth_choice, softmax, memo)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
def reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, x):
weights: dict[str, torch.Tensor] = {
label: self._softmax(alpha) for label, alpha in self._arch_alpha.items()
}
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights)))
res: list[torch.Tensor] = []
weight_list: list[float] = []
depths: list[int] = []
for i, block in enumerate(self.blocks, start=1): # start=1 because depths are 1, 2, 3, 4...
x = block(x)
if i in depth_weights:
weight_list.append(depth_weights[i])
res.append(x)
depths.append(i)
return self.reduction(res, weight_list, depths)
class DifferentiableMixedCell(PathSamplingCell):
"""Implementation of Cell under differentiable context.
An architecture parameter is created on each edge of the full-connected graph.
"""
# TODO: It inherits :class:`PathSamplingCell` to reduce some duplicated code.
# Possibly need another refactor here.
def __init__(
self, op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor, concat_dim,
memo, mutate_kwargs, label
):
super().__init__(
op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor,
concat_dim, memo, mutate_kwargs, label
)
self._arch_alpha = nn.ParameterDict()
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for j in range(i):
edge_label = f'{label}/{i}_{j}'
op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j]
if edge_label in memo:
alpha = memo[edge_label]
if len(alpha) != len(op):
raise ValueError(
f'Architecture parameter size of same label {edge_label} conflict: '
f'{len(alpha)} vs. {len(op)}'
)
else:
alpha = nn.Parameter(torch.randn(len(op)) * 1E-3)
self._arch_alpha[edge_label] = alpha
self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
def resample(self, memo):
"""Differentiable doesn't need to resample."""
return {}
def export(self, memo):
"""Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
"""
exported = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
# Tuple of (weight, input_index, op_name)
all_weights: list[tuple[float, int, str]] = []
for j in range(i):
for k, name in enumerate(self.op_names):
all_weights.append((
float(self._arch_alpha[f'{self.label}/{i}_{j}'][k].item()),
j, name,
))
all_weights.sort(reverse=True)
# We first prefer inputs from different input_index.
# If we have got no other choices, we start to accept duplicates.
# Therefore we gather first occurrences of distinct input_index to the front.
first_occurrence_index: list[int] = [
all_weights.index( # The index of
next(filter(lambda t: t[1] == j, all_weights)) # First occurence of j
)
for j in range(i) # For j < i
]
first_occurrence_index.sort() # Keep them ordered too.
all_weights = [all_weights[k] for k in first_occurrence_index] + \
[w for j, w in enumerate(all_weights) if j not in first_occurrence_index]
_logger.info('Sorted weights in differentiable cell export (node %d): %s', i, all_weights)
for k in range(self.num_ops_per_node):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
_, j, op_name = all_weights[k % len(all_weights)]
exported[f'{self.label}/op_{i}_{k}'] = op_name
exported[f'{self.label}/input_{i}_{k}'] = j
return exported
def forward(self, *inputs: list[torch.Tensor] | torch.Tensor) -> tuple[torch.Tensor, ...] | torch.Tensor:
processed_inputs: list[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: list[torch.Tensor] = self.preprocessor(processed_inputs)
for i, ops in enumerate(cast(Sequence[Sequence[Dict[str, nn.Module]]], self.ops), start=self.num_predecessors):
current_state = []
for j in range(i): # for every previous tensors
op_results = torch.stack([op(states[j]) for op in ops[j].values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0)
current_state.append(edge_sum)
states.append(sum(current_state)) # type: ignore
# Always merge all
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Operations that support weight sharing at a fine-grained level,
which is commonly known as super-kernel (as in channel search), or weight entanglement.
"""
from __future__ import annotations
import inspect
import itertools
import warnings
from typing import Any, Type, TypeVar, cast, Union, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import nni.nas.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.nas.nn.pytorch.choice import ValueChoiceX
from .base import BaseSuperNetModule
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant
from ._operation_utils import Slicable as _S, MaybeWeighted as _W, int_or_int_dict, scalar_or_scalar_dict
T = TypeVar('T')
__all__ = [
'MixedOperationSamplingPolicy',
'MixedOperation',
'MixedLinear',
'MixedConv2d',
'MixedBatchNorm2d',
'MixedLayerNorm',
'MixedMultiHeadAttention',
'NATIVE_MIXED_OPERATIONS',
]
_diff_not_compatible_error = 'To be compatible with differentiable one-shot strategy, {} in {} must not be ValueChoice.'
class MixedOperationSamplingPolicy:
"""
Algo-related part for mixed Operation.
:class:`MixedOperation` delegates its resample and export to this policy (or its subclass),
so that one Operation can be easily combined with different kinds of sampling.
One SamplingStrategy corresponds to one mixed operation.
"""
def __init__(self, operation: 'MixedOperation', memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
"""At init, the sampling policy can prepare basic parameters,
and store them in operation if they need back propagation.
This init is called in :meth:`BaseSuperNetModule.mutate`, after the mixed operation is created.
So similar to :meth:`BaseSuperNetModule.mutate`,
memo should also be managed (read and written) by the policy itself.
"""
pass
def resample(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.resample`."""
raise NotImplementedError()
def export(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.export`."""
raise NotImplementedError()
def forward_argument(self, operation: 'MixedOperation', name: str) -> Any:
"""Computing the argument with ``name`` used in operation's forward.
Usually a value, or a distribution of value.
"""
raise NotImplementedError()
class MixedOperation(BaseSuperNetModule):
"""This is the base class for all mixed operations.
It's what you should inherit to support a new operation with ValueChoice.
It contains commonly used utilities that will ease the effort to write customized mixed oeprations,
i.e., operations with ValueChoice in its arguments.
To customize, please write your own mixed operation, and add the hook into ``mutation_hooks`` parameter when using the strategy.
By design, for a mixed operation to work in a specific algorithm,
at least two classes are needed.
1. One class needs to inherit this class, to control operation-related behavior,
such as how to initialize the operation such that the sampled operation can be its sub-operation.
2. The other one needs to inherit :class:`MixedOperationSamplingPolicy`,
which controls algo-related behavior, such as sampling.
The two classes are linked with ``sampling_policy`` attribute in :class:`MixedOperation`,
whose type is set via ``mixed_op_sampling`` in ``mutate_kwargs`` when
:meth:`MixedOperation.mutate` is called.
With this design, one mixed-operation (e.g., MixedConv2d) can work in multiple algorithms
(e.g., both DARTS and ENAS), saving the engineering effort to rewrite all operations for
each specific algo.
This class should also define a ``bound_type``, to control the matching type in mutate,
an ``argument_list``, to control which arguments can be dynamically used in ``forward``.
This list will also be used in mutate for sanity check.
"""
bound_type: Type[nn.Module] # defined in subclass
argument_list: list[str] # defined in subclass
sampling_policy: MixedOperationSamplingPolicy
def super_init_argument(self, name: str, value_choice: ValueChoiceX) -> Any:
"""Get the initialization argument when constructing super-kernel, i.e., calling ``super().__init__()``.
This is often related to specific operator, rather than algo.
For example::
def super_init_argument(self, name, value_choice):
return max(value_choice.candidates)
"""
raise NotImplementedError()
def __post_init__(self) -> None:
"""Can be used to validate, or to do extra processing after calling ``__init__``."""
pass
def forward_with_args(self, *args, **kwargs):
"""To control real fprop. The accepted arguments are ``argument_list``,
appended by forward arguments in the ``bound_type``."""
raise NotImplementedError()
def __init__(self, module_kwargs: dict[str, Any]) -> None:
# Concerned arguments
self.mutable_arguments: dict[str, ValueChoiceX] = {}
# Useful when retrieving arguments without ValueChoice
self.init_arguments: dict[str, Any] = {**module_kwargs}
self._fill_missing_init_arguments()
# get init default
super_init_kwargs = {}
for key, value in module_kwargs.items():
if isinstance(value, ValueChoiceX):
if key not in self.argument_list:
raise TypeError(f'Unsupported value choice on argument of {self.bound_type}: {key}')
super_init_kwargs[key] = self.super_init_argument(key, value)
self.mutable_arguments[key] = value
else:
super_init_kwargs[key] = value
# get all inner leaf value choices
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices(list(self.mutable_arguments.values()))
super().__init__(**super_init_kwargs)
self.__post_init__()
def resample(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
return self.sampling_policy.resample(self, memo)
def export(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
return self.sampling_policy.export(self, memo)
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
"""Find value choice in module's arguments and replace the whole module"""
has_valuechoice = False
if isinstance(module, cls.bound_type) and is_traceable(module):
for arg in itertools.chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()):
if isinstance(arg, ValueChoiceX):
has_valuechoice = True
if has_valuechoice:
if module.trace_args:
raise ValueError('ValueChoice on class arguments cannot appear together with ``trace_args``. '
'Please enable ``kw_only`` on nni.trace.')
# save type and kwargs
mixed_op = cls(cast(dict, module.trace_kwargs))
if 'mixed_op_sampling' not in mutate_kwargs:
raise ValueError("Need a sampling policy for mixed op, but it's not found in `mutate_kwargs`.")
policy_cls: Type[MixedOperationSamplingPolicy] = mutate_kwargs['mixed_op_sampling']
# initialize policy class
# this is put in mutate because we need to access memo
mixed_op.sampling_policy = policy_cls(mixed_op, memo, mutate_kwargs)
return mixed_op
def forward_argument(self, name: str) -> Any:
"""Get the argument used in forward.
This if often related to algo. We redirect this to sampling policy.
"""
return self.sampling_policy.forward_argument(self, name)
def forward(self, *args, **kwargs):
"""First get sampled arguments, then forward with the sampled arguments (by calling ``forward_with_args``)."""
sampled_args = [self.forward_argument(name) for name in self.argument_list]
return self.forward_with_args(*sampled_args, *args, **kwargs)
def _fill_missing_init_arguments(self) -> None:
"""Set the unspecified init arguments in ``self.init_arguments``.
For example, in the case of Conv2d, when user didn't specify argument ``stride``,
this method adds ``stride = 1`` in ``self.init_arguments``.
This is implemented by inspecting the init signature of ``bound_type``.
Arguments in complex cases like ``__new__`` or in super-class is not supported.
"""
def unwrap(cls):
if not hasattr(cls, '__wrapped__'):
return cls
return unwrap(cls.__wrapped__)
for param in inspect.signature(unwrap(self.bound_type).__init__).parameters.values():
if param.default is not param.empty and param.name not in self.init_arguments:
self.init_arguments[param.name] = param.default
class MixedLinear(MixedOperation, nn.Linear):
"""Mixed linear operation.
Supported arguments are:
- ``in_features``
- ``out_features``
Prefix of weight and bias will be sliced.
"""
bound_type = nas_nn.Linear
argument_list = ['in_features', 'out_features']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def forward_with_args(self,
in_features: int_or_int_dict,
out_features: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
in_features_ = _W(in_features)
out_features_ = _W(out_features)
weight = _S(self.weight)[:out_features_]
weight = _S(weight)[:, :in_features_]
if self.bias is None:
bias = self.bias
else:
bias = _S(self.bias)[:out_features_]
return F.linear(inputs, weight, bias)
_int_or_tuple = Union[int, Tuple[int, int]]
class MixedConv2d(MixedOperation, nn.Conv2d):
"""Mixed conv2d op.
Supported arguments are:
- ``in_channels``
- ``out_channels``
- ``groups``
- ``stride`` (only supported in path sampling)
- ``kernel_size``
- ``padding``
- ``dilation`` (only supported in path sampling)
``padding`` will be the "max" padding in differentiable mode.
Mutable ``groups`` is NOT supported in most cases of differentiable mode.
However, we do support one special case when the group number is proportional to ``in_channels`` and ``out_channels``.
This is often the case of depth-wise convolutions.
For channels, prefix will be sliced.
For kernels, we take the small kernel from the center and round it to floor (left top). For example ::
max_kernel = 5*5, sampled_kernel = 3*3, then we take [1: 4]
max_kernel = 5*5, sampled_kernel = 2*2, then we take [1: 3]
□ □ □ □ □ □ □ □ □ □
□ ■ ■ ■ □ □ ■ ■ □ □
□ ■ ■ ■ □ □ ■ ■ □ □
□ ■ ■ ■ □ □ □ □ □ □
□ □ □ □ □ □ □ □ □ □
"""
bound_type = nas_nn.Conv2d
argument_list = [
'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'dilation', 'groups'
]
@staticmethod
def _to_tuple(value: scalar_or_scalar_dict[Any]) -> tuple[Any, Any]:
if not isinstance(value, tuple):
return (value, value)
return value
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
if name not in ['in_channels', 'out_channels', 'groups', 'stride', 'kernel_size', 'padding', 'dilation']:
raise NotImplementedError(f'Unsupported value choice on argument: {name}')
if name == ['kernel_size', 'padding']:
all_sizes = set(traverse_all_options(value_choice))
if any(isinstance(sz, tuple) for sz in all_sizes):
# maximum kernel should be calculated on every dimension
return (
max(self._to_tuple(sz)[0] for sz in all_sizes),
max(self._to_tuple(sz)[1] for sz in all_sizes)
)
else:
return max(all_sizes)
elif name == 'groups':
if 'in_channels' in self.mutable_arguments:
# If the ratio is constant, we don't need to try the maximum groups.
try:
constant = evaluate_constant(self.mutable_arguments['in_channels'] / value_choice)
return max(cast(List[float], traverse_all_options(value_choice))) // int(constant)
except ValueError:
warnings.warn(
'Both input channels and groups are ValueChoice in a convolution, and their relative ratio is not a constant. '
'This can be problematic for most one-shot algorithms. Please check whether this is your intention.',
RuntimeWarning
)
# minimum groups, maximum kernel
return min(traverse_all_options(value_choice))
else:
return max(traverse_all_options(value_choice))
def forward_with_args(self,
in_channels: int_or_int_dict,
out_channels: int_or_int_dict,
kernel_size: scalar_or_scalar_dict[_int_or_tuple],
stride: _int_or_tuple,
padding: scalar_or_scalar_dict[_int_or_tuple],
dilation: int,
groups: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [stride, dilation]):
raise ValueError(_diff_not_compatible_error.format('stride, dilation', 'Conv2d'))
in_channels_ = _W(in_channels)
out_channels_ = _W(out_channels)
# slice prefix
# For groups > 1, we use groups to slice input weights
weight = _S(self.weight)[:out_channels_]
if not isinstance(groups, dict):
weight = _S(weight)[:, :in_channels_ // groups]
else:
assert 'groups' in self.mutable_arguments
err_message = 'For differentiable one-shot strategy, when groups is a ValueChoice, ' \
'in_channels and out_channels should also be a ValueChoice. ' \
'Also, the ratios of in_channels divided by groups, and out_channels divided by groups ' \
'should be constants.'
if 'in_channels' not in self.mutable_arguments or 'out_channels' not in self.mutable_arguments:
raise ValueError(err_message)
try:
in_channels_per_group = evaluate_constant(self.mutable_arguments['in_channels'] / self.mutable_arguments['groups'])
except ValueError:
raise ValueError(err_message)
if in_channels_per_group != int(in_channels_per_group):
raise ValueError(f'Input channels per group is found to be a non-integer: {in_channels_per_group}')
if inputs.size(1) % in_channels_per_group != 0:
raise RuntimeError(
f'Input channels must be divisible by in_channels_per_group, but the input shape is {inputs.size()}, '
f'while in_channels_per_group = {in_channels_per_group}'
)
# Compute sliced weights and groups (as an integer)
weight = _S(weight)[:, :int(in_channels_per_group)]
groups = inputs.size(1) // int(in_channels_per_group)
# slice center
if isinstance(kernel_size, dict):
# If kernel size is a dict, ignore choices in padding.
if isinstance(self.padding, str):
raise ValueError(f'Use "{self.padding}" in padding is not supported.')
padding = self.padding # max padding, must be a tuple
kernel_a, kernel_b = self._to_tuple(kernel_size)
kernel_a_, kernel_b_ = _W(kernel_a), _W(kernel_b)
max_kernel_a, max_kernel_b = self.kernel_size # self.kernel_size must be a tuple
kernel_a_left, kernel_b_top = (max_kernel_a - kernel_a_) // 2, (max_kernel_b - kernel_b_) // 2
weight = _S(weight)[:, :, kernel_a_left:kernel_a_left + kernel_a_, kernel_b_top:kernel_b_top + kernel_b_]
bias = _S(self.bias)[:out_channels_] if self.bias is not None else None
# The rest parameters only need to be converted to tuple
stride_ = self._to_tuple(stride)
dilation_ = self._to_tuple(dilation)
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(inputs, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, stride_, (0, 0), dilation_, groups)
return F.conv2d(inputs, weight, bias, stride_, cast('int | tuple', padding), dilation_, groups)
class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
"""
Mixed BatchNorm2d operation.
Supported arguments are:
- ``num_features``
- ``eps`` (only supported in path sampling)
- ``momentum`` (only supported in path sampling)
For path-sampling, prefix of ``weight``, ``bias``, ``running_mean`` and ``running_var``
are sliced. For weighted cases, the maximum ``num_features`` is used directly.
Momentum is required to be float.
PyTorch BatchNorm supports a case where momentum can be none, which is not supported here.
"""
bound_type = nas_nn.BatchNorm2d
argument_list = ['num_features', 'eps', 'momentum']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def forward_with_args(self,
num_features: int_or_int_dict,
eps: float,
momentum: float,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [eps, momentum]):
raise ValueError(_diff_not_compatible_error.format('eps and momentum', 'BatchNorm2d'))
if isinstance(num_features, dict):
num_features = self.num_features
weight, bias = self.weight, self.bias
running_mean, running_var = self.running_mean, self.running_var
if num_features < self.num_features:
weight = weight[:num_features]
bias = bias[:num_features]
if running_mean is not None:
running_mean = running_mean[:num_features]
if running_var is not None:
running_var = running_var[:num_features]
if self.training:
bn_training = True
else:
bn_training = (running_mean is None) and (running_var is None)
return F.batch_norm(
inputs,
# If buffers are not to be tracked, ensure that they won't be updated
running_mean if not self.training or self.track_running_stats else None,
running_var if not self.training or self.track_running_stats else None,
weight,
bias,
bn_training,
momentum, # originally exponential_average_factor in pytorch code
eps,
)
class MixedLayerNorm(MixedOperation, nn.LayerNorm):
"""
Mixed LayerNorm operation.
Supported arguments are:
- ``normalized_shape``
- ``eps`` (only supported in path sampling)
For path-sampling, prefix of ``weight`` and ``bias`` are sliced.
For weighted cases, the maximum ``normalized_shape`` is used directly.
eps is required to be float.
"""
bound_type = nas_nn.LayerNorm
argument_list = ['normalized_shape', 'eps']
@staticmethod
def _to_tuple(value: scalar_or_scalar_dict[Any]) -> tuple[Any, Any]:
if not isinstance(value, tuple):
return (value, value)
return value
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
if name not in ['normalized_shape']:
raise NotImplementedError(f'Unsupported value choice on argument: {name}')
all_sizes = set(traverse_all_options(value_choice))
if any(isinstance(sz, (tuple, list)) for sz in all_sizes):
# transpose
all_sizes = list(zip(*all_sizes))
# maximum dim should be calculated on every dimension
return (max(self._to_tuple(sz)) for sz in all_sizes)
else:
return max(all_sizes)
def forward_with_args(self,
normalized_shape,
eps: float,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [eps]):
raise ValueError(_diff_not_compatible_error.format('eps', 'LayerNorm'))
if isinstance(normalized_shape, dict):
normalized_shape = self.normalized_shape
# make it as tuple
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape, )
if isinstance(self.normalized_shape, int):
normalized_shape = (self.normalized_shape, )
# slice all the normalized shape
indices = [slice(0, min(i, j)) for i, j in zip(normalized_shape, self.normalized_shape)]
# remove _S(*)
weight = self.weight[indices] if self.weight is not None else None
bias = self.bias[indices] if self.bias is not None else None
return F.layer_norm(
inputs,
normalized_shape,
weight,
bias,
eps
)
class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
"""
Mixed multi-head attention.
Supported arguments are:
- ``embed_dim``
- ``num_heads`` (only supported in path sampling)
- ``kdim``
- ``vdim``
- ``dropout`` (only supported in path sampling)
At init, it constructs the largest possible Q, K, V dimension.
At forward, it slices the prefix to weight matrices according to the sampled value.
For ``in_proj_bias`` and ``in_proj_weight``, three parts will be sliced and concatenated together:
``[0, embed_dim)``, ``[max_embed_dim, max_embed_dim + embed_dim)``,
``[max_embed_dim * 2, max_embed_dim * 2 + embed_dim)``.
Warnings
----------
All candidates of ``embed_dim`` should be divisible by all candidates of ``num_heads``.
"""
bound_type = nas_nn.MultiheadAttention
argument_list = ['embed_dim', 'num_heads', 'kdim', 'vdim', 'dropout']
def __post_init__(self):
# sometimes super-class believes qkv have the same embed_dim.
# but actually they do not, because we can have dynamic (mutable) kdim/vdim.
_qkv_same_embed_dim = True
for dimension in ['kdim', 'vdim']:
if self.init_arguments[dimension] is None:
# must follow embed_dim is this case
continue
if getattr(self, dimension) == self.embed_dim and \
(dimension in self.mutable_arguments or 'embed_dim' in self.mutable_arguments):
_qkv_same_embed_dim = False
if self._qkv_same_embed_dim and not _qkv_same_embed_dim:
self._qkv_same_embed_dim = _qkv_same_embed_dim
# adding back missing parameters
# factory_kwargs could be empty for legacy pytorch versions
factory_kwargs = {}
if 'device' in self.init_arguments:
factory_kwargs['device'] = self.init_arguments['device']
if 'dtype' in self.init_arguments:
factory_kwargs['dtype'] = self.init_arguments['dtype']
self.q_proj_weight = nn.Parameter(torch.empty((self.embed_dim, self.embed_dim), **factory_kwargs))
self.k_proj_weight = nn.Parameter(torch.empty((self.embed_dim, self.kdim), **factory_kwargs))
self.v_proj_weight = nn.Parameter(torch.empty((self.embed_dim, self.vdim), **factory_kwargs))
self.register_parameter('in_proj_weight', None)
# reset parameters
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def _to_proj_slice(self, embed_dim: _W) -> list[slice]:
# slice three parts, corresponding to q, k, v respectively
return [
slice(embed_dim),
slice(self.embed_dim, self.embed_dim + embed_dim),
slice(self.embed_dim * 2, self.embed_dim * 2 + embed_dim)
]
def forward_with_args(
self,
embed_dim: int_or_int_dict, num_heads: int,
kdim: int_or_int_dict | None, vdim: int_or_int_dict | None,
dropout: float,
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask: torch.Tensor | None = None,
need_weights: bool = True, attn_mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor | None]:
if any(isinstance(arg, dict) for arg in [num_heads, dropout]):
raise ValueError(_diff_not_compatible_error.format('num_heads and dropout', 'MultiHeadAttention'))
# by default, kdim, vdim can be none
if kdim is None:
kdim = embed_dim
if vdim is None:
vdim = embed_dim
qkv_same_embed_dim = kdim == embed_dim and vdim == embed_dim
if getattr(self, 'batch_first', False):
# for backward compatibility: v1.7 doesn't have batch_first
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if isinstance(embed_dim, dict):
used_embed_dim = self.embed_dim
else:
used_embed_dim = embed_dim
embed_dim_ = _W(embed_dim)
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias: Tensor | None = None
in_proj_weight: Tensor | None = None
if self.in_proj_bias is not None:
in_proj_bias = _S(cast(Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim_)]
if self.in_proj_weight is not None:
in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[self._to_proj_slice(embed_dim_), :embed_dim_]
bias_k = _S(cast(Tensor, self.bias_k))[:, :, :embed_dim_] if self.bias_k is not None else None
bias_v = _S(cast(Tensor, self.bias_v))[:, :, :embed_dim_] if self.bias_v is not None else None
out_proj_weight = _S(cast(Tensor, self.out_proj.weight))[:embed_dim_, :embed_dim_]
out_proj_bias = _S(cast(Tensor, self.out_proj.bias))[:embed_dim_] if self.out_proj.bias is not None else None
if not qkv_same_embed_dim:
q_proj = _S(cast(Tensor, self.q_proj_weight))[:embed_dim_, :embed_dim_]
k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim_]
k_proj = _S(k_proj)[:, :_W(kdim)]
v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim_]
v_proj = _S(v_proj)[:, :_W(vdim)]
# The rest part is basically same as pytorch
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads,
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=q_proj, k_proj_weight=k_proj, v_proj_weight=v_proj)
else:
# Cast tensor here because of a bug in pytorch stub
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads,
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask)
if getattr(self, 'batch_first', False): # backward compatibility
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [
MixedLinear,
MixedConv2d,
MixedBatchNorm2d,
MixedLayerNorm,
MixedMultiHeadAttention,
]
# For the supported operations to be properly rendered in documentation
NATIVE_SUPPORTED_OP_NAMES: list[str] = [op.bound_type.__name__ for op in NATIVE_MIXED_OPERATIONS]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Implementation of ProxylessNAS: a hyrbid approach between differentiable and sampling.
The support remains limited. Known limitations include:
- No support for multiple arguments in forward.
- No support for mixed-operation (value choice).
- The code contains duplicates. Needs refactor.
"""
from __future__ import annotations
from typing import cast
import torch
import torch.nn as nn
from .differentiable import DifferentiableMixedLayer, DifferentiableMixedInput
__all__ = ['ProxylessMixedLayer', 'ProxylessMixedInput']
class _ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = x.detach()
detached_x.requires_grad = x.requires_grad
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class ProxylessMixedLayer(DifferentiableMixedLayer):
"""Proxyless version of differentiable mixed layer.
It resamples a single-path every time, rather than go through the softmax.
"""
_arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(paths, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(len(paths)) * 1E-3)
# like sampling-based methods, it has a ``_sampled``.
self._sampled: str | None = None
self._sample_idx: int | None = None
def forward(self, *args, **kwargs):
def run_function(ops, active_id, **kwargs):
def forward(_x):
return ops[active_id](_x, **kwargs)
return forward
def backward_function(ops, active_id, binary_gates, **kwargs):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data, **kwargs)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1, 'ProxylessMixedLayer only supports exactly one input argument.'
x = args[0]
assert self._sampled is not None, 'Need to call resample() before running fprop.'
list_ops = [getattr(self, op) for op in self.op_names]
return _ArchGradientFunction.apply(
x, self._binary_gates, run_function(list_ops, self._sample_idx, **kwargs),
backward_function(list_ops, self._sample_idx, self._binary_gates, **kwargs)
)
def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo."""
if self.label in memo:
self._sampled = memo[self.label]
self._sample_idx = self.op_names.index(self._sampled)
else:
probs = self._softmax(self._arch_alpha)
self._sample_idx = int(torch.multinomial(probs, 1)[0].item())
self._sampled = self.op_names[self._sample_idx]
# set binary gates
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[self._sample_idx] = 1.0
return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(len(self._arch_alpha)):
for j in range(len(self._arch_alpha)):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessMixedInput(DifferentiableMixedInput):
"""Proxyless version of differentiable input choice.
See :class:`ProxylessLayerChoice` for implementation details.
"""
_arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(n_candidates, n_chosen, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(n_candidates) * 1E-3)
self._sampled: int | None = None
def forward(self, inputs):
def run_function(active_sample):
return lambda x: x[active_sample]
def backward_function(binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(self.n_candidates):
out_k = _x[k].data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
inputs = torch.stack(inputs, 0)
assert self._sampled is not None, 'Need to call resample() before running fprop.'
return _ArchGradientFunction.apply(
inputs, self._binary_gates, run_function(self._sampled),
backward_function(self._binary_gates)
)
def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo."""
if self.label in memo:
self._sampled = memo[self.label]
else:
probs = self._softmax(self._arch_alpha)
sample = torch.multinomial(probs, 1)[0].item()
self._sampled = int(sample)
# set binary gates
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[cast(int, self._sampled)] = 1.0
return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: torch.argmax(self._arch_alpha).item()}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(self.n_candidates):
for j in range(self.n_candidates):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment