Unverified Commit 2815fb1f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Make models in search space hub work with one-shot (#4921)

parent 80beca52
...@@ -21,6 +21,9 @@ import torch ...@@ -21,6 +21,9 @@ import torch
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper from nni.retiarii import model_wrapper
from nni.retiarii.oneshot.pytorch.supermodule.sampling import PathSamplingRepeat
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import DifferentiableMixedRepeat
from .utils.fixed import FixedFactory from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight from .utils.pretrained import load_pretrained_weight
...@@ -348,6 +351,100 @@ class CellBuilder: ...@@ -348,6 +351,100 @@ class CellBuilder:
return cell return cell
class NDSStage(nn.Repeat):
"""This class defines NDSStage, a special type of Repeat, for isinstance check, and shape alignment.
In NDS, we can't simply use Repeat to stack the blocks,
because the output shape of each stacked block can be different.
This is a problem for one-shot strategy because they assume every possible candidate
should return values of the same shape.
Therefore, we need :class:`NDSStagePathSampling` and :class:`NDSStageDifferentiable`
to manually align the shapes -- specifically, to transform the first block in each stage.
This is not required though, when depth is not changing, or the mutable depth causes no problem
(e.g., when the minimum depth is large enough).
.. attention::
Assumption: Loose end is treated as all in ``merge_op`` (the case in one-shot),
which enforces reduction cell and normal cells in the same stage to have the exact same output shape.
"""
estimated_out_channels_prev: int
"""Output channels of cells in last stage."""
estimated_out_channels: int
"""Output channels of this stage. It's **estimated** because it assumes ``all`` as ``merge_op``."""
downsampling: bool
"""This stage has downsampling"""
def first_cell_transformation_factory(self) -> Optional[nn.Module]:
"""To make the "previous cell" in first cell's output have the same shape as cells in this stage."""
if self.downsampling:
return FactorizedReduce(self.estimated_out_channels_prev, self.estimated_out_channels)
elif self.estimated_out_channels_prev is not self.estimated_out_channels:
# Can't use != here, ValueChoice doesn't support
return ReLUConvBN(self.estimated_out_channels_prev, self.estimated_out_channels, 1, 1, 0)
return None
class NDSStagePathSampling(PathSamplingRepeat):
"""The path-sampling implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.api.ValueChoiceX):
return cls(
module.first_cell_transformation_factory(),
cast(List[nn.Module], module.blocks),
module.depth_choice
)
def __init__(self, first_cell_transformation: Optional[nn.Module], *args, **kwargs):
super().__init__(*args, **kwargs)
self.first_cell_transformation = first_cell_transformation
def reduction(self, items: List[Tuple[torch.Tensor, torch.Tensor]], sampled: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
if 1 not in sampled or self.first_cell_transformation is None:
return super().reduction(items, sampled)
# items[0] must be the result of first cell
assert len(items[0]) == 2
# Only apply the transformation on "prev" output.
items[0] = (self.first_cell_transformation(items[0][0]), items[0][1])
return super().reduction(items, sampled)
class NDSStageDifferentiable(DifferentiableMixedRepeat):
"""The differentiable implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.api.ValueChoiceX):
# Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(
module.first_cell_transformation_factory(),
cast(List[nn.Module], module.blocks),
module.depth_choice,
softmax,
memo
)
def __init__(self, first_cell_transformation: Optional[nn.Module], *args, **kwargs):
super().__init__(*args, **kwargs)
self.first_cell_transformation = first_cell_transformation
def reduction(
self, items: List[Tuple[torch.Tensor, torch.Tensor]], weights: List[float], depths: List[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
if 1 not in depths or self.first_cell_transformation is None:
return super().reduction(items, weights, depths)
# Same as NDSStagePathSampling
assert len(items[0]) == 2
items[0] = (self.first_cell_transformation(items[0][0]), items[0][1])
return super().reduction(items, weights, depths)
_INIT_PARAMETER_DOCS = """ _INIT_PARAMETER_DOCS = """
Parameters Parameters
...@@ -437,6 +534,8 @@ class NDS(nn.Module): ...@@ -437,6 +534,8 @@ class NDS(nn.Module):
C_pprev = C_prev = 3 * C C_pprev = C_prev = 3 * C
C_curr = C C_curr = C
last_cell_reduce = False last_cell_reduce = False
else:
raise ValueError(f'Unsupported dataset: {dataset}')
self.stages = nn.ModuleList() self.stages = nn.ModuleList()
for stage_idx in range(3): for stage_idx in range(3):
...@@ -448,9 +547,19 @@ class NDS(nn.Module): ...@@ -448,9 +547,19 @@ class NDS(nn.Module):
# C_out is usually `C * num_nodes_per_cell` because of concat operator. # C_out is usually `C * num_nodes_per_cell` because of concat operator.
cell_builder = CellBuilder(op_candidates, C_pprev, C_prev, C_curr, num_nodes_per_cell, cell_builder = CellBuilder(op_candidates, C_pprev, C_prev, C_curr, num_nodes_per_cell,
merge_op, stage_idx > 0, last_cell_reduce) merge_op, stage_idx > 0, last_cell_reduce)
stage = nn.Repeat(cell_builder, num_cells_per_stage[stage_idx]) stage: Union[NDSStage, nn.Sequential] = NDSStage(cell_builder, num_cells_per_stage[stage_idx])
if isinstance(stage, NDSStage):
stage.estimated_out_channels_prev = cast(int, C_prev)
stage.estimated_out_channels = cast(int, C_curr * num_nodes_per_cell)
stage.downsampling = stage_idx > 0
self.stages.append(stage) self.stages.append(stage)
# NOTE: output_node_indices will be computed on-the-fly in trial code.
# When constructing model space, it's just all the nodes in the cell,
# which happens to be the case of one-shot supernet.
# C_pprev is output channel number of last second cell among all the cells already built. # C_pprev is output channel number of last second cell among all the cells already built.
if len(stage) > 1: if len(stage) > 1:
# Contains more than one cell # Contains more than one cell
......
...@@ -98,7 +98,6 @@ class ConvBNReLU(nn.Sequential): ...@@ -98,7 +98,6 @@ class ConvBNReLU(nn.Sequential):
] ]
super().__init__(*simplify_sequential(blocks)) super().__init__(*simplify_sequential(blocks))
self.out_channels = out_channels
class DepthwiseSeparableConv(nn.Sequential): class DepthwiseSeparableConv(nn.Sequential):
...@@ -133,7 +132,8 @@ class DepthwiseSeparableConv(nn.Sequential): ...@@ -133,7 +132,8 @@ class DepthwiseSeparableConv(nn.Sequential):
ConvBNReLU(in_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity) ConvBNReLU(in_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity)
] ]
super().__init__(*simplify_sequential(blocks)) super().__init__(*simplify_sequential(blocks))
self.has_skip = stride == 1 and in_channels == out_channels # NOTE: "is" is used here instead of "==" to avoid creating a new value choice.
self.has_skip = stride == 1 and in_channels is out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.has_skip: if self.has_skip:
...@@ -177,8 +177,8 @@ class InvertedResidual(nn.Sequential): ...@@ -177,8 +177,8 @@ class InvertedResidual(nn.Sequential):
hidden_ch = cast(int, make_divisible(in_channels * expand_ratio, 8)) hidden_ch = cast(int, make_divisible(in_channels * expand_ratio, 8))
# NOTE: this equivalence check should also work for ValueChoice # NOTE: this equivalence check (==) does NOT work for ValueChoice, need to use "is"
self.has_skip = stride == 1 and in_channels == out_channels self.has_skip = stride == 1 and in_channels is out_channels
layers: List[nn.Module] = [ layers: List[nn.Module] = [
# point-wise convolution # point-wise convolution
......
...@@ -7,4 +7,3 @@ from .proxyless import ProxylessTrainer ...@@ -7,4 +7,3 @@ from .proxyless import ProxylessTrainer
from .random import SinglePathTrainer, RandomTrainer from .random import SinglePathTrainer, RandomTrainer
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasLightningModule, RandomSamplingLightningModule from .sampling import EnasLightningModule, RandomSamplingLightningModule
from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader
...@@ -60,7 +60,7 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -60,7 +60,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
) )
__doc__ = _darts_note.format( __doc__ = _darts_note.format(
module_notes='The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.', module_notes='The DARTS Module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note, module_params=BaseOneShotLightningModule._inner_module_note,
) )
...@@ -161,7 +161,7 @@ class ProxylessLightningModule(DartsLightningModule): ...@@ -161,7 +161,7 @@ class ProxylessLightningModule(DartsLightningModule):
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note) """.format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _proxyless_note.format( __doc__ = _proxyless_note.format(
module_notes='This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.', module_notes='This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note, module_params=BaseOneShotLightningModule._inner_module_note,
) )
......
...@@ -115,7 +115,7 @@ def _slice_weight(weight: T, slice_: multidim_slice | list[tuple[multidim_slice, ...@@ -115,7 +115,7 @@ def _slice_weight(weight: T, slice_: multidim_slice | list[tuple[multidim_slice,
# this saves an op on computational graph, which will hopefully make training faster # this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex. # Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr = np.zeros(weight.shape, dtype=np.bool) # type: ignore dummy_arr = np.zeros(weight.shape, dtype=bool) # type: ignore
no_effect = cast(Any, _do_slice(dummy_arr, slice_)).shape == dummy_arr.shape no_effect = cast(Any, _do_slice(dummy_arr, slice_)).shape == dummy_arr.shape
if no_effect: if no_effect:
......
...@@ -7,7 +7,7 @@ in the way that is most convenient to one-shot algorithms.""" ...@@ -7,7 +7,7 @@ in the way that is most convenient to one-shot algorithms."""
from __future__ import annotations from __future__ import annotations
import itertools import itertools
from typing import Any, TypeVar, List, cast from typing import Any, TypeVar, List, cast, Mapping, Sequence, Optional, Iterable
import numpy as np import numpy as np
import torch import torch
...@@ -20,7 +20,13 @@ Choice = Any ...@@ -20,7 +20,13 @@ Choice = Any
T = TypeVar('T') T = TypeVar('T')
__all__ = ['dedup_inner_choices', 'evaluate_value_choice_with_dict', 'traverse_all_options'] __all__ = [
'dedup_inner_choices',
'evaluate_value_choice_with_dict',
'traverse_all_options',
'weighted_sum',
'evaluate_constant',
]
def dedup_inner_choices(value_choices: list[ValueChoiceX]) -> dict[str, ParameterSpec]: def dedup_inner_choices(value_choices: list[ValueChoiceX]) -> dict[str, ParameterSpec]:
...@@ -138,3 +144,101 @@ def traverse_all_options( ...@@ -138,3 +144,101 @@ def traverse_all_options(
return sorted(result.keys()) # type: ignore return sorted(result.keys()) # type: ignore
else: else:
return sorted(result.items()) # type: ignore return sorted(result.items()) # type: ignore
def evaluate_constant(expr: Any) -> Any:
"""Evaluate a value choice expression to a constant. Raise ValueError if it's not a constant."""
all_options = traverse_all_options(expr)
if len(all_options) > 1:
raise ValueError(f'{expr} is not evaluated to a constant. All possible values are: {all_options}')
res = all_options[0]
return res
def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence[Optional[float]], None)) -> T:
"""Return a weighted sum of items.
Items can be list of tensors, numpy arrays, or nested lists / dicts.
If ``weights`` is None, this is simply an unweighted sum.
"""
if weights is None:
weights = [None] * len(items)
assert len(items) == len(weights) > 0
elem = items[0]
unsupported_msg = f'Unsupported element type in weighted sum: {type(elem)}. Value is: {elem}'
if isinstance(elem, str):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise TypeError(unsupported_msg)
try:
if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)):
if weights[0] is None:
res = elem
else:
res = elem * weights[0]
for it, weight in zip(items[1:], weights[1:]):
if type(it) != type(elem):
raise TypeError(f'Expect type {type(elem)} but found {type(it)}. Can not be summed')
if weight is None:
res = res + it # type: ignore
else:
res = res + it * weight # type: ignore
return cast(T, res)
if isinstance(elem, Mapping):
for item in items:
if not isinstance(item, Mapping):
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
if set(item) != set(elem):
raise KeyError(f'Expect keys {list(elem)} but found {list(item)}')
return cast(T, {
key: weighted_sum(cast(List[dict], [cast(Mapping, d)[key] for d in items]), weights) for key in elem
})
if isinstance(elem, Sequence):
for item in items:
if not isinstance(item, Sequence):
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
if len(item) != len(elem):
raise ValueError(f'Expect length {len(item)} but found {len(elem)}')
transposed = cast(Iterable[list], zip(*items)) # type: ignore
return cast(T, [weighted_sum(column, weights) for column in transposed])
except (TypeError, ValueError, RuntimeError, KeyError):
raise ValueError(
'Error when summing items. Value format / shape does not match. See full traceback for details.' +
''.join([
f'\n {idx}: {_summarize_elem_format(it)}' for idx, it in enumerate(items)
])
)
# Dealing with all unexpected types.
raise TypeError(unsupported_msg)
def _summarize_elem_format(elem: Any) -> Any:
# Get a summary of one elem
# Helps generate human-readable error messages
class _repr_object:
# empty object is only repr
def __init__(self, representation):
self.representation = representation
def __repr__(self):
return self.representation
if isinstance(elem, torch.Tensor):
return _repr_object('torch.Tensor(' + ', '.join(map(str, elem.shape)) + ')')
if isinstance(elem, np.ndarray):
return _repr_object('np.array(' + ', '.join(map(str, elem.shape)) + ')')
if isinstance(elem, Mapping):
return {key: _summarize_elem_format(value) for key, value in elem.items()}
if isinstance(elem, Sequence):
return [_summarize_elem_format(value) for value in elem]
# fallback to original, for cases like float, int, ...
return elem
...@@ -21,14 +21,14 @@ from nni.retiarii.nn.pytorch.cell import preprocess_cell_inputs ...@@ -21,14 +21,14 @@ from nni.retiarii.nn.pytorch.cell import preprocess_cell_inputs
from .base import BaseSuperNetModule from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy from .operation import MixedOperation, MixedOperationSamplingPolicy
from .sampling import PathSamplingCell from .sampling import PathSamplingCell
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, weighted_sum
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'DifferentiableMixedLayer', 'DifferentiableMixedInput', 'DifferentiableMixedLayer', 'DifferentiableMixedInput',
'DifferentiableMixedRepeat', 'DifferentiableMixedCell', 'DifferentiableMixedRepeat', 'DifferentiableMixedCell',
'MixedOpDifferentiablePolicy' 'MixedOpDifferentiablePolicy',
] ]
...@@ -77,7 +77,11 @@ class DifferentiableMixedLayer(BaseSuperNetModule): ...@@ -77,7 +77,11 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
_arch_parameter_names: list[str] = ['_arch_alpha'] _arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str): def __init__(self,
paths: list[tuple[str, nn.Module]],
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__() super().__init__()
self.op_names = [] self.op_names = []
if len(alpha) != len(paths): if len(alpha) != len(paths):
...@@ -118,11 +122,15 @@ class DifferentiableMixedLayer(BaseSuperNetModule): ...@@ -118,11 +122,15 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(list(module.named_children()), alpha, softmax, module.label) return cls(list(module.named_children()), alpha, softmax, module.label)
def reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
"""The forward of mixed layer accepts same arguments as its sub-layer.""" """The forward of mixed layer accepts same arguments as its sub-layer."""
op_results = torch.stack([getattr(self, op)(*args, **kwargs) for op in self.op_names]) all_op_results = [getattr(self, op)(*args, **kwargs) for op in self.op_names]
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) return self.reduction(all_op_results, self._softmax(self._arch_alpha))
return torch.sum(op_results * self._softmax(self._arch_alpha).view(*alpha_shape), 0)
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters.""" """Parameters excluding architecture parameters."""
...@@ -167,7 +175,12 @@ class DifferentiableMixedInput(BaseSuperNetModule): ...@@ -167,7 +175,12 @@ class DifferentiableMixedInput(BaseSuperNetModule):
_arch_parameter_names: list[str] = ['_arch_alpha'] _arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str): def __init__(self,
n_candidates: int,
n_chosen: int | None,
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__() super().__init__()
self.n_candidates = n_candidates self.n_candidates = n_candidates
if len(alpha) != n_candidates: if len(alpha) != n_candidates:
...@@ -217,11 +230,14 @@ class DifferentiableMixedInput(BaseSuperNetModule): ...@@ -217,11 +230,14 @@ class DifferentiableMixedInput(BaseSuperNetModule):
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label) return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label)
def reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, inputs): def forward(self, inputs):
"""Forward takes a list of input candidates.""" """Forward takes a list of input candidates."""
inputs = torch.stack(inputs) return self.reduction(inputs, self._softmax(self._arch_alpha))
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * self._softmax(self._arch_alpha).view(*alpha_shape), 0)
def parameters(self, *args, **kwargs): def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters.""" """Parameters excluding architecture parameters."""
...@@ -318,11 +334,18 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): ...@@ -318,11 +334,18 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
""" """
Implementaion of Repeat in a differentiable supernet. Implementaion of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths. Result is a weighted sum of possible prefixes, sliced by possible depths.
If the output is not a single tensor, it will be summed at every independant dimension.
See :func:`weighted_sum` for details.
""" """
_arch_parameter_names: list[str] = ['_arch_alpha'] _arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int], softmax: nn.Module, memo: dict[str, Any]): def __init__(self,
blocks: list[nn.Module],
depth: ChoiceOf[int],
softmax: nn.Module,
memo: dict[str, Any]):
super().__init__() super().__init__()
self.blocks = blocks self.blocks = blocks
self.depth = depth self.depth = depth
...@@ -377,21 +400,28 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): ...@@ -377,21 +400,28 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
if not arch: if not arch:
yield name, p yield name, p
def reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, x): def forward(self, x):
weights: dict[str, torch.Tensor] = { weights: dict[str, torch.Tensor] = {
label: self._softmax(alpha) for label, alpha in self._arch_alpha.items() label: self._softmax(alpha) for label, alpha in self._arch_alpha.items()
} }
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights))) depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights)))
res: torch.Tensor | None = None res: list[torch.Tensor] = []
weight_list: list[float] = []
depths: list[int] = []
for i, block in enumerate(self.blocks, start=1): # start=1 because depths are 1, 2, 3, 4... for i, block in enumerate(self.blocks, start=1): # start=1 because depths are 1, 2, 3, 4...
x = block(x) x = block(x)
if i in depth_weights: if i in depth_weights:
if res is None: weight_list.append(depth_weights[i])
res = depth_weights[i] * x res.append(x)
else: depths.append(i)
res = res + depth_weights[i] * x
return res return self.reduction(res, weight_list, depths)
class DifferentiableMixedCell(PathSamplingCell): class DifferentiableMixedCell(PathSamplingCell):
......
...@@ -10,7 +10,8 @@ from __future__ import annotations ...@@ -10,7 +10,8 @@ from __future__ import annotations
import inspect import inspect
import itertools import itertools
from typing import Any, Type, TypeVar, cast, Union, Tuple import warnings
from typing import Any, Type, TypeVar, cast, Union, Tuple, List
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -23,7 +24,7 @@ from nni.common.serializer import is_traceable ...@@ -23,7 +24,7 @@ from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX from nni.retiarii.nn.pytorch.api import ValueChoiceX
from .base import BaseSuperNetModule from .base import BaseSuperNetModule
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant
from ._operation_utils import Slicable as _S, MaybeWeighted as _W, int_or_int_dict, scalar_or_scalar_dict from ._operation_utils import Slicable as _S, MaybeWeighted as _W, int_or_int_dict, scalar_or_scalar_dict
T = TypeVar('T') T = TypeVar('T')
...@@ -268,14 +269,18 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -268,14 +269,18 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
- ``in_channels`` - ``in_channels``
- ``out_channels`` - ``out_channels``
- ``groups`` (only supported in path sampling) - ``groups``
- ``stride`` (only supported in path sampling) - ``stride`` (only supported in path sampling)
- ``kernel_size`` - ``kernel_size``
- ``padding`` (only supported in path sampling) - ``padding``
- ``dilation`` (only supported in path sampling) - ``dilation`` (only supported in path sampling)
``padding`` will be the "max" padding in differentiable mode. ``padding`` will be the "max" padding in differentiable mode.
Mutable ``groups`` is NOT supported in most cases of differentiable mode.
However, we do support one special case when the group number is proportional to ``in_channels`` and ``out_channels``.
This is often the case of depth-wise convolutions.
For channels, prefix will be sliced. For channels, prefix will be sliced.
For kernels, we take the small kernel from the center and round it to floor (left top). For example :: For kernels, we take the small kernel from the center and round it to floor (left top). For example ::
...@@ -315,6 +320,18 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -315,6 +320,18 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
return max(all_sizes) return max(all_sizes)
elif name == 'groups': elif name == 'groups':
if 'in_channels' in self.mutable_arguments:
# If the ratio is constant, we don't need to try the maximum groups.
try:
constant = evaluate_constant(self.mutable_arguments['in_channels'] / value_choice)
return max(cast(List[float], traverse_all_options(value_choice))) // int(constant)
except ValueError:
warnings.warn(
'Both input channels and groups are ValueChoice in a convolution, and their relative ratio is not a constant. '
'This can be problematic for most one-shot algorithms. Please check whether this is your intention.',
RuntimeWarning
)
# minimum groups, maximum kernel # minimum groups, maximum kernel
return min(traverse_all_options(value_choice)) return min(traverse_all_options(value_choice))
...@@ -328,11 +345,11 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -328,11 +345,11 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
stride: _int_or_tuple, stride: _int_or_tuple,
padding: scalar_or_scalar_dict[_int_or_tuple], padding: scalar_or_scalar_dict[_int_or_tuple],
dilation: int, dilation: int,
groups: int, groups: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor: inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [stride, dilation, groups]): if any(isinstance(arg, dict) for arg in [stride, dilation]):
raise ValueError(_diff_not_compatible_error.format('stride, dilation and groups', 'Conv2d')) raise ValueError(_diff_not_compatible_error.format('stride, dilation', 'Conv2d'))
in_channels_ = _W(in_channels) in_channels_ = _W(in_channels)
out_channels_ = _W(out_channels) out_channels_ = _W(out_channels)
...@@ -340,7 +357,32 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -340,7 +357,32 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
# slice prefix # slice prefix
# For groups > 1, we use groups to slice input weights # For groups > 1, we use groups to slice input weights
weight = _S(self.weight)[:out_channels_] weight = _S(self.weight)[:out_channels_]
weight = _S(weight)[:, :in_channels_ // groups]
if not isinstance(groups, dict):
weight = _S(weight)[:, :in_channels_ // groups]
else:
assert 'groups' in self.mutable_arguments
err_message = 'For differentiable one-shot strategy, when groups is a ValueChoice, ' \
'in_channels and out_channels should also be a ValueChoice. ' \
'Also, the ratios of in_channels divided by groups, and out_channels divided by groups ' \
'should be constants.'
if 'in_channels' not in self.mutable_arguments or 'out_channels' not in self.mutable_arguments:
raise ValueError(err_message)
try:
in_channels_per_group = evaluate_constant(self.mutable_arguments['in_channels'] / self.mutable_arguments['groups'])
except ValueError:
raise ValueError(err_message)
if in_channels_per_group != int(in_channels_per_group):
raise ValueError(f'Input channels per group is found to be a non-integer: {in_channels_per_group}')
if inputs.size(1) % in_channels_per_group != 0:
raise RuntimeError(
f'Input channels must be divisible by in_channels_per_group, but the input shape is {inputs.size()}, '
f'while in_channels_per_group = {in_channels_per_group}'
)
# Compute sliced weights and groups (as an integer)
weight = _S(weight)[:, :int(in_channels_per_group)]
groups = inputs.size(1) // int(in_channels_per_group)
# slice center # slice center
if isinstance(kernel_size, dict): if isinstance(kernel_size, dict):
......
...@@ -16,7 +16,7 @@ from nni.retiarii.nn.pytorch.api import ValueChoiceX ...@@ -16,7 +16,7 @@ from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs from nni.retiarii.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from .base import BaseSuperNetModule from .base import BaseSuperNetModule
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum
from .operation import MixedOperationSamplingPolicy, MixedOperation from .operation import MixedOperationSamplingPolicy, MixedOperation
__all__ = [ __all__ = [
...@@ -72,6 +72,10 @@ class PathSamplingLayer(BaseSuperNetModule): ...@@ -72,6 +72,10 @@ class PathSamplingLayer(BaseSuperNetModule):
if isinstance(module, LayerChoice): if isinstance(module, LayerChoice):
return cls(list(module.named_children()), module.label) return cls(list(module.named_children()), module.label)
def reduction(self, items: list[Any], sampled: list[Any]):
"""Override this to implement customized reduction."""
return weighted_sum(items)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if self._sampled is None: if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.') raise RuntimeError('At least one path needs to be sampled before fprop.')
...@@ -79,10 +83,7 @@ class PathSamplingLayer(BaseSuperNetModule): ...@@ -79,10 +83,7 @@ class PathSamplingLayer(BaseSuperNetModule):
# str(samp) is needed here because samp can sometimes be integers, but attr are always str # str(samp) is needed here because samp can sometimes be integers, but attr are always str
res = [getattr(self, str(samp))(*args, **kwargs) for samp in sampled] res = [getattr(self, str(samp))(*args, **kwargs) for samp in sampled]
if len(res) == 1: return self.reduction(res, sampled)
return res[0]
else:
return sum(res)
class PathSamplingInput(BaseSuperNetModule): class PathSamplingInput(BaseSuperNetModule):
...@@ -95,11 +96,11 @@ class PathSamplingInput(BaseSuperNetModule): ...@@ -95,11 +96,11 @@ class PathSamplingInput(BaseSuperNetModule):
Sampled input indices. Sampled input indices.
""" """
def __init__(self, n_candidates: int, n_chosen: int, reduction: str, label: str): def __init__(self, n_candidates: int, n_chosen: int, reduction_type: str, label: str):
super().__init__() super().__init__()
self.n_candidates = n_candidates self.n_candidates = n_candidates
self.n_chosen = n_chosen self.n_chosen = n_chosen
self.reduction = reduction self.reduction_type = reduction_type
self._sampled: list[int] | int | None = None self._sampled: list[int] | int | None = None
self.label = label self.label = label
...@@ -144,6 +145,19 @@ class PathSamplingInput(BaseSuperNetModule): ...@@ -144,6 +145,19 @@ class PathSamplingInput(BaseSuperNetModule):
raise ValueError('n_chosen is None is not supported yet.') raise ValueError('n_chosen is None is not supported yet.')
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label) return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
def reduction(self, items: list[Any], sampled: list[Any]) -> Any:
"""Override this to implement customized reduction."""
if len(items) == 1:
return items[0]
else:
if self.reduction_type == 'sum':
return sum(items)
elif self.reduction_type == 'mean':
return sum(items) / len(items)
elif self.reduction_type == 'concat':
return torch.cat(items, 1)
raise ValueError(f'Unsupported reduction type: {self.reduction_type}')
def forward(self, input_tensors): def forward(self, input_tensors):
if self._sampled is None: if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.') raise RuntimeError('At least one path needs to be sampled before fprop.')
...@@ -151,15 +165,7 @@ class PathSamplingInput(BaseSuperNetModule): ...@@ -151,15 +165,7 @@ class PathSamplingInput(BaseSuperNetModule):
raise ValueError(f'Expect {self.n_candidates} input tensors, found {len(input_tensors)}.') raise ValueError(f'Expect {self.n_candidates} input tensors, found {len(input_tensors)}.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
res = [input_tensors[samp] for samp in sampled] res = [input_tensors[samp] for samp in sampled]
if len(res) == 1: return self.reduction(res, sampled)
return res[0]
else:
if self.reduction == 'sum':
return sum(res)
elif self.reduction == 'mean':
return sum(res) / len(res)
elif self.reduction == 'concat':
return torch.cat(res, 1)
class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
...@@ -202,6 +208,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): ...@@ -202,6 +208,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return result return result
def forward_argument(self, operation: MixedOperation, name: str) -> Any: def forward_argument(self, operation: MixedOperation, name: str) -> Any:
# NOTE: we don't support sampling a list here.
if self._sampled is None: if self._sampled is None:
raise ValueError('Need to call resample() before running forward') raise ValueError('Need to call resample() before running forward')
if name in operation.mutable_arguments: if name in operation.mutable_arguments:
...@@ -257,20 +264,23 @@ class PathSamplingRepeat(BaseSuperNetModule): ...@@ -257,20 +264,23 @@ class PathSamplingRepeat(BaseSuperNetModule):
# Only interesting when depth is mutable # Only interesting when depth is mutable
return cls(cast(List[nn.Module], module.blocks), module.depth_choice) return cls(cast(List[nn.Module], module.blocks), module.depth_choice)
def reduction(self, items: list[Any], sampled: list[Any]):
"""Override this to implement customized reduction."""
return weighted_sum(items)
def forward(self, x): def forward(self, x):
if self._sampled is None: if self._sampled is None:
raise RuntimeError('At least one depth needs to be sampled before fprop.') raise RuntimeError('At least one depth needs to be sampled before fprop.')
if isinstance(self._sampled, list): sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
res = []
for i, block in enumerate(self.blocks): res = []
x = block(x) for cur_depth, block in enumerate(self.blocks, start=1):
if i in self._sampled: x = block(x)
res.append(x) if cur_depth in sampled:
return sum(res) res.append(x)
else: if not any(d > cur_depth for d in sampled):
for block in self.blocks[:self._sampled]: break
x = block(x) return self.reduction(res, sampled)
return x
class PathSamplingCell(BaseSuperNetModule): class PathSamplingCell(BaseSuperNetModule):
......
...@@ -215,7 +215,7 @@ def _mnist_net(type_, evaluator_kwargs): ...@@ -215,7 +215,7 @@ def _mnist_net(type_, evaluator_kwargs):
base_model = CustomOpValueChoiceNet() base_model = CustomOpValueChoiceNet()
else: else:
raise ValueError(f'Unsupported type: {type_}') raise ValueError(f'Unsupported type: {type_}')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST)('data/mnist', train=True, download=True, transform=transform) train_dataset = nni.trace(MNIST)('data/mnist', train=True, download=True, transform=transform)
# Multi-GPU combined dataloader will break this subset sampler. Expected though. # Multi-GPU combined dataloader will break this subset sampler. Expected though.
......
...@@ -78,6 +78,46 @@ def test_valuechoice_utils(): ...@@ -78,6 +78,46 @@ def test_valuechoice_utils():
for value, weight in ans.items(): for value, weight in ans.items():
assert abs(weight - weights[value]) < 1e-6 assert abs(weight - weights[value]) < 1e-6
assert evaluate_constant(ValueChoice([3, 4, 6], label='x') - ValueChoice([3, 4, 6], label='x')) == 0
with pytest.raises(ValueError):
evaluate_constant(ValueChoice([3, 4, 6]) - ValueChoice([3, 4, 6]))
assert evaluate_constant(ValueChoice([3, 4, 6], label='x') * 2 / ValueChoice([3, 4, 6], label='x')) == 2
def test_weighted_sum():
weights = [0.1, 0.2, 0.7]
items = [1, 2, 3]
assert abs(weighted_sum(items, weights) - 2.6) < 1e-6
assert weighted_sum(items) == 6
with pytest.raises(TypeError, match='Unsupported'):
weighted_sum(['a', 'b', 'c'], weights)
assert abs(weighted_sum(np.arange(3), weights).item() - 1.6) < 1e-6
items = [torch.full((2, 3, 5), i) for i in items]
assert abs(weighted_sum(items, weights).flatten()[0].item() - 2.6) < 1e-6
items = [torch.randn(2, 3, i) for i in [1, 2, 3]]
with pytest.raises(ValueError, match=r'does not match.*\n.*torch\.Tensor\(2, 3, 1\)'):
weighted_sum(items, weights)
items = [(1, 2), (3, 4), (5, 6)]
res = weighted_sum(items, weights)
assert len(res) == 2 and abs(res[0] - 4.2) < 1e-6 and abs(res[1] - 5.2) < 1e-6
items = [(1, 2), (3, 4), (5, 6, 7)]
with pytest.raises(ValueError):
weighted_sum(items, weights)
items = [{"a": i, "b": np.full((2, 3, 5), i)} for i in [1, 2, 3]]
res = weighted_sum(items, weights)
assert res['b'].shape == (2, 3, 5)
assert abs(res['b'][0][0][0] - res['a']) < 1e-6
assert abs(res['a'] - 2.6) < 1e-6
def test_pathsampling_valuechoice(): def test_pathsampling_valuechoice():
orig_conv = Conv2d(3, ValueChoice([3, 5, 7], label='123'), kernel_size=3) orig_conv = Conv2d(3, ValueChoice([3, 5, 7], label='123'), kernel_size=3)
...@@ -147,6 +187,26 @@ def test_mixed_conv2d(): ...@@ -147,6 +187,26 @@ def test_mixed_conv2d():
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in')) conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in'))
assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10)).size() == torch.Size([2, 6, 10, 10]) assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10)).size() == torch.Size([2, 6, 10, 10])
# groups, invalid case
conv = Conv2d(ValueChoice([9, 6, 3], label='in'), ValueChoice([9, 6, 3], label='in'), 1, groups=9)
with pytest.raises(RuntimeError):
assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10))
# groups, differentiable
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='out'), 1, groups=ValueChoice([3, 6, 9], label='in'))
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in'))
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
with pytest.raises(ValueError):
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 9], label='groups'))
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
with pytest.raises(RuntimeError):
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in') // 3)
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 10, 3, 3))
# make sure kernel is sliced correctly # make sure kernel is sliced correctly
conv = Conv2d(1, 1, ValueChoice([1, 3], label='k'), bias=False) conv = Conv2d(1, 1, ValueChoice([1, 3], label='k'), bias=False)
conv = MixedConv2d.mutate(conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy}) conv = MixedConv2d.mutate(conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
...@@ -238,13 +298,18 @@ def test_differentiable_layer_input(): ...@@ -238,13 +298,18 @@ def test_differentiable_layer_input():
assert op.export({})['eee'] in ['a', 'b'] assert op.export({})['eee'] in ['a', 'b']
assert len(list(op.parameters())) == 3 assert len(list(op.parameters())) == 3
with pytest.raises(ValueError):
op = DifferentiableMixedLayer([('a', Linear(2, 3)), ('b', Linear(2, 4))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
op(torch.randn(4, 2))
input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd') input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2 assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2
assert len(input.export({})['ddd']) == 2 assert len(input.export({})['ddd']) == 2
def test_proxyless_layer_input(): def test_proxyless_layer_input():
op = ProxylessMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee') op = ProxylessMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)),
nn.Softmax(-1), 'eee')
assert op.resample({})['eee'] in ['a', 'b'] assert op.resample({})['eee'] in ['a', 'b']
assert op(torch.randn(4, 2)).size(-1) == 3 assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b'] assert op.export({})['eee'] in ['a', 'b']
...@@ -286,6 +351,31 @@ def test_differentiable_repeat(): ...@@ -286,6 +351,31 @@ def test_differentiable_repeat():
sample = op.export({}) sample = op.export({})
assert 'ccc' in sample and sample['ccc'] in [0, 1] assert 'ccc' in sample and sample['ccc'] in [0, 1]
class TupleModule(nn.Module):
def __init__(self, num):
super().__init__()
self.num = num
def forward(self, *args, **kwargs):
return torch.full((2, 3), self.num), torch.full((3, 5), self.num), {'a': 7, 'b': [self.num] * 11}
class CustomSoftmax(nn.Softmax):
def forward(self, *args, **kwargs):
return [0.3, 0.3, 0.4]
op = DifferentiableMixedRepeat(
[TupleModule(i + 1) for i in range(4)],
ValueChoice([1, 2, 4], label='ccc'),
CustomSoftmax(),
{}
)
op.resample({})
res = op(None)
assert len(res) == 3
assert res[0].shape == (2, 3) and res[0][0][0].item() == 2.5
assert res[2]['a'] == 7
assert len(res[2]['b']) == 11 and res[2]['b'][-1] == 2.5
def test_pathsampling_cell(): def test_pathsampling_cell():
for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]: for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]:
...@@ -363,4 +453,3 @@ def test_differentiable_cell(): ...@@ -363,4 +453,3 @@ def test_differentiable_cell():
else: else:
# no loose-end support for now # no loose-end support for now
assert output.shape == torch.Size([2, 16 * model.cell.num_nodes]) assert output.shape == torch.Size([2, 16 * model.cell.num_nodes])
...@@ -95,7 +95,7 @@ def test_nasbench101(): ...@@ -95,7 +95,7 @@ def test_nasbench101():
def test_nasbench201(): def test_nasbench201():
ss = searchspace.NasBench101() ss = searchspace.NasBench201()
_test_searchspace_on_dataset(ss) _test_searchspace_on_dataset(ss)
......
import logging
import pytest
import numpy as np
import torch
import nni
import nni.retiarii.hub.pytorch as ss
import nni.retiarii.evaluator.pytorch as pl
import nni.retiarii.strategy as stg
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.hub.pytorch.nasnet import NDSStagePathSampling, NDSStageDifferentiable
from torch.utils.data import Subset
from torchvision import transforms
from torchvision.datasets import CIFAR10, ImageNet
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason='Too slow without CUDA.')
def _hub_factory(alias):
if alias == 'nasbench101':
return ss.NasBench101()
if alias == 'nasbench201':
return ss.NasBench201()
if alias == 'mobilenetv3':
return ss.MobileNetV3Space()
if alias == 'mobilenetv3_small':
return ss.MobileNetV3Space(
width_multipliers=(0.75, 1, 1.5),
expand_ratios=(4, 6)
)
if alias == 'proxylessnas':
return ss.ProxylessNAS()
if alias == 'shufflenet':
return ss.ShuffleNetSpace()
if alias == 'autoformer':
return ss.AutoformerSpace()
if '_smalldepth' in alias:
num_cells = (4, 8)
elif '_depth' in alias:
num_cells = (8, 12)
else:
num_cells = 8
if '_width' in alias:
width = (8, 16)
else:
width = 16
if '_imagenet' in alias:
dataset = 'imagenet'
else:
dataset = 'cifar'
if alias.startswith('nasnet'):
return ss.NASNet(width=width, num_cells=num_cells, dataset=dataset)
if alias.startswith('enas'):
return ss.ENAS(width=width, num_cells=num_cells, dataset=dataset)
if alias.startswith('amoeba'):
return ss.AmoebaNet(width=width, num_cells=num_cells, dataset=dataset)
if alias.startswith('pnas'):
return ss.PNAS(width=width, num_cells=num_cells, dataset=dataset)
if alias.startswith('darts'):
return ss.DARTS(width=width, num_cells=num_cells, dataset=dataset)
raise ValueError(f'Unrecognized space: {alias}')
def _strategy_factory(alias, space_type):
# Some search space needs extra hooks
extra_mutation_hooks = []
nds_need_shape_alignment = '_smalldepth' in space_type
if nds_need_shape_alignment:
if alias in ['enas', 'random']:
extra_mutation_hooks.append(NDSStagePathSampling.mutate)
else:
extra_mutation_hooks.append(NDSStageDifferentiable.mutate)
if alias == 'darts':
return stg.DARTS(mutation_hooks=extra_mutation_hooks)
if alias == 'gumbel':
return stg.GumbelDARTS(mutation_hooks=extra_mutation_hooks)
if alias == 'proxyless':
return stg.Proxyless()
if alias == 'enas':
return stg.ENAS(mutation_hooks=extra_mutation_hooks, reward_metric_name='val_acc')
if alias == 'random':
return stg.RandomOneShot(mutation_hooks=extra_mutation_hooks)
raise ValueError(f'Unrecognized strategy: {alias}')
def _dataset_factory(dataset_type, subset=20):
if dataset_type == 'cifar10':
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_dataset = nni.trace(CIFAR10)(
'../data/cifar10',
train=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]))
valid_dataset = nni.trace(CIFAR10)(
'../data/cifar10',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
normalize,
]))
elif dataset_type == 'imagenet':
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = nni.trace(ImageNet)(
'../data/imagenet',
split='val', # no train data available in tests
transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
valid_dataset = nni.trace(ImageNet)(
'../data/imagenet',
split='val',
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
else:
raise ValueError(f'Unsupported dataset type: {dataset_type}')
if subset:
train_dataset = Subset(train_dataset, np.random.permutation(len(train_dataset))[:subset])
valid_dataset = Subset(valid_dataset, np.random.permutation(len(valid_dataset))[:subset])
return train_dataset, valid_dataset
@pytest.mark.parametrize('space_type', [
# 'nasbench101',
'nasbench201',
'mobilenetv3',
'mobilenetv3_small',
'proxylessnas',
'shufflenet',
# 'autoformer',
'nasnet',
'enas',
'amoeba',
'pnas',
'darts',
'darts_smalldepth',
'darts_depth',
'darts_width',
'darts_width_smalldepth',
'darts_width_depth',
'darts_imagenet',
'darts_width_smalldepth_imagenet',
'enas_smalldepth',
'enas_depth',
'enas_width',
'enas_width_smalldepth',
'enas_width_depth',
'enas_imagenet',
'enas_width_smalldepth_imagenet',
'pnas_width_smalldepth',
'amoeba_width_smalldepth',
])
@pytest.mark.parametrize('strategy_type', [
'darts',
'gumbel',
'proxyless',
'enas',
'random'
])
def test_hub_oneshot(space_type, strategy_type):
NDS_SPACES = ['amoeba', 'darts', 'pnas', 'enas', 'nasnet']
if strategy_type == 'proxyless':
if 'width' in space_type or 'depth' in space_type or \
any(space_type.startswith(prefix) for prefix in NDS_SPACES + ['proxylessnas', 'mobilenetv3']):
pytest.skip('The space has used unsupported APIs.')
if strategy_type in ['darts', 'gumbel'] and space_type == 'mobilenetv3':
pytest.skip('Skip as it consumes too much memory.')
model_space = _hub_factory(space_type)
dataset_type = 'cifar10'
if 'imagenet' in space_type or space_type in ['mobilenetv3', 'proxylessnas', 'shufflenet', 'autoformer']:
dataset_type = 'imagenet'
subset_size = 4
if strategy_type in ['darts', 'gumbel'] and any(space_type.startswith(prefix) for prefix in NDS_SPACES) and '_' in space_type:
subset_size = 2
train_dataset, valid_dataset = _dataset_factory(dataset_type, subset=subset_size)
train_loader = pl.DataLoader(train_dataset, batch_size=2, num_workers=2, shuffle=True)
valid_loader = pl.DataLoader(valid_dataset, batch_size=2, num_workers=2, shuffle=False)
evaluator = pl.Classification(
train_dataloaders=train_loader,
val_dataloaders=valid_loader,
max_epochs=1,
export_onnx=False,
gpus=1 if torch.cuda.is_available() else 0, # 0 for my debug
logger=False, # disable logging and checkpoint to avoid too much log
enable_checkpointing=False,
enable_model_summary=False
# profiler='advanced'
)
# To test on final model:
# model = type(model_space).load_searched_model('darts-v2')
# evaluator.fit(model)
strategy = _strategy_factory(strategy_type, space_type)
config = RetiariiExeConfig()
config.execution_engine = 'oneshot'
experiment = RetiariiExperiment(model_space, evaluator, strategy=strategy)
experiment.run(config)
_original_loglevel = None
def setup_module(module):
global _original_loglevel
_original_loglevel = logging.getLogger("pytorch_lightning").level
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
def teardown_module(module):
logging.getLogger("pytorch_lightning").setLevel(_original_loglevel)
...@@ -50,14 +50,16 @@ def prepare_imagenet_subset(data_dir: Path, imagenet_dir: Path): ...@@ -50,14 +50,16 @@ def prepare_imagenet_subset(data_dir: Path, imagenet_dir: Path):
# Target root dir # Target root dir
subset_dir = data_dir / 'imagenet' subset_dir = data_dir / 'imagenet'
shutil.rmtree(subset_dir, ignore_errors=True) shutil.rmtree(subset_dir, ignore_errors=True)
subset_dir.mkdir(parents=True)
shutil.copyfile(imagenet_dir / 'meta.bin', subset_dir / 'meta.bin')
copied_count = 0 copied_count = 0
for category_id, imgs in images.items(): for category_id, imgs in images.items():
random_state.shuffle(imgs) random_state.shuffle(imgs)
for img in imgs[:len(imgs) // 10]: for img in imgs[:len(imgs) // 10]:
folder_name = Path(img).parent.name folder_name = Path(img).parent.name
file_name = Path(img).name file_name = Path(img).name
(subset_dir / folder_name).mkdir(exist_ok=True, parents=True) (subset_dir / 'val' / folder_name).mkdir(exist_ok=True, parents=True)
shutil.copyfile(img, subset_dir / folder_name / file_name) shutil.copyfile(img, subset_dir / 'val' / folder_name / file_name)
copied_count += 1 copied_count += 1
print(f'Generated a subset of {copied_count} images.') print(f'Generated a subset of {copied_count} images.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment