Unverified Commit 481aa292 authored by Maze's avatar Maze Committed by GitHub
Browse files

Fix Autoformer to compatible with RandomOneShot strategy (#4987)

parent 5a3d82e8
This diff is collapsed.
...@@ -37,6 +37,11 @@ PRETRAINED_WEIGHT_URLS = { ...@@ -37,6 +37,11 @@ PRETRAINED_WEIGHT_URLS = {
# spos # spos
'spos': f'{NNI_BLOB}/nashub/spos-0b17f6fc.pth', 'spos': f'{NNI_BLOB}/nashub/spos-0b17f6fc.pth',
# autoformer
'autoformer-tiny': f'{NNI_BLOB}/nashub/autoformer-searched-tiny-1e90ebc1.pth',
'autoformer-small': f'{NNI_BLOB}/nashub/autoformer-searched-small-4bc5d4e5.pth',
'autoformer-base': f'{NNI_BLOB}/nashub/autoformer-searched-base-c417590a.pth'
} }
......
...@@ -140,7 +140,7 @@ class Slicable(Generic[T]): ...@@ -140,7 +140,7 @@ class Slicable(Generic[T]):
raise TypeError(f'Unsuppoted weight type: {type(weight)}') raise TypeError(f'Unsuppoted weight type: {type(weight)}')
self.weight = weight self.weight = weight
def __getitem__(self, index: slice_type | multidim_slice) -> T: def __getitem__(self, index: slice_type | multidim_slice | Any) -> T:
if not isinstance(index, tuple): if not isinstance(index, tuple):
index = (index, ) index = (index, )
index = cast(multidim_slice, index) index = cast(multidim_slice, index)
...@@ -267,7 +267,7 @@ def _iterate_over_slice_type(s: slice_type): ...@@ -267,7 +267,7 @@ def _iterate_over_slice_type(s: slice_type):
def _iterate_over_multidim_slice(ms: multidim_slice): def _iterate_over_multidim_slice(ms: multidim_slice):
"""Get :class:`MaybeWeighted` instances in ``ms``.""" """Get :class:`MaybeWeighted` instances in ``ms``."""
for s in ms: for s in ms:
if s is not None: if s is not None and s is not Ellipsis:
yield from _iterate_over_slice_type(s) yield from _iterate_over_slice_type(s)
...@@ -286,8 +286,8 @@ def _evaluate_multidim_slice(ms: multidim_slice, value_fn: _value_fn_type = None ...@@ -286,8 +286,8 @@ def _evaluate_multidim_slice(ms: multidim_slice, value_fn: _value_fn_type = None
"""Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``.""" """Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``."""
res = [] res = []
for s in ms: for s in ms:
if s is not None: if s is not None and s is not Ellipsis:
res.append(_evaluate_slice_type(s, value_fn)) res.append(_evaluate_slice_type(s, value_fn))
else: else:
res.append(None) res.append(s)
return tuple(res) return tuple(res)
...@@ -35,6 +35,7 @@ __all__ = [ ...@@ -35,6 +35,7 @@ __all__ = [
'MixedLinear', 'MixedLinear',
'MixedConv2d', 'MixedConv2d',
'MixedBatchNorm2d', 'MixedBatchNorm2d',
'MixedLayerNorm',
'MixedMultiHeadAttention', 'MixedMultiHeadAttention',
'NATIVE_MIXED_OPERATIONS', 'NATIVE_MIXED_OPERATIONS',
] ]
...@@ -472,6 +473,74 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d): ...@@ -472,6 +473,74 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
eps, eps,
) )
class MixedLayerNorm(MixedOperation, nn.LayerNorm):
"""
Mixed LayerNorm operation.
Supported arguments are:
- ``normalized_shape``
- ``eps`` (only supported in path sampling)
For path-sampling, prefix of ``weight`` and ``bias`` are sliced.
For weighted cases, the maximum ``normalized_shape`` is used directly.
eps is required to be float.
"""
bound_type = retiarii_nn.LayerNorm
argument_list = ['normalized_shape', 'eps']
@staticmethod
def _to_tuple(value: scalar_or_scalar_dict[Any]) -> tuple[Any, Any]:
if not isinstance(value, tuple):
return (value, value)
return value
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
if name not in ['normalized_shape']:
raise NotImplementedError(f'Unsupported value choice on argument: {name}')
all_sizes = set(traverse_all_options(value_choice))
if any(isinstance(sz, (tuple, list)) for sz in all_sizes):
# transpose
all_sizes = list(zip(*all_sizes))
# maximum dim should be calculated on every dimension
return (max(self._to_tuple(sz)) for sz in all_sizes)
else:
return max(all_sizes)
def forward_with_args(self,
normalized_shape,
eps: float,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [eps]):
raise ValueError(_diff_not_compatible_error.format('eps', 'LayerNorm'))
if isinstance(normalized_shape, dict):
normalized_shape = self.normalized_shape
# make it as tuple
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape, )
if isinstance(self.normalized_shape, int):
normalized_shape = (self.normalized_shape, )
# slice all the normalized shape
indices = [slice(0, min(i, j)) for i, j in zip(normalized_shape, self.normalized_shape)]
# remove _S(*)
weight = self.weight[indices] if self.weight is not None else None
bias = self.bias[indices] if self.bias is not None else None
return F.layer_norm(
inputs,
normalized_shape,
weight,
bias,
eps
)
class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
""" """
...@@ -628,6 +697,7 @@ NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [ ...@@ -628,6 +697,7 @@ NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [
MixedLinear, MixedLinear,
MixedConv2d, MixedConv2d,
MixedBatchNorm2d, MixedBatchNorm2d,
MixedLayerNorm,
MixedMultiHeadAttention, MixedMultiHeadAttention,
] ]
......
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, Linear, MultiheadAttention from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, LayerNorm, Linear, MultiheadAttention
from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import ( from nni.retiarii.oneshot.pytorch.supermodule.differentiable import (
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax, MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax,
...@@ -28,6 +28,12 @@ def test_slice(): ...@@ -28,6 +28,12 @@ def test_slice():
assert S(weight)[:, 1:W(3)*2+1, :, 9:13].shape == (3, 6, 24, 4) assert S(weight)[:, 1:W(3)*2+1, :, 9:13].shape == (3, 6, 24, 4)
assert S(weight)[:, 1:W(3)*2+1].shape == (3, 6, 24, 23) assert S(weight)[:, 1:W(3)*2+1].shape == (3, 6, 24, 23)
# Ellipsis
assert S(weight)[..., 9:13].shape == (3, 7, 24, 4)
assert S(weight)[:2, ..., 1:W(3)+1].shape == (2, 7, 24, 3)
assert S(weight)[..., 1:W(3)*2+1].shape == (3, 7, 24, 6)
assert S(weight)[..., :10, 1:W(3)*2+1].shape == (3, 7, 10, 6)
# no effect # no effect
assert S(weight)[:] is weight assert S(weight)[:] is weight
...@@ -227,6 +233,23 @@ def test_mixed_batchnorm2d(): ...@@ -227,6 +233,23 @@ def test_mixed_batchnorm2d():
_mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3)) _mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3))
def test_mixed_layernorm():
ln = LayerNorm(ValueChoice([32, 64], label='normalized_shape'), elementwise_affine=True)
assert _mixed_operation_sampling_sanity_check(ln, {'normalized_shape': 32}, torch.randn(2, 16, 32)).size(-1) == 32
assert _mixed_operation_sampling_sanity_check(ln, {'normalized_shape': 64}, torch.randn(2, 16, 64)).size(-1) == 64
_mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 16, 64))
import itertools
ln = LayerNorm(ValueChoice(list(itertools.product([16, 32, 64], [8, 16])), label='normalized_shape'))
assert list(_mixed_operation_sampling_sanity_check(ln, {'normalized_shape': (16, 8)}, torch.randn(2, 16, 8)).shape[-2:]) == [16, 8]
assert list(_mixed_operation_sampling_sanity_check(ln, {'normalized_shape': (64, 16)}, torch.randn(2, 64, 16)).shape[-2:]) == [64, 16]
_mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 64, 16))
def test_mixed_mhattn(): def test_mixed_mhattn():
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4) mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4)
......
...@@ -78,6 +78,11 @@ def _strategy_factory(alias, space_type): ...@@ -78,6 +78,11 @@ def _strategy_factory(alias, space_type):
extra_mutation_hooks.append(NDSStagePathSampling.mutate) extra_mutation_hooks.append(NDSStagePathSampling.mutate)
else: else:
extra_mutation_hooks.append(NDSStageDifferentiable.mutate) extra_mutation_hooks.append(NDSStageDifferentiable.mutate)
# Autoformer search space require specific extra hooks
if space_type == 'autoformer':
from nni.retiarii.hub.pytorch.autoformer import MixedAbsPosEmbed, MixedClsToken
extra_mutation_hooks.extend([MixedAbsPosEmbed.mutate, MixedClsToken.mutate])
if alias == 'darts': if alias == 'darts':
return stg.DARTS(mutation_hooks=extra_mutation_hooks) return stg.DARTS(mutation_hooks=extra_mutation_hooks)
...@@ -149,7 +154,7 @@ def _dataset_factory(dataset_type, subset=20): ...@@ -149,7 +154,7 @@ def _dataset_factory(dataset_type, subset=20):
'mobilenetv3_small', 'mobilenetv3_small',
'proxylessnas', 'proxylessnas',
'shufflenet', 'shufflenet',
# 'autoformer', 'autoformer',
'nasnet', 'nasnet',
'enas', 'enas',
'amoeba', 'amoeba',
...@@ -186,7 +191,7 @@ def test_hub_oneshot(space_type, strategy_type): ...@@ -186,7 +191,7 @@ def test_hub_oneshot(space_type, strategy_type):
NDS_SPACES = ['amoeba', 'darts', 'pnas', 'enas', 'nasnet'] NDS_SPACES = ['amoeba', 'darts', 'pnas', 'enas', 'nasnet']
if strategy_type == 'proxyless': if strategy_type == 'proxyless':
if 'width' in space_type or 'depth' in space_type or \ if 'width' in space_type or 'depth' in space_type or \
any(space_type.startswith(prefix) for prefix in NDS_SPACES + ['proxylessnas', 'mobilenetv3']): any(space_type.startswith(prefix) for prefix in NDS_SPACES + ['proxylessnas', 'mobilenetv3', 'autoformer']):
pytest.skip('The space has used unsupported APIs.') pytest.skip('The space has used unsupported APIs.')
if strategy_type in ['darts', 'gumbel'] and space_type == 'mobilenetv3': if strategy_type in ['darts', 'gumbel'] and space_type == 'mobilenetv3':
pytest.skip('Skip as it consumes too much memory.') pytest.skip('Skip as it consumes too much memory.')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment