Unverified Commit c80bda29 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Support Repeat and Cell in One-shot NAS (#4835)

parent c54a07df
...@@ -247,7 +247,8 @@ class _SupervisedLearningModule(LightningModule): ...@@ -247,7 +247,8 @@ class _SupervisedLearningModule(LightningModule):
return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore return self.optimizer(self.parameters(), lr=self.hparams.learning_rate, weight_decay=self.hparams.weight_decay) # type: ignore
def on_validation_epoch_end(self): def on_validation_epoch_end(self):
if self.running_mode == 'multi': if not self.trainer.sanity_checking and self.running_mode == 'multi':
# Don't report metric when sanity checking
nni.report_intermediate_result(self._get_validation_metrics()) nni.report_intermediate_result(self._get_validation_metrics())
def on_fit_end(self): def on_fit_end(self):
......
...@@ -30,9 +30,51 @@ class _DefaultPostprocessor(nn.Module): ...@@ -30,9 +30,51 @@ class _DefaultPostprocessor(nn.Module):
return this_cell return this_cell
_cell_op_factory_type = Callable[[int, int, Optional[int]], nn.Module] CellOpFactory = Callable[[int, int, Optional[int]], nn.Module]
def create_cell_op_candidates(
op_candidates, node_index, op_index, chosen
) -> Tuple[Dict[str, nn.Module], bool]:
has_factory = False
# convert the complex type into the type that is acceptable to LayerChoice
def convert_single_op(op):
nonlocal has_factory
if isinstance(op, nn.Module):
return copy.deepcopy(op)
elif callable(op):
# Yes! It's using factory to create operations now.
has_factory = True
# FIXME: I don't know how to check whether we are in graph engine.
return op(node_index, op_index, chosen)
else:
raise TypeError(f'Unrecognized type {type(op)} for op {op}')
if isinstance(op_candidates, list):
res = {str(i): convert_single_op(op) for i, op in enumerate(op_candidates)}
elif isinstance(op_candidates, dict):
res = {key: convert_single_op(op) for key, op in op_candidates.items()}
elif callable(op_candidates):
warnings.warn(f'Directly passing a callable into Cell is deprecated. Please consider migrating to list or dict.',
DeprecationWarning)
res = op_candidates()
has_factory = True
else:
raise TypeError(f'Unrecognized type {type(op_candidates)} for {op_candidates}')
return res, has_factory
def preprocess_cell_inputs(num_predecessors: int, *inputs: Union[List[torch.Tensor], torch.Tensor]) -> List[torch.Tensor]:
if len(inputs) == 1 and isinstance(inputs[0], list):
processed_inputs = list(inputs[0]) # shallow copy
else:
processed_inputs = cast(List[torch.Tensor], list(inputs))
assert len(processed_inputs) == num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
return processed_inputs
class Cell(nn.Module): class Cell(nn.Module):
""" """
Cell structure that is popularly used in NAS literature. Cell structure that is popularly used in NAS literature.
...@@ -108,6 +150,9 @@ class Cell(nn.Module): ...@@ -108,6 +150,9 @@ class Cell(nn.Module):
The index are enumerated for all nodes including predecessors from 0. The index are enumerated for all nodes including predecessors from 0.
When first created, the input index is ``None``, meaning unknown. When first created, the input index is ``None``, meaning unknown.
Note that in graph execution engine, support of function in ``op_candidates`` is limited. Note that in graph execution engine, support of function in ``op_candidates`` is limited.
Please also note that, to make :class:`Cell` work with one-shot strategy,
``op_candidates``, in case it's a callable, should not depend on the second input argument,
i.e., ``op_index`` in current node.
num_nodes : int num_nodes : int
Number of nodes in the cell. Number of nodes in the cell.
num_ops_per_node: int num_ops_per_node: int
...@@ -191,15 +236,19 @@ class Cell(nn.Module): ...@@ -191,15 +236,19 @@ class Cell(nn.Module):
When ``merge_op`` is ``loose_end``, ``output_node_indices`` is useful to compute the shape of this cell's output, When ``merge_op`` is ``loose_end``, ``output_node_indices`` is useful to compute the shape of this cell's output,
because the output shape depends on the connection in the cell, and which nodes are "loose ends" depends on mutation. because the output shape depends on the connection in the cell, and which nodes are "loose ends" depends on mutation.
op_candidates_factory : CellOpFactory or None
If the operations are created with a factory (callable), this is to be set with the factory.
One-shot algorithms will use this to make each node a cartesian product of operations and inputs.
""" """
def __init__(self, def __init__(self,
op_candidates: Union[ op_candidates: Union[
Callable[[], List[nn.Module]], Callable[[], List[nn.Module]],
List[nn.Module], List[nn.Module],
List[_cell_op_factory_type], List[CellOpFactory],
Dict[str, nn.Module], Dict[str, nn.Module],
Dict[str, _cell_op_factory_type] Dict[str, CellOpFactory]
], ],
num_nodes: int, num_nodes: int,
num_ops_per_node: int = 1, num_ops_per_node: int = 1,
...@@ -232,6 +281,8 @@ class Cell(nn.Module): ...@@ -232,6 +281,8 @@ class Cell(nn.Module):
self.concat_dim = concat_dim self.concat_dim = concat_dim
self.op_candidates_factory: Union[List[CellOpFactory], Dict[str, CellOpFactory], None] = None # set later
# fill-in the missing modules # fill-in the missing modules
self._create_modules(op_candidates) self._create_modules(op_candidates)
...@@ -253,7 +304,9 @@ class Cell(nn.Module): ...@@ -253,7 +304,9 @@ class Cell(nn.Module):
# this is needed because op_candidates can be very complex # this is needed because op_candidates can be very complex
# the type annoation and docs for details # the type annoation and docs for details
ops = self._convert_op_candidates(op_candidates, i, k, chosen) ops, has_factory = create_cell_op_candidates(op_candidates, i, k, chosen)
if has_factory:
self.op_candidates_factory = op_candidates
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created. # though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
cast(ModuleList, self.ops[-1]).append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}')) cast(ModuleList, self.ops[-1]).append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
...@@ -279,12 +332,7 @@ class Cell(nn.Module): ...@@ -279,12 +332,7 @@ class Cell(nn.Module):
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``) By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell. of some of (possibly all) the nodes' outputs in the cell.
""" """
processed_inputs: List[torch.Tensor] processed_inputs: List[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
if len(inputs) == 1 and isinstance(inputs[0], list):
processed_inputs = list(inputs[0]) # shallow copy
else:
processed_inputs = cast(List[torch.Tensor], list(inputs))
assert len(processed_inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
states: List[torch.Tensor] = self.preprocessor(processed_inputs) states: List[torch.Tensor] = self.preprocessor(processed_inputs)
for ops, inps in zip( for ops, inps in zip(
cast(Sequence[Sequence[LayerChoice]], self.ops), cast(Sequence[Sequence[LayerChoice]], self.ops),
...@@ -301,26 +349,3 @@ class Cell(nn.Module): ...@@ -301,26 +349,3 @@ class Cell(nn.Module):
else: else:
this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim) this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs) return self.postprocessor(this_cell, processed_inputs)
@staticmethod
def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]:
# convert the complex type into the type that is acceptable to LayerChoice
def convert_single_op(op):
if isinstance(op, nn.Module):
return copy.deepcopy(op)
elif callable(op):
# FIXME: I don't know how to check whether we are in graph engine.
return op(node_index, op_index, chosen)
else:
raise TypeError(f'Unrecognized type {type(op)} for op {op}')
if isinstance(op_candidates, list):
return [convert_single_op(op) for op in op_candidates]
elif isinstance(op_candidates, dict):
return {key: convert_single_op(op) for key, op in op_candidates.items()}
elif callable(op_candidates):
warnings.warn(f'Directly passing a callable into Cell is deprecated. Please consider migrating to list or dict.',
DeprecationWarning)
return op_candidates()
else:
raise TypeError(f'Unrecognized type {type(op_candidates)} for {op_candidates}')
...@@ -106,7 +106,7 @@ class Repeat(Mutable): ...@@ -106,7 +106,7 @@ class Repeat(Mutable):
'In repeat, `depth` is already a ValueChoice, but `label` is still set. It will be ignored.', 'In repeat, `depth` is already a ValueChoice, but `label` is still set. It will be ignored.',
RuntimeWarning RuntimeWarning
) )
self.depth_choice = depth self.depth_choice: Union[int, ChoiceOf[int]] = depth
all_values = list(self.depth_choice.all_options()) all_values = list(self.depth_choice.all_options())
self.min_depth = min(all_values) self.min_depth = min(all_values)
self.max_depth = max(all_values) self.max_depth = max(all_values)
...@@ -117,12 +117,12 @@ class Repeat(Mutable): ...@@ -117,12 +117,12 @@ class Repeat(Mutable):
elif isinstance(depth, tuple): elif isinstance(depth, tuple):
self.min_depth = depth if isinstance(depth, int) else depth[0] self.min_depth = depth if isinstance(depth, int) else depth[0]
self.max_depth = depth if isinstance(depth, int) else depth[1] self.max_depth = depth if isinstance(depth, int) else depth[1]
self.depth_choice = ValueChoice(list(range(self.min_depth, self.max_depth + 1)), label=label) self.depth_choice: Union[int, ChoiceOf[int]] = ValueChoice(list(range(self.min_depth, self.max_depth + 1)), label=label)
self._label = self.depth_choice.label self._label = self.depth_choice.label
elif isinstance(depth, int): elif isinstance(depth, int):
self.min_depth = self.max_depth = depth self.min_depth = self.max_depth = depth
self.depth_choice = depth self.depth_choice: Union[int, ChoiceOf[int]] = depth
else: else:
raise TypeError(f'Unsupported "depth" type: {type(depth)}') raise TypeError(f'Unsupported "depth" type: {type(depth)}')
assert self.max_depth >= self.min_depth >= 0 and self.max_depth >= 1, f'Depth of {self.min_depth} to {self.max_depth} is invalid.' assert self.max_depth >= self.min_depth >= 0 and self.max_depth >= 1, f'Depth of {self.min_depth} to {self.max_depth} is invalid.'
......
...@@ -59,7 +59,8 @@ def traverse_and_mutate_submodules( ...@@ -59,7 +59,8 @@ def traverse_and_mutate_submodules(
module_list = [] module_list = []
def apply(m): def apply(m):
for name, child in m.named_children(): # Need to call list() here because the loop body might replace some children in-place.
for name, child in list(m.named_children()):
# post-order DFS # post-order DFS
if not topdown: if not topdown:
apply(child) apply(child)
...@@ -94,6 +95,8 @@ def traverse_and_mutate_submodules( ...@@ -94,6 +95,8 @@ def traverse_and_mutate_submodules(
break break
if isinstance(mutate_result, BaseSuperNetModule): if isinstance(mutate_result, BaseSuperNetModule):
# Replace child with the mutate result, and DFS this one
child = mutate_result
module_list.append(mutate_result) module_list.append(mutate_result)
# pre-order DFS # pre-order DFS
...@@ -112,9 +115,9 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k ...@@ -112,9 +115,9 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k
primitive_list = ( primitive_list = (
nas_nn.LayerChoice, nas_nn.LayerChoice,
nas_nn.InputChoice, nas_nn.InputChoice,
nas_nn.ValueChoice,
nas_nn.Repeat, nas_nn.Repeat,
nas_nn.NasBench101Cell, nas_nn.NasBench101Cell,
# nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later # nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet # nas_nn.NasBench201Cell, # forward = supernet
) )
......
...@@ -12,7 +12,8 @@ import torch.optim as optim ...@@ -12,7 +12,8 @@ import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.differentiable import ( from .supermodule.differentiable import (
DifferentiableMixedLayer, DifferentiableMixedInput, DifferentiableMixedLayer, DifferentiableMixedInput,
MixedOpDifferentiablePolicy, GumbelSoftmax MixedOpDifferentiablePolicy, GumbelSoftmax,
DifferentiableMixedCell, DifferentiableMixedRepeat
) )
from .supermodule.proxyless import ProxylessMixedInput, ProxylessMixedLayer from .supermodule.proxyless import ProxylessMixedInput, ProxylessMixedLayer
from .supermodule.operation import NATIVE_MIXED_OPERATIONS from .supermodule.operation import NATIVE_MIXED_OPERATIONS
...@@ -52,6 +53,8 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -52,6 +53,8 @@ class DartsLightningModule(BaseOneShotLightningModule):
hooks = [ hooks = [
DifferentiableMixedLayer.mutate, DifferentiableMixedLayer.mutate,
DifferentiableMixedInput.mutate, DifferentiableMixedInput.mutate,
DifferentiableMixedCell.mutate,
DifferentiableMixedRepeat.mutate,
] ]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS] hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook) hooks.append(no_default_hook)
...@@ -182,16 +185,6 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -182,16 +185,6 @@ class GumbelDartsLightningModule(DartsLightningModule):
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note) """.format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with gumbel-differentiable versions"""
hooks = [
DifferentiableMixedLayer.mutate,
DifferentiableMixedInput.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
return hooks
def mutate_kwargs(self): def mutate_kwargs(self):
"""Use gumbel softmax.""" """Use gumbel softmax."""
return { return {
......
...@@ -12,7 +12,10 @@ import torch.nn as nn ...@@ -12,7 +12,10 @@ import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.sampling import PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy from .supermodule.sampling import (
PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy,
PathSamplingCell, PathSamplingRepeat
)
from .supermodule.operation import NATIVE_MIXED_OPERATIONS from .supermodule.operation import NATIVE_MIXED_OPERATIONS
from .enas import ReinforceController, ReinforceField from .enas import ReinforceController, ReinforceField
...@@ -43,6 +46,8 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule): ...@@ -43,6 +46,8 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
hooks = [ hooks = [
PathSamplingLayer.mutate, PathSamplingLayer.mutate,
PathSamplingInput.mutate, PathSamplingInput.mutate,
PathSamplingRepeat.mutate,
PathSamplingCell.mutate,
] ]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS] hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook) hooks.append(no_default_hook)
......
...@@ -4,20 +4,26 @@ ...@@ -4,20 +4,26 @@
from __future__ import annotations from __future__ import annotations
import functools import functools
import logging
import warnings import warnings
from typing import Any, cast from typing import Any, Dict, Sequence, List, Tuple, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import preprocess_cell_inputs
from .base import BaseSuperNetModule from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy from .operation import MixedOperation, MixedOperationSamplingPolicy
from ._valuechoice_utils import traverse_all_options from .sampling import PathSamplingCell
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices
_logger = logging.getLogger(__name__)
class GumbelSoftmax(nn.Softmax): class GumbelSoftmax(nn.Softmax):
...@@ -284,10 +290,10 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): ...@@ -284,10 +290,10 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
return {} return {}
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]: def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is also random for each leaf value choice.""" """Export is argmax for each leaf value choice."""
result = {} result = {}
for name, spec in operation.search_space_spec().items(): for name, spec in operation.search_space_spec().items():
if name in result: if name in memo:
continue continue
chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item()) chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item())
result[name] = spec.values[chosen_index] result[name] = spec.values[chosen_index]
...@@ -300,3 +306,199 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): ...@@ -300,3 +306,199 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
} }
return dict(traverse_all_options(operation.mutable_arguments[name], weights=weights)) return dict(traverse_all_options(operation.mutable_arguments[name], weights=weights))
return operation.init_arguments[name] return operation.init_arguments[name]
class DifferentiableMixedRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int], softmax: nn.Module, memo: dict[str, Any]):
super().__init__()
self.blocks = blocks
self.depth = depth
self._softmax = softmax
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._arch_alpha = nn.ParameterDict()
for name, spec in self._space_spec.items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
self._arch_alpha[name] = alpha
def resample(self, memo):
"""Do nothing."""
return {}
def export(self, memo):
"""Choose argmax for each leaf value choice."""
result = {}
for name, spec in self._space_spec.items():
if name in memo:
continue
chosen_index = int(torch.argmax(self._arch_alpha[name]).item())
result[name] = spec.values[chosen_index]
return result
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
# Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(cast(List[nn.Module], module.blocks), module.depth_choice, softmax, memo)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
def forward(self, x):
weights: dict[str, torch.Tensor] = {
label: self._softmax(alpha) for label, alpha in self._arch_alpha.items()
}
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights)))
res: torch.Tensor | None = None
for i, block in enumerate(self.blocks, start=1): # start=1 because depths are 1, 2, 3, 4...
x = block(x)
if i in depth_weights:
if res is None:
res = depth_weights[i] * x
else:
res = res + depth_weights[i] * x
return res
class DifferentiableMixedCell(PathSamplingCell):
"""Implementation of Cell under differentiable context.
An architecture parameter is created on each edge of the full-connected graph.
"""
# TODO: It inherits :class:`PathSamplingCell` to reduce some duplicated code.
# Possibly need another refactor here.
def __init__(
self, op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor, concat_dim,
memo, mutate_kwargs, label
):
super().__init__(
op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor,
concat_dim, memo, mutate_kwargs, label
)
self._arch_alpha = nn.ParameterDict()
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for j in range(i):
edge_label = f'{label}/{i}_{j}'
op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j]
if edge_label in memo:
alpha = memo[edge_label]
if len(alpha) != len(op):
raise ValueError(
f'Architecture parameter size of same label {edge_label} conflict: '
f'{len(alpha)} vs. {len(op)}'
)
else:
alpha = nn.Parameter(torch.randn(len(op)) * 1E-3)
self._arch_alpha[edge_label] = alpha
self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
def resample(self, memo):
"""Differentiable doesn't need to resample."""
return {}
def export(self, memo):
"""Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
"""
exported = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
# Tuple of (weight, input_index, op_name)
all_weights: list[tuple[float, int, str]] = []
for j in range(i):
for k, name in enumerate(self.op_names):
all_weights.append((
float(self._arch_alpha[f'{self.label}/{i}_{j}'][k].item()),
j, name,
))
all_weights.sort(reverse=True)
# We first prefer inputs from different input_index.
# If we have got no other choices, we start to accept duplicates.
# Therefore we gather first occurrences of distinct input_index to the front.
first_occurrence_index: list[int] = [
all_weights.index( # The index of
next(filter(lambda t: t[1] == j, all_weights)) # First occurence of j
)
for j in range(i) # For j < i
]
first_occurrence_index.sort() # Keep them ordered too.
all_weights = [all_weights[k] for k in first_occurrence_index] + \
[w for j, w in enumerate(all_weights) if j not in first_occurrence_index]
_logger.info('Sorted weights in differentiable cell export (node %d): %s', i, all_weights)
for k in range(self.num_ops_per_node):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
_, j, op_name = all_weights[k % len(all_weights)]
exported[f'{self.label}/op_{i}_{k}'] = op_name
exported[f'{self.label}/input_{i}_{k}'] = j
return exported
def forward(self, *inputs: list[torch.Tensor] | torch.Tensor) -> tuple[torch.Tensor, ...] | torch.Tensor:
processed_inputs: list[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: list[torch.Tensor] = self.preprocessor(processed_inputs)
for i, ops in enumerate(cast(Sequence[Sequence[Dict[str, nn.Module]]], self.ops), start=self.num_predecessors):
current_state = []
for j in range(i): # for every previous tensors
op_results = torch.stack([op(states[j]) for op in ops[j].values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0)
current_state.append(edge_sum)
states.append(sum(current_state)) # type: ignore
# Always merge all
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
...@@ -3,17 +3,20 @@ ...@@ -3,17 +3,20 @@
from __future__ import annotations from __future__ import annotations
import copy
import random import random
from typing import Any from typing import Any, List, Dict, Sequence, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from .base import BaseSuperNetModule from .base import BaseSuperNetModule
from ._valuechoice_utils import evaluate_value_choice_with_dict from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices
from .operation import MixedOperationSamplingPolicy, MixedOperation from .operation import MixedOperationSamplingPolicy, MixedOperation
...@@ -198,3 +201,200 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): ...@@ -198,3 +201,200 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
if name in operation.mutable_arguments: if name in operation.mutable_arguments:
return self._sampled[name] return self._sampled[name]
return operation.init_arguments[name] return operation.init_arguments[name]
class PathSamplingRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a path-sampling supernet.
Samples one / some of the prefixes of the repeated blocks.
Attributes
----------
_sampled : int or list of int
Sampled depth.
"""
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int]):
super().__init__()
self.blocks = blocks
self.depth = depth
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._sampled: list[int] | int | None = None
def resample(self, memo):
"""Since depth is based on ValueChoice, we only need to randomly sample every leaf value choices."""
result = {}
for label in self._space_spec:
if label in memo:
result[label] = memo[label]
else:
result[label] = random.choice(self._space_spec[label].values)
self._sampled = evaluate_value_choice_with_dict(self.depth, result)
return result
def export(self, memo):
"""Random choose one if every choice not in memo."""
result = {}
for label in self._space_spec:
if label not in memo:
result[label] = random.choice(self._space_spec[label].values)
return result
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
# Only interesting when depth is mutable
return cls(cast(List[nn.Module], module.blocks), module.depth_choice)
def forward(self, x):
if self._sampled is None:
raise RuntimeError('At least one depth needs to be sampled before fprop.')
if isinstance(self._sampled, list):
res = []
for i, block in enumerate(self.blocks):
x = block(x)
if i in self._sampled:
res.append(x)
return sum(res)
else:
for block in self.blocks[:self._sampled]:
x = block(x)
return x
class PathSamplingCell(BaseSuperNetModule):
"""The implementation of super-net cell follows `DARTS <https://github.com/quark0/darts>`__.
When ``factory_used`` is true, it reconstructs the cell for every possible combination of operation and input index,
because for different input index, the cell factory could instantiate different operations (e.g., with different stride).
On export, we first have best (operation, input) pairs, the select the best ``num_ops_per_node``.
``loose_end`` is not supported yet, because it will cause more problems (e.g., shape mismatch).
We assumes ``loose_end`` to be ``all`` regardless of its configuration.
A supernet cell can't slim its own weight to fit into a sub network, which is also a known issue.
"""
def __init__(
self,
op_factory: list[CellOpFactory] | dict[str, CellOpFactory],
num_nodes: int,
num_ops_per_node: int,
num_predecessors: int,
preprocessor: Any,
postprocessor: Any,
concat_dim: int,
memo: dict, # although not used here, useful in subclass
mutate_kwargs: dict, # same as memo
label: str,
):
super().__init__()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
self.num_predecessors = num_predecessors
self.preprocessor = preprocessor
self.ops = nn.ModuleList()
self.postprocessor = postprocessor
self.concat_dim = concat_dim
self.op_names: list[str] = cast(List[str], None)
self.output_node_indices = list(range(self.num_predecessors, self.num_nodes + self.num_predecessors))
# Create a fully-connected graph.
# Each edge is a ModuleDict with op candidates.
# Can not reuse LayerChoice here, because the spec, resample, export all need to be customized.
# InputChoice is implicit in this graph.
for i in self.output_node_indices:
self.ops.append(nn.ModuleList())
for k in range(i + self.num_predecessors):
# Second argument in (i, **0**, k) is always 0.
# One-shot strategy can't handle the cases where op spec is dependent on `op_index`.
ops, _ = create_cell_op_candidates(op_factory, i, 0, k)
self.op_names = list(ops.keys())
cast(nn.ModuleList, self.ops[-1]).append(nn.ModuleDict(ops))
self.label = label
self._sampled: dict[str, str | int] = {}
def search_space_spec(self) -> dict[str, ParameterSpec]:
# TODO: Recreating the space here.
# The spec should be moved to definition of Cell itself.
space_spec = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for k in range(self.num_ops_per_node):
op_label = f'{self.label}/op_{i}_{k}'
input_label = f'{self.label}/input_{i}_{k}'
space_spec[op_label] = ParameterSpec(op_label, 'choice', self.op_names, (op_label,), True, size=len(self.op_names))
space_spec[input_label] = ParameterSpec(input_label, 'choice', list(range(i)), (input_label, ), True, size=i)
return space_spec
def resample(self, memo):
"""Random choose one path if label is not found in memo."""
self._sampled = {}
new_sampled = {}
for label, param_spec in self.search_space_spec().items():
if label in memo:
assert not isinstance(memo[label], list), 'Multi-path sampling is currently unsupported on cell.'
self._sampled[label] = memo[label]
else:
self._sampled[label] = new_sampled[label] = random.choice(param_spec.values)
return new_sampled
def export(self, memo):
"""Randomly choose one to export."""
return self.resample(memo)
def forward(self, *inputs: list[torch.Tensor] | torch.Tensor) -> tuple[torch.Tensor, ...] | torch.Tensor:
processed_inputs: List[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: List[torch.Tensor] = self.preprocessor(processed_inputs)
for i, ops in enumerate(cast(Sequence[Sequence[Dict[str, nn.Module]]], self.ops), start=self.num_predecessors):
current_state = []
for k in range(self.num_ops_per_node):
# Select op list based on the input chosen
input_index = self._sampled[f'{self.label}/input_{i}_{k}']
op_candidates = ops[cast(int, input_index)]
# Select op from op list based on the op chosen
op_index = self._sampled[f'{self.label}/op_{i}_{k}']
op = op_candidates[cast(str, op_index)]
current_state.append(op(states[cast(int, input_index)]))
states.append(sum(current_state)) # type: ignore
# Always merge all
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Cell):
op_factory = None # not all the cells need to be replaced
if module.op_candidates_factory is not None:
op_factory = module.op_candidates_factory
assert isinstance(op_factory, list) or isinstance(op_factory, dict), \
'Only support op_factory of type list or dict.'
elif module.merge_op == 'loose_end':
op_candidates_lc = module.ops[-1][-1] # type: ignore
assert isinstance(op_candidates_lc, LayerChoice)
op_factory = { # create a factory
name: lambda _, __, ___: copy.deepcopy(op_candidates_lc[name])
for name in op_candidates_lc.names
}
if op_factory is not None:
return cls(
op_factory,
module.num_nodes,
module.num_ops_per_node,
module.num_predecessors,
module.preprocessor,
module.postprocessor,
module.concat_dim,
memo,
mutate_kwargs,
module.label
)
from typing import List, Tuple
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
@model_wrapper
class CellSimple(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell(x, y)
@model_wrapper
class CellDefaultArgs(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
def forward(self, x):
return self.cell(x)
class CellPreprocessor(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 16)
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return [self.linear(x[0]), x[1]]
class CellPostprocessor(nn.Module):
def forward(self, this: torch.Tensor, prev: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return prev[-1], this
@model_wrapper
class CellCustomProcessor(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': nn.Linear(16, 16),
'second': nn.Linear(16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2,
preprocessor=CellPreprocessor(), postprocessor=CellPostprocessor(), merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
@model_wrapper
class CellLooseEnd(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end')
def forward(self, x, y):
return self.cell([x, y])
@model_wrapper
class CellOpFactory(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16),
'second': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
...@@ -23,6 +23,10 @@ from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process ...@@ -23,6 +23,10 @@ from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process
from nni.retiarii.serializer import model_wrapper from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack, NoContextError, original_state_dict_hooks from nni.retiarii.utils import ContextStack, NoContextError, original_state_dict_hooks
from .models import (
CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory
)
class EnumerateSampler(Sampler): class EnumerateSampler(Sampler):
def __init__(self): def __init__(self):
...@@ -924,17 +928,7 @@ class Python(GraphIR): ...@@ -924,17 +928,7 @@ class Python(GraphIR):
model = Net() model = Net()
def test_cell(self): def test_cell(self):
@model_wrapper raw_model, mutators = self._get_model_with_mutators(CellSimple())
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell(x, y)
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10): for _ in range(10):
sampler = EnumerateSampler() sampler = EnumerateSampler()
model = raw_model model = raw_model
...@@ -943,16 +937,7 @@ class Python(GraphIR): ...@@ -943,16 +937,7 @@ class Python(GraphIR):
self.assertTrue(self._get_converted_pytorch_model(model)( self.assertTrue(self._get_converted_pytorch_model(model)(
torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64])) torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64]))
@model_wrapper raw_model, mutators = self._get_model_with_mutators(CellDefaultArgs())
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
def forward(self, x):
return self.cell(x)
raw_model, mutators = self._get_model_with_mutators(Net2())
for _ in range(10): for _ in range(10):
sampler = EnumerateSampler() sampler = EnumerateSampler()
model = raw_model model = raw_model
...@@ -961,34 +946,7 @@ class Python(GraphIR): ...@@ -961,34 +946,7 @@ class Python(GraphIR):
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64])) self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_cell_predecessors(self): def test_cell_predecessors(self):
from typing import List, Tuple raw_model, mutators = self._get_model_with_mutators(CellCustomProcessor())
class Preprocessor(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 16)
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return [self.linear(x[0]), x[1]]
class Postprocessor(nn.Module):
def forward(self, this: torch.Tensor, prev: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return prev[-1], this
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': nn.Linear(16, 16),
'second': nn.Linear(16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2,
preprocessor=Preprocessor(), postprocessor=Postprocessor(), merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10): for _ in range(10):
sampler = EnumerateSampler() sampler = EnumerateSampler()
model = raw_model model = raw_model
...@@ -1000,17 +958,7 @@ class Python(GraphIR): ...@@ -1000,17 +958,7 @@ class Python(GraphIR):
self.assertTrue(result[1].size() == torch.Size([1, 64])) self.assertTrue(result[1].size() == torch.Size([1, 64]))
def test_cell_loose_end(self): def test_cell_loose_end(self):
@model_wrapper raw_model, mutators = self._get_model_with_mutators(CellLooseEnd())
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
any_not_all = False any_not_all = False
for _ in range(10): for _ in range(10):
sampler = EnumerateSampler() sampler = EnumerateSampler()
...@@ -1026,19 +974,7 @@ class Python(GraphIR): ...@@ -1026,19 +974,7 @@ class Python(GraphIR):
self.assertTrue(any_not_all) self.assertTrue(any_not_all)
def test_cell_complex(self): def test_cell_complex(self):
@model_wrapper raw_model, mutators = self._get_model_with_mutators(CellOpFactory())
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16),
'second': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10): for _ in range(10):
sampler = EnumerateSampler() sampler = EnumerateSampler()
model = raw_model model = raw_model
......
...@@ -137,6 +137,31 @@ class RepeatNet(nn.Module): ...@@ -137,6 +137,31 @@ class RepeatNet(nn.Module):
return F.log_softmax(x, dim=1) return F.log_softmax(x, dim=1)
@model_wrapper
class CellNet(nn.Module):
def __init__(self):
super().__init__()
self.stem = nn.Conv2d(1, 5, 7, stride=4)
self.cells = nn.Repeat(
lambda index: nn.Cell({
'conv1': lambda _, __, inp: nn.Conv2d(
(5 if index == 0 else 3 * 4) if inp is not None and inp < 1 else 4, 4, 1
),
'conv2': lambda _, __, inp: nn.Conv2d(
(5 if index == 0 else 3 * 4) if inp is not None and inp < 1 else 4, 4, 3, padding=1
),
}, 3, merge_op='loose_end'), (1, 3)
)
self.fc = nn.Linear(3 * 4, 10)
def forward(self, x):
x = self.stem(x)
x = self.cells(x)
x = torch.mean(x, (2, 3))
x = self.fc(x)
return F.log_softmax(x, dim=1)
@basic_unit @basic_unit
class MyOp(nn.Module): class MyOp(nn.Module):
def __init__(self, some_ch): def __init__(self, some_ch):
...@@ -183,6 +208,8 @@ def _mnist_net(type_, evaluator_kwargs): ...@@ -183,6 +208,8 @@ def _mnist_net(type_, evaluator_kwargs):
base_model = ValueChoiceConvNet() base_model = ValueChoiceConvNet()
elif type_ == 'repeat': elif type_ == 'repeat':
base_model = RepeatNet() base_model = RepeatNet()
elif type_ == 'cell':
base_model = CellNet()
elif type_ == 'custom_op': elif type_ == 'custom_op':
base_model = CustomOpValueChoiceNet() base_model = CustomOpValueChoiceNet()
else: else:
...@@ -246,7 +273,7 @@ def _test_strategy(strategy_, support_value_choice=True, multi_gpu=False): ...@@ -246,7 +273,7 @@ def _test_strategy(strategy_, support_value_choice=True, multi_gpu=False):
(_mnist_net('simple', evaluator_kwargs), True), (_mnist_net('simple', evaluator_kwargs), True),
(_mnist_net('simple_value_choice', evaluator_kwargs), support_value_choice), (_mnist_net('simple_value_choice', evaluator_kwargs), support_value_choice),
(_mnist_net('value_choice', evaluator_kwargs), support_value_choice), (_mnist_net('value_choice', evaluator_kwargs), support_value_choice),
(_mnist_net('repeat', evaluator_kwargs), False), # no strategy supports repeat currently (_mnist_net('repeat', evaluator_kwargs), support_value_choice), # no strategy supports repeat currently
(_mnist_net('custom_op', evaluator_kwargs), False), # this is definitely a NO (_mnist_net('custom_op', evaluator_kwargs), False), # this is definitely a NO
(_multihead_attention_net(evaluator_kwargs), support_value_choice), (_multihead_attention_net(evaluator_kwargs), support_value_choice),
] ]
......
...@@ -4,17 +4,23 @@ import numpy as np ...@@ -4,17 +4,23 @@ import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, Linear, MultiheadAttention from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, Linear, MultiheadAttention
from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import ( from nni.retiarii.oneshot.pytorch.supermodule.differentiable import (
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax,
DifferentiableMixedRepeat, DifferentiableMixedCell
) )
from nni.retiarii.oneshot.pytorch.supermodule.sampling import ( from nni.retiarii.oneshot.pytorch.supermodule.sampling import (
MixedOpPathSamplingPolicy, PathSamplingLayer, PathSamplingInput MixedOpPathSamplingPolicy, PathSamplingLayer, PathSamplingInput, PathSamplingRepeat, PathSamplingCell
) )
from nni.retiarii.oneshot.pytorch.supermodule.operation import MixedConv2d, NATIVE_MIXED_OPERATIONS from nni.retiarii.oneshot.pytorch.supermodule.operation import MixedConv2d, NATIVE_MIXED_OPERATIONS
from nni.retiarii.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput from nni.retiarii.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput
from nni.retiarii.oneshot.pytorch.supermodule._operation_utils import Slicable as S, MaybeWeighted as W from nni.retiarii.oneshot.pytorch.supermodule._operation_utils import Slicable as S, MaybeWeighted as W
from nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils import * from nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils import *
from .models import (
CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory
)
def test_slice(): def test_slice():
weight = np.ones((3, 7, 24, 23)) weight = np.ones((3, 7, 24, 23))
...@@ -246,3 +252,113 @@ def test_proxyless_layer_input(): ...@@ -246,3 +252,113 @@ def test_proxyless_layer_input():
assert input.resample({})['ddd'] in list(range(5)) assert input.resample({})['ddd'] in list(range(5))
assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2]) assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2])
assert input.export({})['ddd'] in list(range(5)) assert input.export({})['ddd'] in list(range(5))
def test_pathsampling_repeat():
op = PathSamplingRepeat([nn.Linear(16, 16), nn.Linear(16, 8), nn.Linear(8, 4)], ValueChoice([1, 2, 3], label='ccc'))
sample = op.resample({})
assert sample['ccc'] in [1, 2, 3]
for i in range(1, 4):
op.resample({'ccc': i})
out = op(torch.randn(2, 16))
assert out.shape[1] == [16, 8, 4][i - 1]
op = PathSamplingRepeat([nn.Linear(i + 1, i + 2) for i in range(7)], 2 * ValueChoice([1, 2, 3], label='ddd') + 1)
sample = op.resample({})
assert sample['ddd'] in [1, 2, 3]
for i in range(1, 4):
op.resample({'ddd': i})
out = op(torch.randn(2, 1))
assert out.shape[1] == (2 * i + 1) + 1
def test_differentiable_repeat():
op = DifferentiableMixedRepeat(
[nn.Linear(8 if i == 0 else 16, 16) for i in range(4)],
ValueChoice([0, 1], label='ccc') * 2 + 1,
GumbelSoftmax(-1),
{}
)
op.resample({})
assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
sample = op.export({})
assert 'ccc' in sample and sample['ccc'] in [0, 1]
def test_pathsampling_cell():
for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]:
model = cell_cls()
nas_modules = traverse_and_mutate_submodules(model, [
PathSamplingLayer.mutate,
PathSamplingInput.mutate,
PathSamplingCell.mutate,
], {})
result = {}
for module in nas_modules:
result.update(module.resample(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
result = {}
for module in nas_modules:
result.update(module.export(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
if cell_cls in [CellLooseEnd, CellOpFactory]:
assert isinstance(model.cell, PathSamplingCell)
else:
assert not isinstance(model.cell, PathSamplingCell)
inputs = {
CellSimple: (torch.randn(2, 16), torch.randn(2, 16)),
CellDefaultArgs: (torch.randn(2, 16),),
CellCustomProcessor: (torch.randn(2, 3), torch.randn(2, 16)),
CellLooseEnd: (torch.randn(2, 16), torch.randn(2, 16)),
CellOpFactory: (torch.randn(2, 3), torch.randn(2, 16)),
}[cell_cls]
output = model(*inputs)
if cell_cls == CellCustomProcessor:
assert isinstance(output, tuple) and len(output) == 2 and \
output[1].shape == torch.Size([2, 16 * model.cell.num_nodes])
else:
# no loose-end support for now
assert output.shape == torch.Size([2, 16 * model.cell.num_nodes])
def test_differentiable_cell():
for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]:
model = cell_cls()
nas_modules = traverse_and_mutate_submodules(model, [
DifferentiableMixedLayer.mutate,
DifferentiableMixedInput.mutate,
DifferentiableMixedCell.mutate,
], {})
result = {}
for module in nas_modules:
result.update(module.export(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
ctrl_params = []
for m in nas_modules:
ctrl_params += list(m.parameters(arch=True))
if cell_cls in [CellLooseEnd, CellOpFactory]:
assert len(ctrl_params) == model.cell.num_nodes * (model.cell.num_nodes + 3) // 2
assert isinstance(model.cell, DifferentiableMixedCell)
else:
assert not isinstance(model.cell, DifferentiableMixedCell)
inputs = {
CellSimple: (torch.randn(2, 16), torch.randn(2, 16)),
CellDefaultArgs: (torch.randn(2, 16),),
CellCustomProcessor: (torch.randn(2, 3), torch.randn(2, 16)),
CellLooseEnd: (torch.randn(2, 16), torch.randn(2, 16)),
CellOpFactory: (torch.randn(2, 3), torch.randn(2, 16)),
}[cell_cls]
output = model(*inputs)
if cell_cls == CellCustomProcessor:
assert isinstance(output, tuple) and len(output) == 2 and \
output[1].shape == torch.Size([2, 16 * model.cell.num_nodes])
else:
# no loose-end support for now
assert output.shape == torch.Size([2, 16 * model.cell.num_nodes])
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment