Unverified Commit 2815fb1f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Make models in search space hub work with one-shot (#4921)

parent 80beca52
......@@ -21,6 +21,9 @@ import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import model_wrapper
from nni.retiarii.oneshot.pytorch.supermodule.sampling import PathSamplingRepeat
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import DifferentiableMixedRepeat
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
......@@ -348,6 +351,100 @@ class CellBuilder:
return cell
class NDSStage(nn.Repeat):
"""This class defines NDSStage, a special type of Repeat, for isinstance check, and shape alignment.
In NDS, we can't simply use Repeat to stack the blocks,
because the output shape of each stacked block can be different.
This is a problem for one-shot strategy because they assume every possible candidate
should return values of the same shape.
Therefore, we need :class:`NDSStagePathSampling` and :class:`NDSStageDifferentiable`
to manually align the shapes -- specifically, to transform the first block in each stage.
This is not required though, when depth is not changing, or the mutable depth causes no problem
(e.g., when the minimum depth is large enough).
.. attention::
Assumption: Loose end is treated as all in ``merge_op`` (the case in one-shot),
which enforces reduction cell and normal cells in the same stage to have the exact same output shape.
"""
estimated_out_channels_prev: int
"""Output channels of cells in last stage."""
estimated_out_channels: int
"""Output channels of this stage. It's **estimated** because it assumes ``all`` as ``merge_op``."""
downsampling: bool
"""This stage has downsampling"""
def first_cell_transformation_factory(self) -> Optional[nn.Module]:
"""To make the "previous cell" in first cell's output have the same shape as cells in this stage."""
if self.downsampling:
return FactorizedReduce(self.estimated_out_channels_prev, self.estimated_out_channels)
elif self.estimated_out_channels_prev is not self.estimated_out_channels:
# Can't use != here, ValueChoice doesn't support
return ReLUConvBN(self.estimated_out_channels_prev, self.estimated_out_channels, 1, 1, 0)
return None
class NDSStagePathSampling(PathSamplingRepeat):
"""The path-sampling implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.api.ValueChoiceX):
return cls(
module.first_cell_transformation_factory(),
cast(List[nn.Module], module.blocks),
module.depth_choice
)
def __init__(self, first_cell_transformation: Optional[nn.Module], *args, **kwargs):
super().__init__(*args, **kwargs)
self.first_cell_transformation = first_cell_transformation
def reduction(self, items: List[Tuple[torch.Tensor, torch.Tensor]], sampled: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
if 1 not in sampled or self.first_cell_transformation is None:
return super().reduction(items, sampled)
# items[0] must be the result of first cell
assert len(items[0]) == 2
# Only apply the transformation on "prev" output.
items[0] = (self.first_cell_transformation(items[0][0]), items[0][1])
return super().reduction(items, sampled)
class NDSStageDifferentiable(DifferentiableMixedRepeat):
"""The differentiable implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.api.ValueChoiceX):
# Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(
module.first_cell_transformation_factory(),
cast(List[nn.Module], module.blocks),
module.depth_choice,
softmax,
memo
)
def __init__(self, first_cell_transformation: Optional[nn.Module], *args, **kwargs):
super().__init__(*args, **kwargs)
self.first_cell_transformation = first_cell_transformation
def reduction(
self, items: List[Tuple[torch.Tensor, torch.Tensor]], weights: List[float], depths: List[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
if 1 not in depths or self.first_cell_transformation is None:
return super().reduction(items, weights, depths)
# Same as NDSStagePathSampling
assert len(items[0]) == 2
items[0] = (self.first_cell_transformation(items[0][0]), items[0][1])
return super().reduction(items, weights, depths)
_INIT_PARAMETER_DOCS = """
Parameters
......@@ -437,6 +534,8 @@ class NDS(nn.Module):
C_pprev = C_prev = 3 * C
C_curr = C
last_cell_reduce = False
else:
raise ValueError(f'Unsupported dataset: {dataset}')
self.stages = nn.ModuleList()
for stage_idx in range(3):
......@@ -448,9 +547,19 @@ class NDS(nn.Module):
# C_out is usually `C * num_nodes_per_cell` because of concat operator.
cell_builder = CellBuilder(op_candidates, C_pprev, C_prev, C_curr, num_nodes_per_cell,
merge_op, stage_idx > 0, last_cell_reduce)
stage = nn.Repeat(cell_builder, num_cells_per_stage[stage_idx])
stage: Union[NDSStage, nn.Sequential] = NDSStage(cell_builder, num_cells_per_stage[stage_idx])
if isinstance(stage, NDSStage):
stage.estimated_out_channels_prev = cast(int, C_prev)
stage.estimated_out_channels = cast(int, C_curr * num_nodes_per_cell)
stage.downsampling = stage_idx > 0
self.stages.append(stage)
# NOTE: output_node_indices will be computed on-the-fly in trial code.
# When constructing model space, it's just all the nodes in the cell,
# which happens to be the case of one-shot supernet.
# C_pprev is output channel number of last second cell among all the cells already built.
if len(stage) > 1:
# Contains more than one cell
......
......@@ -98,7 +98,6 @@ class ConvBNReLU(nn.Sequential):
]
super().__init__(*simplify_sequential(blocks))
self.out_channels = out_channels
class DepthwiseSeparableConv(nn.Sequential):
......@@ -133,7 +132,8 @@ class DepthwiseSeparableConv(nn.Sequential):
ConvBNReLU(in_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity)
]
super().__init__(*simplify_sequential(blocks))
self.has_skip = stride == 1 and in_channels == out_channels
# NOTE: "is" is used here instead of "==" to avoid creating a new value choice.
self.has_skip = stride == 1 and in_channels is out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.has_skip:
......@@ -177,8 +177,8 @@ class InvertedResidual(nn.Sequential):
hidden_ch = cast(int, make_divisible(in_channels * expand_ratio, 8))
# NOTE: this equivalence check should also work for ValueChoice
self.has_skip = stride == 1 and in_channels == out_channels
# NOTE: this equivalence check (==) does NOT work for ValueChoice, need to use "is"
self.has_skip = stride == 1 and in_channels is out_channels
layers: List[nn.Module] = [
# point-wise convolution
......
......@@ -7,4 +7,3 @@ from .proxyless import ProxylessTrainer
from .random import SinglePathTrainer, RandomTrainer
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasLightningModule, RandomSamplingLightningModule
from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader
......@@ -60,7 +60,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
)
__doc__ = _darts_note.format(
module_notes='The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.',
module_notes='The DARTS Module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
......@@ -161,7 +161,7 @@ class ProxylessLightningModule(DartsLightningModule):
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _proxyless_note.format(
module_notes='This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.',
module_notes='This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
......
......@@ -115,7 +115,7 @@ def _slice_weight(weight: T, slice_: multidim_slice | list[tuple[multidim_slice,
# this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr = np.zeros(weight.shape, dtype=np.bool) # type: ignore
dummy_arr = np.zeros(weight.shape, dtype=bool) # type: ignore
no_effect = cast(Any, _do_slice(dummy_arr, slice_)).shape == dummy_arr.shape
if no_effect:
......
......@@ -7,7 +7,7 @@ in the way that is most convenient to one-shot algorithms."""
from __future__ import annotations
import itertools
from typing import Any, TypeVar, List, cast
from typing import Any, TypeVar, List, cast, Mapping, Sequence, Optional, Iterable
import numpy as np
import torch
......@@ -20,7 +20,13 @@ Choice = Any
T = TypeVar('T')
__all__ = ['dedup_inner_choices', 'evaluate_value_choice_with_dict', 'traverse_all_options']
__all__ = [
'dedup_inner_choices',
'evaluate_value_choice_with_dict',
'traverse_all_options',
'weighted_sum',
'evaluate_constant',
]
def dedup_inner_choices(value_choices: list[ValueChoiceX]) -> dict[str, ParameterSpec]:
......@@ -138,3 +144,101 @@ def traverse_all_options(
return sorted(result.keys()) # type: ignore
else:
return sorted(result.items()) # type: ignore
def evaluate_constant(expr: Any) -> Any:
"""Evaluate a value choice expression to a constant. Raise ValueError if it's not a constant."""
all_options = traverse_all_options(expr)
if len(all_options) > 1:
raise ValueError(f'{expr} is not evaluated to a constant. All possible values are: {all_options}')
res = all_options[0]
return res
def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence[Optional[float]], None)) -> T:
"""Return a weighted sum of items.
Items can be list of tensors, numpy arrays, or nested lists / dicts.
If ``weights`` is None, this is simply an unweighted sum.
"""
if weights is None:
weights = [None] * len(items)
assert len(items) == len(weights) > 0
elem = items[0]
unsupported_msg = f'Unsupported element type in weighted sum: {type(elem)}. Value is: {elem}'
if isinstance(elem, str):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise TypeError(unsupported_msg)
try:
if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)):
if weights[0] is None:
res = elem
else:
res = elem * weights[0]
for it, weight in zip(items[1:], weights[1:]):
if type(it) != type(elem):
raise TypeError(f'Expect type {type(elem)} but found {type(it)}. Can not be summed')
if weight is None:
res = res + it # type: ignore
else:
res = res + it * weight # type: ignore
return cast(T, res)
if isinstance(elem, Mapping):
for item in items:
if not isinstance(item, Mapping):
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
if set(item) != set(elem):
raise KeyError(f'Expect keys {list(elem)} but found {list(item)}')
return cast(T, {
key: weighted_sum(cast(List[dict], [cast(Mapping, d)[key] for d in items]), weights) for key in elem
})
if isinstance(elem, Sequence):
for item in items:
if not isinstance(item, Sequence):
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
if len(item) != len(elem):
raise ValueError(f'Expect length {len(item)} but found {len(elem)}')
transposed = cast(Iterable[list], zip(*items)) # type: ignore
return cast(T, [weighted_sum(column, weights) for column in transposed])
except (TypeError, ValueError, RuntimeError, KeyError):
raise ValueError(
'Error when summing items. Value format / shape does not match. See full traceback for details.' +
''.join([
f'\n {idx}: {_summarize_elem_format(it)}' for idx, it in enumerate(items)
])
)
# Dealing with all unexpected types.
raise TypeError(unsupported_msg)
def _summarize_elem_format(elem: Any) -> Any:
# Get a summary of one elem
# Helps generate human-readable error messages
class _repr_object:
# empty object is only repr
def __init__(self, representation):
self.representation = representation
def __repr__(self):
return self.representation
if isinstance(elem, torch.Tensor):
return _repr_object('torch.Tensor(' + ', '.join(map(str, elem.shape)) + ')')
if isinstance(elem, np.ndarray):
return _repr_object('np.array(' + ', '.join(map(str, elem.shape)) + ')')
if isinstance(elem, Mapping):
return {key: _summarize_elem_format(value) for key, value in elem.items()}
if isinstance(elem, Sequence):
return [_summarize_elem_format(value) for value in elem]
# fallback to original, for cases like float, int, ...
return elem
......@@ -21,14 +21,14 @@ from nni.retiarii.nn.pytorch.cell import preprocess_cell_inputs
from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy
from .sampling import PathSamplingCell
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, weighted_sum
_logger = logging.getLogger(__name__)
__all__ = [
'DifferentiableMixedLayer', 'DifferentiableMixedInput',
'DifferentiableMixedRepeat', 'DifferentiableMixedCell',
'MixedOpDifferentiablePolicy'
'MixedOpDifferentiablePolicy',
]
......@@ -77,7 +77,11 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
def __init__(self,
paths: list[tuple[str, nn.Module]],
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__()
self.op_names = []
if len(alpha) != len(paths):
......@@ -118,11 +122,15 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(list(module.named_children()), alpha, softmax, module.label)
def reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, *args, **kwargs):
"""The forward of mixed layer accepts same arguments as its sub-layer."""
op_results = torch.stack([getattr(self, op)(*args, **kwargs) for op in self.op_names])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * self._softmax(self._arch_alpha).view(*alpha_shape), 0)
all_op_results = [getattr(self, op)(*args, **kwargs) for op in self.op_names]
return self.reduction(all_op_results, self._softmax(self._arch_alpha))
def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
......@@ -167,7 +175,12 @@ class DifferentiableMixedInput(BaseSuperNetModule):
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str):
def __init__(self,
n_candidates: int,
n_chosen: int | None,
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__()
self.n_candidates = n_candidates
if len(alpha) != n_candidates:
......@@ -217,11 +230,14 @@ class DifferentiableMixedInput(BaseSuperNetModule):
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label)
def reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, inputs):
"""Forward takes a list of input candidates."""
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * self._softmax(self._arch_alpha).view(*alpha_shape), 0)
return self.reduction(inputs, self._softmax(self._arch_alpha))
def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
......@@ -318,11 +334,18 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths.
If the output is not a single tensor, it will be summed at every independant dimension.
See :func:`weighted_sum` for details.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int], softmax: nn.Module, memo: dict[str, Any]):
def __init__(self,
blocks: list[nn.Module],
depth: ChoiceOf[int],
softmax: nn.Module,
memo: dict[str, Any]):
super().__init__()
self.blocks = blocks
self.depth = depth
......@@ -377,21 +400,28 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
if not arch:
yield name, p
def reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, x):
weights: dict[str, torch.Tensor] = {
label: self._softmax(alpha) for label, alpha in self._arch_alpha.items()
}
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights)))
res: torch.Tensor | None = None
res: list[torch.Tensor] = []
weight_list: list[float] = []
depths: list[int] = []
for i, block in enumerate(self.blocks, start=1): # start=1 because depths are 1, 2, 3, 4...
x = block(x)
if i in depth_weights:
if res is None:
res = depth_weights[i] * x
else:
res = res + depth_weights[i] * x
return res
weight_list.append(depth_weights[i])
res.append(x)
depths.append(i)
return self.reduction(res, weight_list, depths)
class DifferentiableMixedCell(PathSamplingCell):
......
......@@ -10,7 +10,8 @@ from __future__ import annotations
import inspect
import itertools
from typing import Any, Type, TypeVar, cast, Union, Tuple
import warnings
from typing import Any, Type, TypeVar, cast, Union, Tuple, List
import torch
import torch.nn as nn
......@@ -23,7 +24,7 @@ from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from .base import BaseSuperNetModule
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant
from ._operation_utils import Slicable as _S, MaybeWeighted as _W, int_or_int_dict, scalar_or_scalar_dict
T = TypeVar('T')
......@@ -268,14 +269,18 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
- ``in_channels``
- ``out_channels``
- ``groups`` (only supported in path sampling)
- ``groups``
- ``stride`` (only supported in path sampling)
- ``kernel_size``
- ``padding`` (only supported in path sampling)
- ``padding``
- ``dilation`` (only supported in path sampling)
``padding`` will be the "max" padding in differentiable mode.
Mutable ``groups`` is NOT supported in most cases of differentiable mode.
However, we do support one special case when the group number is proportional to ``in_channels`` and ``out_channels``.
This is often the case of depth-wise convolutions.
For channels, prefix will be sliced.
For kernels, we take the small kernel from the center and round it to floor (left top). For example ::
......@@ -315,6 +320,18 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
return max(all_sizes)
elif name == 'groups':
if 'in_channels' in self.mutable_arguments:
# If the ratio is constant, we don't need to try the maximum groups.
try:
constant = evaluate_constant(self.mutable_arguments['in_channels'] / value_choice)
return max(cast(List[float], traverse_all_options(value_choice))) // int(constant)
except ValueError:
warnings.warn(
'Both input channels and groups are ValueChoice in a convolution, and their relative ratio is not a constant. '
'This can be problematic for most one-shot algorithms. Please check whether this is your intention.',
RuntimeWarning
)
# minimum groups, maximum kernel
return min(traverse_all_options(value_choice))
......@@ -328,11 +345,11 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
stride: _int_or_tuple,
padding: scalar_or_scalar_dict[_int_or_tuple],
dilation: int,
groups: int,
groups: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [stride, dilation, groups]):
raise ValueError(_diff_not_compatible_error.format('stride, dilation and groups', 'Conv2d'))
if any(isinstance(arg, dict) for arg in [stride, dilation]):
raise ValueError(_diff_not_compatible_error.format('stride, dilation', 'Conv2d'))
in_channels_ = _W(in_channels)
out_channels_ = _W(out_channels)
......@@ -340,7 +357,32 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
# slice prefix
# For groups > 1, we use groups to slice input weights
weight = _S(self.weight)[:out_channels_]
weight = _S(weight)[:, :in_channels_ // groups]
if not isinstance(groups, dict):
weight = _S(weight)[:, :in_channels_ // groups]
else:
assert 'groups' in self.mutable_arguments
err_message = 'For differentiable one-shot strategy, when groups is a ValueChoice, ' \
'in_channels and out_channels should also be a ValueChoice. ' \
'Also, the ratios of in_channels divided by groups, and out_channels divided by groups ' \
'should be constants.'
if 'in_channels' not in self.mutable_arguments or 'out_channels' not in self.mutable_arguments:
raise ValueError(err_message)
try:
in_channels_per_group = evaluate_constant(self.mutable_arguments['in_channels'] / self.mutable_arguments['groups'])
except ValueError:
raise ValueError(err_message)
if in_channels_per_group != int(in_channels_per_group):
raise ValueError(f'Input channels per group is found to be a non-integer: {in_channels_per_group}')
if inputs.size(1) % in_channels_per_group != 0:
raise RuntimeError(
f'Input channels must be divisible by in_channels_per_group, but the input shape is {inputs.size()}, '
f'while in_channels_per_group = {in_channels_per_group}'
)
# Compute sliced weights and groups (as an integer)
weight = _S(weight)[:, :int(in_channels_per_group)]
groups = inputs.size(1) // int(in_channels_per_group)
# slice center
if isinstance(kernel_size, dict):
......
......@@ -16,7 +16,7 @@ from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from .base import BaseSuperNetModule
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum
from .operation import MixedOperationSamplingPolicy, MixedOperation
__all__ = [
......@@ -72,6 +72,10 @@ class PathSamplingLayer(BaseSuperNetModule):
if isinstance(module, LayerChoice):
return cls(list(module.named_children()), module.label)
def reduction(self, items: list[Any], sampled: list[Any]):
"""Override this to implement customized reduction."""
return weighted_sum(items)
def forward(self, *args, **kwargs):
if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.')
......@@ -79,10 +83,7 @@ class PathSamplingLayer(BaseSuperNetModule):
# str(samp) is needed here because samp can sometimes be integers, but attr are always str
res = [getattr(self, str(samp))(*args, **kwargs) for samp in sampled]
if len(res) == 1:
return res[0]
else:
return sum(res)
return self.reduction(res, sampled)
class PathSamplingInput(BaseSuperNetModule):
......@@ -95,11 +96,11 @@ class PathSamplingInput(BaseSuperNetModule):
Sampled input indices.
"""
def __init__(self, n_candidates: int, n_chosen: int, reduction: str, label: str):
def __init__(self, n_candidates: int, n_chosen: int, reduction_type: str, label: str):
super().__init__()
self.n_candidates = n_candidates
self.n_chosen = n_chosen
self.reduction = reduction
self.reduction_type = reduction_type
self._sampled: list[int] | int | None = None
self.label = label
......@@ -144,6 +145,19 @@ class PathSamplingInput(BaseSuperNetModule):
raise ValueError('n_chosen is None is not supported yet.')
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
def reduction(self, items: list[Any], sampled: list[Any]) -> Any:
"""Override this to implement customized reduction."""
if len(items) == 1:
return items[0]
else:
if self.reduction_type == 'sum':
return sum(items)
elif self.reduction_type == 'mean':
return sum(items) / len(items)
elif self.reduction_type == 'concat':
return torch.cat(items, 1)
raise ValueError(f'Unsupported reduction type: {self.reduction_type}')
def forward(self, input_tensors):
if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.')
......@@ -151,15 +165,7 @@ class PathSamplingInput(BaseSuperNetModule):
raise ValueError(f'Expect {self.n_candidates} input tensors, found {len(input_tensors)}.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
res = [input_tensors[samp] for samp in sampled]
if len(res) == 1:
return res[0]
else:
if self.reduction == 'sum':
return sum(res)
elif self.reduction == 'mean':
return sum(res) / len(res)
elif self.reduction == 'concat':
return torch.cat(res, 1)
return self.reduction(res, sampled)
class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
......@@ -202,6 +208,7 @@ class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
return result
def forward_argument(self, operation: MixedOperation, name: str) -> Any:
# NOTE: we don't support sampling a list here.
if self._sampled is None:
raise ValueError('Need to call resample() before running forward')
if name in operation.mutable_arguments:
......@@ -257,20 +264,23 @@ class PathSamplingRepeat(BaseSuperNetModule):
# Only interesting when depth is mutable
return cls(cast(List[nn.Module], module.blocks), module.depth_choice)
def reduction(self, items: list[Any], sampled: list[Any]):
"""Override this to implement customized reduction."""
return weighted_sum(items)
def forward(self, x):
if self._sampled is None:
raise RuntimeError('At least one depth needs to be sampled before fprop.')
if isinstance(self._sampled, list):
res = []
for i, block in enumerate(self.blocks):
x = block(x)
if i in self._sampled:
res.append(x)
return sum(res)
else:
for block in self.blocks[:self._sampled]:
x = block(x)
return x
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
res = []
for cur_depth, block in enumerate(self.blocks, start=1):
x = block(x)
if cur_depth in sampled:
res.append(x)
if not any(d > cur_depth for d in sampled):
break
return self.reduction(res, sampled)
class PathSamplingCell(BaseSuperNetModule):
......
......@@ -215,7 +215,7 @@ def _mnist_net(type_, evaluator_kwargs):
base_model = CustomOpValueChoiceNet()
else:
raise ValueError(f'Unsupported type: {type_}')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = nni.trace(MNIST)('data/mnist', train=True, download=True, transform=transform)
# Multi-GPU combined dataloader will break this subset sampler. Expected though.
......
......@@ -78,6 +78,46 @@ def test_valuechoice_utils():
for value, weight in ans.items():
assert abs(weight - weights[value]) < 1e-6
assert evaluate_constant(ValueChoice([3, 4, 6], label='x') - ValueChoice([3, 4, 6], label='x')) == 0
with pytest.raises(ValueError):
evaluate_constant(ValueChoice([3, 4, 6]) - ValueChoice([3, 4, 6]))
assert evaluate_constant(ValueChoice([3, 4, 6], label='x') * 2 / ValueChoice([3, 4, 6], label='x')) == 2
def test_weighted_sum():
weights = [0.1, 0.2, 0.7]
items = [1, 2, 3]
assert abs(weighted_sum(items, weights) - 2.6) < 1e-6
assert weighted_sum(items) == 6
with pytest.raises(TypeError, match='Unsupported'):
weighted_sum(['a', 'b', 'c'], weights)
assert abs(weighted_sum(np.arange(3), weights).item() - 1.6) < 1e-6
items = [torch.full((2, 3, 5), i) for i in items]
assert abs(weighted_sum(items, weights).flatten()[0].item() - 2.6) < 1e-6
items = [torch.randn(2, 3, i) for i in [1, 2, 3]]
with pytest.raises(ValueError, match=r'does not match.*\n.*torch\.Tensor\(2, 3, 1\)'):
weighted_sum(items, weights)
items = [(1, 2), (3, 4), (5, 6)]
res = weighted_sum(items, weights)
assert len(res) == 2 and abs(res[0] - 4.2) < 1e-6 and abs(res[1] - 5.2) < 1e-6
items = [(1, 2), (3, 4), (5, 6, 7)]
with pytest.raises(ValueError):
weighted_sum(items, weights)
items = [{"a": i, "b": np.full((2, 3, 5), i)} for i in [1, 2, 3]]
res = weighted_sum(items, weights)
assert res['b'].shape == (2, 3, 5)
assert abs(res['b'][0][0][0] - res['a']) < 1e-6
assert abs(res['a'] - 2.6) < 1e-6
def test_pathsampling_valuechoice():
orig_conv = Conv2d(3, ValueChoice([3, 5, 7], label='123'), kernel_size=3)
......@@ -147,6 +187,26 @@ def test_mixed_conv2d():
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in'))
assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10)).size() == torch.Size([2, 6, 10, 10])
# groups, invalid case
conv = Conv2d(ValueChoice([9, 6, 3], label='in'), ValueChoice([9, 6, 3], label='in'), 1, groups=9)
with pytest.raises(RuntimeError):
assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10))
# groups, differentiable
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='out'), 1, groups=ValueChoice([3, 6, 9], label='in'))
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in'))
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
with pytest.raises(ValueError):
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 9], label='groups'))
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
with pytest.raises(RuntimeError):
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in') // 3)
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 10, 3, 3))
# make sure kernel is sliced correctly
conv = Conv2d(1, 1, ValueChoice([1, 3], label='k'), bias=False)
conv = MixedConv2d.mutate(conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
......@@ -238,13 +298,18 @@ def test_differentiable_layer_input():
assert op.export({})['eee'] in ['a', 'b']
assert len(list(op.parameters())) == 3
with pytest.raises(ValueError):
op = DifferentiableMixedLayer([('a', Linear(2, 3)), ('b', Linear(2, 4))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
op(torch.randn(4, 2))
input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2
assert len(input.export({})['ddd']) == 2
def test_proxyless_layer_input():
op = ProxylessMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
op = ProxylessMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)),
nn.Softmax(-1), 'eee')
assert op.resample({})['eee'] in ['a', 'b']
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b']
......@@ -286,6 +351,31 @@ def test_differentiable_repeat():
sample = op.export({})
assert 'ccc' in sample and sample['ccc'] in [0, 1]
class TupleModule(nn.Module):
def __init__(self, num):
super().__init__()
self.num = num
def forward(self, *args, **kwargs):
return torch.full((2, 3), self.num), torch.full((3, 5), self.num), {'a': 7, 'b': [self.num] * 11}
class CustomSoftmax(nn.Softmax):
def forward(self, *args, **kwargs):
return [0.3, 0.3, 0.4]
op = DifferentiableMixedRepeat(
[TupleModule(i + 1) for i in range(4)],
ValueChoice([1, 2, 4], label='ccc'),
CustomSoftmax(),
{}
)
op.resample({})
res = op(None)
assert len(res) == 3
assert res[0].shape == (2, 3) and res[0][0][0].item() == 2.5
assert res[2]['a'] == 7
assert len(res[2]['b']) == 11 and res[2]['b'][-1] == 2.5
def test_pathsampling_cell():
for cell_cls in [CellSimple, CellDefaultArgs, CellCustomProcessor, CellLooseEnd, CellOpFactory]:
......@@ -363,4 +453,3 @@ def test_differentiable_cell():
else:
# no loose-end support for now
assert output.shape == torch.Size([2, 16 * model.cell.num_nodes])
......@@ -95,7 +95,7 @@ def test_nasbench101():
def test_nasbench201():
ss = searchspace.NasBench101()
ss = searchspace.NasBench201()
_test_searchspace_on_dataset(ss)
......
import logging
import pytest
import numpy as np
import torch
import nni
import nni.retiarii.hub.pytorch as ss
import nni.retiarii.evaluator.pytorch as pl
import nni.retiarii.strategy as stg
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.hub.pytorch.nasnet import NDSStagePathSampling, NDSStageDifferentiable
from torch.utils.data import Subset
from torchvision import transforms
from torchvision.datasets import CIFAR10, ImageNet
pytestmark = pytest.mark.skipif(not torch.cuda.is_available(), reason='Too slow without CUDA.')
def _hub_factory(alias):
if alias == 'nasbench101':
return ss.NasBench101()
if alias == 'nasbench201':
return ss.NasBench201()
if alias == 'mobilenetv3':
return ss.MobileNetV3Space()
if alias == 'mobilenetv3_small':
return ss.MobileNetV3Space(
width_multipliers=(0.75, 1, 1.5),
expand_ratios=(4, 6)
)
if alias == 'proxylessnas':
return ss.ProxylessNAS()
if alias == 'shufflenet':
return ss.ShuffleNetSpace()
if alias == 'autoformer':
return ss.AutoformerSpace()
if '_smalldepth' in alias:
num_cells = (4, 8)
elif '_depth' in alias:
num_cells = (8, 12)
else:
num_cells = 8
if '_width' in alias:
width = (8, 16)
else:
width = 16
if '_imagenet' in alias:
dataset = 'imagenet'
else:
dataset = 'cifar'
if alias.startswith('nasnet'):
return ss.NASNet(width=width, num_cells=num_cells, dataset=dataset)
if alias.startswith('enas'):
return ss.ENAS(width=width, num_cells=num_cells, dataset=dataset)
if alias.startswith('amoeba'):
return ss.AmoebaNet(width=width, num_cells=num_cells, dataset=dataset)
if alias.startswith('pnas'):
return ss.PNAS(width=width, num_cells=num_cells, dataset=dataset)
if alias.startswith('darts'):
return ss.DARTS(width=width, num_cells=num_cells, dataset=dataset)
raise ValueError(f'Unrecognized space: {alias}')
def _strategy_factory(alias, space_type):
# Some search space needs extra hooks
extra_mutation_hooks = []
nds_need_shape_alignment = '_smalldepth' in space_type
if nds_need_shape_alignment:
if alias in ['enas', 'random']:
extra_mutation_hooks.append(NDSStagePathSampling.mutate)
else:
extra_mutation_hooks.append(NDSStageDifferentiable.mutate)
if alias == 'darts':
return stg.DARTS(mutation_hooks=extra_mutation_hooks)
if alias == 'gumbel':
return stg.GumbelDARTS(mutation_hooks=extra_mutation_hooks)
if alias == 'proxyless':
return stg.Proxyless()
if alias == 'enas':
return stg.ENAS(mutation_hooks=extra_mutation_hooks, reward_metric_name='val_acc')
if alias == 'random':
return stg.RandomOneShot(mutation_hooks=extra_mutation_hooks)
raise ValueError(f'Unrecognized strategy: {alias}')
def _dataset_factory(dataset_type, subset=20):
if dataset_type == 'cifar10':
normalize = transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
train_dataset = nni.trace(CIFAR10)(
'../data/cifar10',
train=True,
transform=transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, 4),
transforms.ToTensor(),
normalize,
]))
valid_dataset = nni.trace(CIFAR10)(
'../data/cifar10',
train=False,
transform=transforms.Compose([
transforms.ToTensor(),
normalize,
]))
elif dataset_type == 'imagenet':
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
train_dataset = nni.trace(ImageNet)(
'../data/imagenet',
split='val', # no train data available in tests
transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
valid_dataset = nni.trace(ImageNet)(
'../data/imagenet',
split='val',
transform=transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
normalize,
]))
else:
raise ValueError(f'Unsupported dataset type: {dataset_type}')
if subset:
train_dataset = Subset(train_dataset, np.random.permutation(len(train_dataset))[:subset])
valid_dataset = Subset(valid_dataset, np.random.permutation(len(valid_dataset))[:subset])
return train_dataset, valid_dataset
@pytest.mark.parametrize('space_type', [
# 'nasbench101',
'nasbench201',
'mobilenetv3',
'mobilenetv3_small',
'proxylessnas',
'shufflenet',
# 'autoformer',
'nasnet',
'enas',
'amoeba',
'pnas',
'darts',
'darts_smalldepth',
'darts_depth',
'darts_width',
'darts_width_smalldepth',
'darts_width_depth',
'darts_imagenet',
'darts_width_smalldepth_imagenet',
'enas_smalldepth',
'enas_depth',
'enas_width',
'enas_width_smalldepth',
'enas_width_depth',
'enas_imagenet',
'enas_width_smalldepth_imagenet',
'pnas_width_smalldepth',
'amoeba_width_smalldepth',
])
@pytest.mark.parametrize('strategy_type', [
'darts',
'gumbel',
'proxyless',
'enas',
'random'
])
def test_hub_oneshot(space_type, strategy_type):
NDS_SPACES = ['amoeba', 'darts', 'pnas', 'enas', 'nasnet']
if strategy_type == 'proxyless':
if 'width' in space_type or 'depth' in space_type or \
any(space_type.startswith(prefix) for prefix in NDS_SPACES + ['proxylessnas', 'mobilenetv3']):
pytest.skip('The space has used unsupported APIs.')
if strategy_type in ['darts', 'gumbel'] and space_type == 'mobilenetv3':
pytest.skip('Skip as it consumes too much memory.')
model_space = _hub_factory(space_type)
dataset_type = 'cifar10'
if 'imagenet' in space_type or space_type in ['mobilenetv3', 'proxylessnas', 'shufflenet', 'autoformer']:
dataset_type = 'imagenet'
subset_size = 4
if strategy_type in ['darts', 'gumbel'] and any(space_type.startswith(prefix) for prefix in NDS_SPACES) and '_' in space_type:
subset_size = 2
train_dataset, valid_dataset = _dataset_factory(dataset_type, subset=subset_size)
train_loader = pl.DataLoader(train_dataset, batch_size=2, num_workers=2, shuffle=True)
valid_loader = pl.DataLoader(valid_dataset, batch_size=2, num_workers=2, shuffle=False)
evaluator = pl.Classification(
train_dataloaders=train_loader,
val_dataloaders=valid_loader,
max_epochs=1,
export_onnx=False,
gpus=1 if torch.cuda.is_available() else 0, # 0 for my debug
logger=False, # disable logging and checkpoint to avoid too much log
enable_checkpointing=False,
enable_model_summary=False
# profiler='advanced'
)
# To test on final model:
# model = type(model_space).load_searched_model('darts-v2')
# evaluator.fit(model)
strategy = _strategy_factory(strategy_type, space_type)
config = RetiariiExeConfig()
config.execution_engine = 'oneshot'
experiment = RetiariiExperiment(model_space, evaluator, strategy=strategy)
experiment.run(config)
_original_loglevel = None
def setup_module(module):
global _original_loglevel
_original_loglevel = logging.getLogger("pytorch_lightning").level
logging.getLogger("pytorch_lightning").setLevel(logging.WARNING)
def teardown_module(module):
logging.getLogger("pytorch_lightning").setLevel(_original_loglevel)
......@@ -50,14 +50,16 @@ def prepare_imagenet_subset(data_dir: Path, imagenet_dir: Path):
# Target root dir
subset_dir = data_dir / 'imagenet'
shutil.rmtree(subset_dir, ignore_errors=True)
subset_dir.mkdir(parents=True)
shutil.copyfile(imagenet_dir / 'meta.bin', subset_dir / 'meta.bin')
copied_count = 0
for category_id, imgs in images.items():
random_state.shuffle(imgs)
for img in imgs[:len(imgs) // 10]:
folder_name = Path(img).parent.name
file_name = Path(img).name
(subset_dir / folder_name).mkdir(exist_ok=True, parents=True)
shutil.copyfile(img, subset_dir / folder_name / file_name)
(subset_dir / 'val' / folder_name).mkdir(exist_ok=True, parents=True)
shutil.copyfile(img, subset_dir / 'val' / folder_name / file_name)
copied_count += 1
print(f'Generated a subset of {copied_count} images.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment