"docs/source/reference/vscode:/vscode.git/clone" did not exist on "bbf54a8835811f96bd1e4dc4c2669f94be0bf264"
Unverified Commit c80bda29 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Support Repeat and Cell in One-shot NAS (#4835)

parent c54a07df
......@@ -247,7 +247,8 @@ class _SupervisedLearningModule(LightningModule):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self):
if self.running_mode == 'multi':
if not self.trainer.sanity_checking and self.running_mode == 'multi':
# Don't report metric when sanity checking
nni.report_intermediate_result(self._get_validation_metrics())
def on_fit_end(self):
......
......@@ -30,9 +30,51 @@ class _DefaultPostprocessor(nn.Module):
return this_cell
_cell_op_factory_type = Callable[[int, int, Optional[int]], nn.Module]
CellOpFactory = Callable[[int, int, Optional[int]], nn.Module]
def create_cell_op_candidates(
op_candidates, node_index, op_index, chosen
) -> Tuple[Dict[str, nn.Module], bool]:
has_factory = False
# convert the complex type into the type that is acceptable to LayerChoice
def convert_single_op(op):
nonlocal has_factory
if isinstance(op, nn.Module):
return copy.deepcopy(op)
elif callable(op):
# Yes! It's using factory to create operations now.
has_factory = True
# FIXME: I don't know how to check whether we are in graph engine.
return op(node_index, op_index, chosen)
else:
raise TypeError(f'Unrecognized type {type(op)} for op {op}')
if isinstance(op_candidates, list):
res = {str(i): convert_single_op(op) for i, op in enumerate(op_candidates)}
elif isinstance(op_candidates, dict):
res = {key: convert_single_op(op) for key, op in op_candidates.items()}
elif callable(op_candidates):
warnings.warn(f'Directly passing a callable into Cell is deprecated. Please consider migrating to list or dict.',
DeprecationWarning)
res = op_candidates()
has_factory = True
else:
raise TypeError(f'Unrecognized type {type(op_candidates)} for {op_candidates}')
return res, has_factory
def preprocess_cell_inputs(num_predecessors: int, *inputs: Union[List[torch.Tensor], torch.Tensor]) -> List[torch.Tensor]:
if len(inputs) == 1 and isinstance(inputs[0], list):
processed_inputs = list(inputs[0]) # shallow copy
else:
processed_inputs = cast(List[torch.Tensor], list(inputs))
assert len(processed_inputs) == num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
return processed_inputs
class Cell(nn.Module):
"""
Cell structure that is popularly used in NAS literature.
......@@ -108,6 +150,9 @@ class Cell(nn.Module):
The index are enumerated for all nodes including predecessors from 0.
When first created, the input index is ``None``, meaning unknown.
Note that in graph execution engine, support of function in ``op_candidates`` is limited.
Please also note that, to make :class:`Cell` work with one-shot strategy,
``op_candidates``, in case it's a callable, should not depend on the second input argument,
i.e., ``op_index`` in current node.
num_nodes : int
Number of nodes in the cell.
num_ops_per_node: int
......@@ -191,15 +236,19 @@ class Cell(nn.Module):
When ``merge_op`` is ``loose_end``, ``output_node_indices`` is useful to compute the shape of this cell's output,
because the output shape depends on the connection in the cell, and which nodes are "loose ends" depends on mutation.
op_candidates_factory : CellOpFactory or None
If the operations are created with a factory (callable), this is to be set with the factory.
One-shot algorithms will use this to make each node a cartesian product of operations and inputs.
"""
def __init__(self,
op_candidates: Union[
Callable[[], List[nn.Module]],
List[nn.Module],
List[_cell_op_factory_type],
List[CellOpFactory],
Dict[str, nn.Module],
Dict[str, _cell_op_factory_type]
Dict[str, CellOpFactory]
],
num_nodes: int,
num_ops_per_node: int = 1,
......@@ -232,6 +281,8 @@ class Cell(nn.Module):
self.concat_dim = concat_dim
self.op_candidates_factory: Union[List[CellOpFactory], Dict[str, CellOpFactory], None] = None # set later
# fill-in the missing modules
self._create_modules(op_candidates)
......@@ -253,7 +304,9 @@ class Cell(nn.Module):
# this is needed because op_candidates can be very complex
# the type annoation and docs for details
ops = self._convert_op_candidates(op_candidates, i, k, chosen)
ops, has_factory = create_cell_op_candidates(op_candidates, i, k, chosen)
if has_factory:
self.op_candidates_factory = op_candidates
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
cast(ModuleList, self.ops[-1]).append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
......@@ -279,12 +332,7 @@ class Cell(nn.Module):
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell.
"""
processed_inputs: List[torch.Tensor]
if len(inputs) == 1 and isinstance(inputs[0], list):
processed_inputs = list(inputs[0]) # shallow copy
else:
processed_inputs = cast(List[torch.Tensor], list(inputs))
assert len(processed_inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
processed_inputs: List[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: List[torch.Tensor] = self.preprocessor(processed_inputs)
for ops, inps in zip(
cast(Sequence[Sequence[LayerChoice]], self.ops),
......@@ -301,26 +349,3 @@ class Cell(nn.Module):
else:
this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
@staticmethod
def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]:
# convert the complex type into the type that is acceptable to LayerChoice
def convert_single_op(op):
if isinstance(op, nn.Module):
return copy.deepcopy(op)
elif callable(op):
# FIXME: I don't know how to check whether we are in graph engine.
return op(node_index, op_index, chosen)
else:
raise TypeError(f'Unrecognized type {type(op)} for op {op}')
if isinstance(op_candidates, list):
return [convert_single_op(op) for op in op_candidates]
elif isinstance(op_candidates, dict):
return {key: convert_single_op(op) for key, op in op_candidates.items()}
elif callable(op_candidates):
warnings.warn(f'Directly passing a callable into Cell is deprecated. Please consider migrating to list or dict.',
DeprecationWarning)
return op_candidates()
else:
raise TypeError(f'Unrecognized type {type(op_candidates)} for {op_candidates}')
......@@ -106,7 +106,7 @@ class Repeat(Mutable):
'In repeat, `depth` is already a ValueChoice, but `label` is still set. It will be ignored.',
RuntimeWarning
)
self.depth_choice = depth
self.depth_choice: Union[int, ChoiceOf[int]] = depth
all_values = list(self.depth_choice.all_options())
self.min_depth = min(all_values)
self.max_depth = max(all_values)
......@@ -117,12 +117,12 @@ class Repeat(Mutable):
elif isinstance(depth, tuple):
self.min_depth = depth if isinstance(depth, int) else depth[0]
self.max_depth = depth if isinstance(depth, int) else depth[1]
self.depth_choice = ValueChoice(list(range(self.min_depth, self.max_depth + 1)), label=label)
self.depth_choice: Union[int, ChoiceOf[int]] = ValueChoice(list(range(self.min_depth, self.max_depth + 1)), label=label)
self._label = self.depth_choice.label
elif isinstance(depth, int):
self.min_depth = self.max_depth = depth
self.depth_choice = depth
self.depth_choice: Union[int, ChoiceOf[int]] = depth
else:
raise TypeError(f'Unsupported "depth" type: {type(depth)}')
assert self.max_depth >= self.min_depth >= 0 and self.max_depth >= 1, f'Depth of {self.min_depth} to {self.max_depth} is invalid.'
......
......@@ -59,7 +59,8 @@ def traverse_and_mutate_submodules(
module_list = []
def apply(m):
for name, child in m.named_children():
# Need to call list() here because the loop body might replace some children in-place.
for name, child in list(m.named_children()):
# post-order DFS
if not topdown:
apply(child)
......@@ -94,6 +95,8 @@ def traverse_and_mutate_submodules(
break
if isinstance(mutate_result, BaseSuperNetModule):
# Replace child with the mutate result, and DFS this one
child = mutate_result
module_list.append(mutate_result)
# pre-order DFS
......@@ -112,9 +115,9 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k
primitive_list = (
nas_nn.LayerChoice,
nas_nn.InputChoice,
nas_nn.ValueChoice,
nas_nn.Repeat,
nas_nn.NasBench101Cell,
# nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet
)
......@@ -321,9 +324,9 @@ class BaseOneShotLightningModule(pl.LightningModule):
# under v1.5
w_optimizers, lr_schedulers, self.frequencies, monitor = \
self.trainer._configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization) # type: ignore
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization) # type: ignore
if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers): # type: ignore
if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers): # type: ignore
raise Exception(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
......
......@@ -12,7 +12,8 @@ import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.differentiable import (
DifferentiableMixedLayer, DifferentiableMixedInput,
MixedOpDifferentiablePolicy, GumbelSoftmax
MixedOpDifferentiablePolicy, GumbelSoftmax,
DifferentiableMixedCell, DifferentiableMixedRepeat
)
from .supermodule.proxyless import ProxylessMixedInput, ProxylessMixedLayer
from .supermodule.operation import NATIVE_MIXED_OPERATIONS
......@@ -52,6 +53,8 @@ class DartsLightningModule(BaseOneShotLightningModule):
hooks = [
DifferentiableMixedLayer.mutate,
DifferentiableMixedInput.mutate,
DifferentiableMixedCell.mutate,
DifferentiableMixedRepeat.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
......@@ -182,16 +185,6 @@ class GumbelDartsLightningModule(DartsLightningModule):
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with gumbel-differentiable versions"""
hooks = [
DifferentiableMixedLayer.mutate,
DifferentiableMixedInput.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
return hooks
def mutate_kwargs(self):
"""Use gumbel softmax."""
return {
......
......@@ -12,7 +12,10 @@ import torch.nn as nn
import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.sampling import PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy
from .supermodule.sampling import (
PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy,
PathSamplingCell, PathSamplingRepeat
)
from .supermodule.operation import NATIVE_MIXED_OPERATIONS
from .enas import ReinforceController, ReinforceField
......@@ -43,6 +46,8 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
hooks = [
PathSamplingLayer.mutate,
PathSamplingInput.mutate,
PathSamplingRepeat.mutate,
PathSamplingCell.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
......
......@@ -4,20 +4,26 @@
from __future__ import annotations
import functools
import logging
import warnings
from typing import Any, cast
from typing import Any, Dict, Sequence, List, Tuple, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import preprocess_cell_inputs
from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy
from ._valuechoice_utils import traverse_all_options
from .sampling import PathSamplingCell
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices
_logger = logging.getLogger(__name__)
class GumbelSoftmax(nn.Softmax):
......@@ -284,10 +290,10 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
return {}
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is also random for each leaf value choice."""
"""Export is argmax for each leaf value choice."""
result = {}
for name, spec in operation.search_space_spec().items():
if name in result:
if name in memo:
continue
chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item())
result[name] = spec.values[chosen_index]
......@@ -300,3 +306,199 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
}
return dict(traverse_all_options(operation.mutable_arguments[name], weights=weights))
return operation.init_arguments[name]
class DifferentiableMixedRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int], softmax: nn.Module, memo: dict[str, Any]):
super().__init__()
self.blocks = blocks
self.depth = depth
self._softmax = softmax
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._arch_alpha = nn.ParameterDict()
for name, spec in self._space_spec.items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
self._arch_alpha[name] = alpha
def resample(self, memo):
"""Do nothing."""
return {}
def export(self, memo):
"""Choose argmax for each leaf value choice."""
result = {}
for name, spec in self._space_spec.items():
if name in memo:
continue
chosen_index = int(torch.argmax(self._arch_alpha[name]).item())
result[name] = spec.values[chosen_index]
return result
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
# Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(cast(List[nn.Module], module.blocks), module.depth_choice, softmax, memo)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
def forward(self, x):
weights: dict[str, torch.Tensor] = {
label: self._softmax(alpha) for label, alpha in self._arch_alpha.items()
}
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights)))
res: torch.Tensor | None = None
for i, block in enumerate(self.blocks, start=1): # start=1 because depths are 1, 2, 3, 4...
x = block(x)
if i in depth_weights:
if res is None:
res = depth_weights[i] * x
else:
res = res + depth_weights[i] * x
return res
class DifferentiableMixedCell(PathSamplingCell):
"""Implementation of Cell under differentiable context.
An architecture parameter is created on each edge of the full-connected graph.
"""
# TODO: It inherits :class:`PathSamplingCell` to reduce some duplicated code.
# Possibly need another refactor here.
def __init__(
self, op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor, concat_dim,
memo, mutate_kwargs, label
):
super().__init__(
op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor,
concat_dim, memo, mutate_kwargs, label
)
self._arch_alpha = nn.ParameterDict()
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for j in range(i):
edge_label = f'{label}/{i}_{j}'
op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j]
if edge_label in memo:
alpha = memo[edge_label]
if len(alpha) != len(op):
raise ValueError(
f'Architecture parameter size of same label {edge_label} conflict: '
f'{len(alpha)} vs. {len(op)}'
)
else:
alpha = nn.Parameter(torch.randn(len(op)) * 1E-3)
self._arch_alpha[edge_label] = alpha
self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
def resample(self, memo):
"""Differentiable doesn't need to resample."""
return {}
def export(self, memo):
"""Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
"""
exported = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
# Tuple of (weight, input_index, op_name)
all_weights: list[tuple[float, int, str]] = []
for j in range(i):
for k, name in enumerate(self.op_names):
all_weights.append((
float(self._arch_alpha[f'{self.label}/{i}_{j}'][k].item()),
j, name,
))
all_weights.sort(reverse=True)
# We first prefer inputs from different input_index.
# If we have got no other choices, we start to accept duplicates.
# Therefore we gather first occurrences of distinct input_index to the front.
first_occurrence_index: list[int] = [
all_weights.index( # The index of
next(filter(lambda t: t[1] == j, all_weights)) # First occurence of j
)
for j in range(i) # For j < i
]
first_occurrence_index.sort() # Keep them ordered too.
all_weights = [all_weights[k] for k in first_occurrence_index] + \
[w for j, w in enumerate(all_weights) if j not in first_occurrence_index]
_logger.info('Sorted weights in differentiable cell export (node %d): %s', i, all_weights)
for k in range(self.num_ops_per_node):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
_, j, op_name = all_weights[k % len(all_weights)]
exported[f'{self.label}/op_{i}_{k}'] = op_name
exported[f'{self.label}/input_{i}_{k}'] = j
return exported
def forward(self, *inputs: list[torch.Tensor] | torch.Tensor) -> tuple[torch.Tensor, ...] | torch.Tensor:
processed_inputs: list[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: list[torch.Tensor] = self.preprocessor(processed_inputs)
for i, ops in enumerate(cast(Sequence[Sequence[Dict[str, nn.Module]]], self.ops), start=self.num_predecessors):
current_state = []
for j in range(i): # for every previous tensors
op_results = torch.stack([op(states[j]) for op in ops[j].values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0)
current_state.append(edge_sum)
states.append(sum(current_state)) # type: ignore
# Always merge all
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
......@@ -3,17 +3,20 @@
from __future__ import annotations
import copy
import random
from typing import Any
from typing import Any, List, Dict, Sequence, cast
import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from .base import BaseSuperNetModule
from ._valuechoice_utils import evaluate_value_choice_with_dict
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices
from .operation import MixedOperationSamplingPolicy, MixedOperation
......@@ -198,3 +201,200 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
if name in operation.mutable_arguments:
return self._sampled[name]
return operation.init_arguments[name]
class PathSamplingRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a path-sampling supernet.
Samples one / some of the prefixes of the repeated blocks.
Attributes
----------
_sampled : int or list of int
Sampled depth.
"""
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int]):
super().__init__()
self.blocks = blocks
self.depth = depth
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._sampled: list[int] | int | None = None
def resample(self, memo):
"""Since depth is based on ValueChoice, we only need to randomly sample every leaf value choices."""
result = {}
for label in self._space_spec:
if label in memo:
result[label] = memo[label]
else:
result[label] = random.choice(self._space_spec[label].values)
self._sampled = evaluate_value_choice_with_dict(self.depth, result)
return result
def export(self, memo):
"""Random choose one if every choice not in memo."""
result = {}
for label in self._space_spec:
if label not in memo:
result[label] = random.choice(self._space_spec[label].values)
return result
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
# Only interesting when depth is mutable
return cls(cast(List[nn.Module], module.blocks), module.depth_choice)
def forward(self, x):
if self._sampled is None:
raise RuntimeError('At least one depth needs to be sampled before fprop.')
if isinstance(self._sampled, list):
res = []
for i, block in enumerate(self.blocks):
x = block(x)
if i in self._sampled:
res.append(x)
return sum(res)
else:
for block in self.blocks[:self._sampled]:
x = block(x)
return x
class PathSamplingCell(BaseSuperNetModule):
"""The implementation of super-net cell follows `DARTS <https://github.com/quark0/darts>`__.
When ``factory_used`` is true, it reconstructs the cell for every possible combination of operation and input index,
because for different input index, the cell factory could instantiate different operations (e.g., with different stride).
On export, we first have best (operation, input) pairs, the select the best ``num_ops_per_node``.
``loose_end`` is not supported yet, because it will cause more problems (e.g., shape mismatch).
We assumes ``loose_end`` to be ``all`` regardless of its configuration.
A supernet cell can't slim its own weight to fit into a sub network, which is also a known issue.
"""
def __init__(
self,
op_factory: list[CellOpFactory] | dict[str, CellOpFactory],
num_nodes: int,
num_ops_per_node: int,
num_predecessors: int,
preprocessor: Any,
postprocessor: Any,
concat_dim: int,
memo: dict, # although not used here, useful in subclass
mutate_kwargs: dict, # same as memo
label: str,
):
super().__init__()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
self.num_predecessors = num_predecessors
self.preprocessor = preprocessor
self.ops = nn.ModuleList()
self.postprocessor = postprocessor
self.concat_dim = concat_dim
self.op_names: list[str] = cast(List[str], None)
self.output_node_indices = list(range(self.num_predecessors, self.num_nodes + self.num_predecessors))
# Create a fully-connected graph.
# Each edge is a ModuleDict with op candidates.
# Can not reuse LayerChoice here, because the spec, resample, export all need to be customized.
# InputChoice is implicit in this graph.
for i in self.output_node_indices:
self.ops.append(nn.ModuleList())
for k in range(i + self.num_predecessors):
# Second argument in (i, **0**, k) is always 0.
# One-shot strategy can't handle the cases where op spec is dependent on `op_index`.
ops, _ = create_cell_op_candidates(op_factory, i, 0, k)
self.op_names = list(ops.keys())
cast(nn.ModuleList, self.ops[-1]).append(nn.ModuleDict(ops))
self.label = label
self._sampled: dict[str, str | int] = {}
def search_space_spec(self) -> dict[str, ParameterSpec]:
# TODO: Recreating the space here.
# The spec should be moved to definition of Cell itself.
space_spec = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for k in range(self.num_ops_per_node):
op_label = f'{self.label}/op_{i}_{k}'
input_label = f'{self.label}/input_{i}_{k}'
space_spec[op_label] = ParameterSpec(op_label, 'choice', self.op_names, (op_label,), True, size=len(self.op_names))
space_spec[input_label] = ParameterSpec(input_label, 'choice', list(range(i)), (input_label, ), True, size=i)
return space_spec
def resample(self, memo):
"""Random choose one path if label is not found in memo."""
self._sampled = {}
new_sampled = {}
for label, param_spec in self.search_space_spec().items():
if label in memo:
assert not isinstance(memo[label], list), 'Multi-path sampling is currently unsupported on cell.'
self._sampled[label] = memo[label]
else:
self._sampled[label] = new_sampled[label] = random.choice(param_spec.values)
return new_sampled
def export(self, memo):
"""Randomly choose one to export."""
return self.resample(memo)
def forward(self, *inputs: list[torch.Tensor] | torch.Tensor) -> tuple[torch.Tensor, ...] | torch.Tensor:
processed_inputs: List[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: List[torch.Tensor] = self.preprocessor(processed_inputs)
for i, ops in enumerate(cast(Sequence[Sequence[Dict[str, nn.Module]]], self.ops), start=self.num_predecessors):
current_state = []
for k in range(self.num_ops_per_node):
# Select op list based on the input chosen
input_index = self._sampled[f'{self.label}/input_{i}_{k}']
op_candidates = ops[cast(int, input_index)]
# Select op from op list based on the op chosen
op_index = self._sampled[f'{self.label}/op_{i}_{k}']
op = op_candidates[cast(str, op_index)]
current_state.append(op(states[cast(int, input_index)]))
states.append(sum(current_state)) # type: ignore
# Always merge all
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Cell):
op_factory = None # not all the cells need to be replaced
if module.op_candidates_factory is not None:
op_factory = module.op_candidates_factory
assert isinstance(op_factory, list) or isinstance(op_factory, dict), \
'Only support op_factory of type list or dict.'
elif module.merge_op == 'loose_end':
op_candidates_lc = module.ops[-1][-1] # type: ignore
assert isinstance(op_candidates_lc, LayerChoice)
op_factory = { # create a factory
name: lambda _, __, ___: copy.deepcopy(op_candidates_lc[name])
for name in op_candidates_lc.names
}
if op_factory is not None:
return cls(
op_factory,
module.num_nodes,
module.num_ops_per_node,
module.num_predecessors,
module.preprocessor,
module.postprocessor,
module.concat_dim,
memo,
mutate_kwargs,
module.label
)
from typing import List, Tuple
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
@model_wrapper
class CellSimple(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell(x, y)
@model_wrapper
class CellDefaultArgs(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
def forward(self, x):
return self.cell(x)
class CellPreprocessor(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 16)
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return [self.linear(x[0]), x[1]]
class CellPostprocessor(nn.Module):
def forward(self, this: torch.Tensor, prev: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return prev[-1], this
@model_wrapper
class CellCustomProcessor(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': nn.Linear(16, 16),
'second': nn.Linear(16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2,
preprocessor=CellPreprocessor(), postprocessor=CellPostprocessor(), merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
@model_wrapper
class CellLooseEnd(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end')
def forward(self, x, y):
return self.cell([x, y])
@model_wrapper
class CellOpFactory(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16),
'second': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
......@@ -23,6 +23,10 @@ from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process
from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack, NoContextError, original_state_dict_hooks
from .models import (
CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory
)
class EnumerateSampler(Sampler):
def __init__(self):
......@@ -924,17 +928,7 @@ class Python(GraphIR):
model = Net()
def test_cell(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell(x, y)
raw_model, mutators = self._get_model_with_mutators(Net())
raw_model, mutators = self._get_model_with_mutators(CellSimple())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
......@@ -943,16 +937,7 @@ class Python(GraphIR):
self.assertTrue(self._get_converted_pytorch_model(model)(
torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64]))
@model_wrapper
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
def forward(self, x):
return self.cell(x)
raw_model, mutators = self._get_model_with_mutators(Net2())
raw_model, mutators = self._get_model_with_mutators(CellDefaultArgs())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
......@@ -961,34 +946,7 @@ class Python(GraphIR):
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_cell_predecessors(self):
from typing import List, Tuple
class Preprocessor(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 16)
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return [self.linear(x[0]), x[1]]
class Postprocessor(nn.Module):
def forward(self, this: torch.Tensor, prev: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return prev[-1], this
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': nn.Linear(16, 16),
'second': nn.Linear(16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2,
preprocessor=Preprocessor(), postprocessor=Postprocessor(), merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
raw_model, mutators = self._get_model_with_mutators(CellCustomProcessor())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
......@@ -1000,17 +958,7 @@ class Python(GraphIR):
self.assertTrue(result[1].size() == torch.Size([1, 64]))
def test_cell_loose_end(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
raw_model, mutators = self._get_model_with_mutators(CellLooseEnd())
any_not_all = False
for _ in range(10):
sampler = EnumerateSampler()
......@@ -1026,19 +974,7 @@ class Python(GraphIR):
self.assertTrue(any_not_all)
def test_cell_complex(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16),
'second': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
raw_model, mutators = self._get_model_with_mutators(CellOpFactory())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
......
......@@ -137,6 +137,31 @@ class RepeatNet(nn.Module):
return F.log_softmax(x, dim=1)
@model_wrapper
class CellNet(nn.Module):
def __init__(self):
super().__init__()
self.stem = nn.Conv2d(1, 5, 7, stride=4)
self.cells = nn.Repeat(
lambda index: nn.Cell({
'conv1': lambda _, __, inp: nn.Conv2d(
(5 if index == 0 else 3 * 4) if inp is not None and inp < 1 else 4, 4, 1
),
'conv2': lambda _, __, inp: nn.Conv2d(
(5 if index == 0 else 3 * 4) if inp is not None and inp < 1 else 4, 4, 3, padding=1
),
}, 3, merge_op='loose_end'), (1, 3)
)
self.fc = nn.Linear(3 * 4, 10)
def forward(self, x):
x = self.stem(x)
x = self.cells(x)
x = torch.mean(x, (2, 3))
x = self.fc(x)
return F.log_softmax(x, dim=1)
@basic_unit
class MyOp(nn.Module):
def __init__(self, some_ch):
......@@ -183,6 +208,8 @@ def _mnist_net(type_, evaluator_kwargs):
base_model = ValueChoiceConvNet()
elif type_ == 'repeat':
base_model = RepeatNet()
elif type_ == 'cell':
base_model = CellNet()
elif type_ == 'custom_op':
base_model = CustomOpValueChoiceNet()
else:
......@@ -246,7 +273,7 @@ def _test_strategy(strategy_, support_value_choice=True, multi_gpu=False):
(_mnist_net('simple', evaluator_kwargs), True),
(_mnist_net('simple_value_choice', evaluator_kwargs), support_value_choice),
(_mnist_net('value_choice', evaluator_kwargs), support_value_choice),
(_mnist_net('repeat', evaluator_kwargs), False), # no strategy supports repeat currently
(_mnist_net('repeat', evaluator_kwargs), support_value_choice), # no strategy supports repeat currently
(_mnist_net('custom_op', evaluator_kwargs), False), # this is definitely a NO
(_multihead_attention_net(evaluator_kwargs), support_value_choice),
]
......
......@@ -4,17 +4,23 @@ import numpy as np
import torch
import torch.nn as nn
from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, Linear, MultiheadAttention
from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import (
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax,
DifferentiableMixedRepeat, DifferentiableMixedCell
)
from nni.retiarii.oneshot.pytorch.supermodule.sampling import (
MixedOpPathSamplingPolicy, PathSamplingLayer, PathSamplingInput
MixedOpPathSamplingPolicy, PathSamplingLayer, PathSamplingInput, PathSamplingRepeat, PathSamplingCell
)
from nni.retiarii.oneshot.pytorch.supermodule.operation import MixedConv2d, NATIVE_MIXED_OPERATIONS
from nni.retiarii.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput
from nni.retiarii.oneshot.pytorch.supermodule._operation_utils import Slicable as S, MaybeWeighted as W
from nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils import *
from .models import (
CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory
)
def test_slice():
weight = np.ones((3, 7, 24, 23))
......@@ -246,3 +252,113 @@ def test_proxyless_layer_input():
assert input.resample({})['ddd'] in list(range(5))
assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2])
assert input.export({})['ddd'] in list(range(5))
def test_pathsampling_repeat():
op = PathSamplingRepeat([nn.Linear(16, 16), nn.Linear(16, 8), nn.Linear(8, 4)], ValueChoice([1, 2, 3], label='ccc'))
sample = op.resample({})
assert sample['ccc'] in [1, 2, 3]
for i in range(1, 4):
op.resample({'ccc': i})
out = op(torch.randn(2, 16))
assert out.shape[1] == [16, 8, 4][i - 1]
op = PathSamplingRepeat([nn.Linear(i + 1, i + 2) for i in range(7)], 2 * ValueChoice([1, 2, 3], label='ddd') + 1)
sample = op.resample({})
assert sample['ddd'] in [1, 2, 3]
for i in range(1, 4):
op.resample({'ddd': i})
out = op(torch.randn(2, 1))
assert out.shape[1] == (2 * i + 1) + 1
def test_differentiable_repeat():
op = DifferentiableMixedRepeat(
[nn.Linear(8 if i == 0 else 16, 16) for i in range(4)],
ValueChoice([0, 1], label='ccc') * 2 + 1,
GumbelSoftmax(-1),
{}
)
op.resample({})
assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
sample = op.export({})
assert 'ccc' in sample and sample['ccc'] in [0, 1]
def test_pathsampling_cell():
for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]:
model = cell_cls()
nas_modules = traverse_and_mutate_submodules(model, [
PathSamplingLayer.mutate,
PathSamplingInput.mutate,
PathSamplingCell.mutate,
], {})
result = {}
for module in nas_modules:
result.update(module.resample(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
result = {}
for module in nas_modules:
result.update(module.export(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
if cell_cls in [CellLooseEnd, CellOpFactory]:
assert isinstance(model.cell, PathSamplingCell)
else:
assert not isinstance(model.cell, PathSamplingCell)
inputs = {
CellSimple: (torch.randn(2, 16), torch.randn(2, 16)),
CellDefaultArgs: (torch.randn(2, 16),),
CellCustomProcessor: (torch.randn(2, 3), torch.randn(2, 16)),
CellLooseEnd: (torch.randn(2, 16), torch.randn(2, 16)),
CellOpFactory: (torch.randn(2, 3), torch.randn(2, 16)),
}[cell_cls]
output = model(*inputs)
if cell_cls == CellCustomProcessor:
assert isinstance(output, tuple) and len(output) == 2 and \
output[1].shape == torch.Size([2, 16 * model.cell.num_nodes])
else:
# no loose-end support for now
assert output.shape == torch.Size([2, 16 * model.cell.num_nodes])
def test_differentiable_cell():
for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]:
model = cell_cls()
nas_modules = traverse_and_mutate_submodules(model, [
DifferentiableMixedLayer.mutate,
DifferentiableMixedInput.mutate,
DifferentiableMixedCell.mutate,
], {})
result = {}
for module in nas_modules:
result.update(module.export(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
ctrl_params = []
for m in nas_modules:
ctrl_params += list(m.parameters(arch=True))
if cell_cls in [CellLooseEnd, CellOpFactory]:
assert len(ctrl_params) == model.cell.num_nodes * (model.cell.num_nodes + 3) // 2
assert isinstance(model.cell, DifferentiableMixedCell)
else:
assert not isinstance(model.cell, DifferentiableMixedCell)
inputs = {
CellSimple: (torch.randn(2, 16), torch.randn(2, 16)),
CellDefaultArgs: (torch.randn(2, 16),),
CellCustomProcessor: (torch.randn(2, 3), torch.randn(2, 16)),
CellLooseEnd: (torch.randn(2, 16), torch.randn(2, 16)),
CellOpFactory: (torch.randn(2, 3), torch.randn(2, 16)),
}[cell_cls]
output = model(*inputs)
if cell_cls == CellCustomProcessor:
assert isinstance(output, tuple) and len(output) == 2 and \
output[1].shape == torch.Size([2, 16 * model.cell.num_nodes])
else:
# no loose-end support for now
assert output.shape == torch.Size([2, 16 * model.cell.num_nodes])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment