Unverified Commit 9e2a069d authored by Maze's avatar Maze Committed by GitHub
Browse files

One-shot sub state dict implementation (#5054)

parent d68691d0
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import Optional, Tuple, cast, Any, Dict from typing import Optional, Tuple, cast, Any, Dict, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -135,7 +135,7 @@ class TransformerEncoderLayer(nn.Module): ...@@ -135,7 +135,7 @@ class TransformerEncoderLayer(nn.Module):
The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention. The pytorch build-in nn.TransformerEncoderLayer() does not support customed attention.
""" """
def __init__( def __init__(
self, embed_dim, num_heads, mlp_ratio=4., self, embed_dim, num_heads, mlp_ratio: Union[int, float, nn.ValueChoice]=4.,
qkv_bias=False, qk_scale=None, rpe=False, qkv_bias=False, qk_scale=None, rpe=False,
drop_rate=0., attn_drop=0., proj_drop=0., drop_path=0., drop_rate=0., attn_drop=0., proj_drop=0., drop_path=0.,
pre_norm=True, rpe_length=14, head_dim=64 pre_norm=True, rpe_length=14, head_dim=64
...@@ -235,13 +235,18 @@ class MixedClsToken(MixedOperation, ClsToken): ...@@ -235,13 +235,18 @@ class MixedClsToken(MixedOperation, ClsToken):
def super_init_argument(self, name: str, value_choice: ValueChoiceX): def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice)) return max(traverse_all_options(value_choice))
def forward_with_args(self, embed_dim, def slice_param(self, embed_dim, **kwargs) -> Any:
inputs: torch.Tensor) -> torch.Tensor:
embed_dim_ = _W(embed_dim) embed_dim_ = _W(embed_dim)
cls_token = _S(self.cls_token)[..., :embed_dim_] cls_token = _S(self.cls_token)[..., :embed_dim_]
return torch.cat((cls_token.expand(inputs.shape[0], -1, -1), inputs), dim=1) return {'cls_token': cls_token}
def forward_with_args(self, embed_dim,
inputs: torch.Tensor) -> torch.Tensor:
cls_token = self.slice_param(embed_dim)['cls_token']
assert isinstance(cls_token, torch.Tensor)
return torch.cat((cls_token.expand(inputs.shape[0], -1, -1), inputs), dim=1)
@basic_unit @basic_unit
class AbsPosEmbed(nn.Module): class AbsPosEmbed(nn.Module):
...@@ -271,11 +276,17 @@ class MixedAbsPosEmbed(MixedOperation, AbsPosEmbed): ...@@ -271,11 +276,17 @@ class MixedAbsPosEmbed(MixedOperation, AbsPosEmbed):
def super_init_argument(self, name: str, value_choice: ValueChoiceX): def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice)) return max(traverse_all_options(value_choice))
def forward_with_args(self, embed_dim, def slice_param(self, embed_dim, **kwargs) -> Any:
inputs: torch.Tensor) -> torch.Tensor:
embed_dim_ = _W(embed_dim) embed_dim_ = _W(embed_dim)
pos_embed = _S(self.pos_embed)[..., :embed_dim_] pos_embed = _S(self.pos_embed)[..., :embed_dim_]
return {'pos_embed': pos_embed}
def forward_with_args(self, embed_dim,
inputs: torch.Tensor) -> torch.Tensor:
pos_embed = self.slice_param(embed_dim)['pos_embed']
assert isinstance(pos_embed, torch.Tensor)
return inputs + pos_embed return inputs + pos_embed
......
...@@ -77,7 +77,6 @@ def traverse_and_mutate_submodules( ...@@ -77,7 +77,6 @@ def traverse_and_mutate_submodules(
memo = {} memo = {}
module_list = [] module_list = []
def apply(m): def apply(m):
# Need to call list() here because the loop body might replace some children in-place. # Need to call list() here because the loop body might replace some children in-place.
for name, child in list(m.named_children()): for name, child in list(m.named_children()):
...@@ -280,16 +279,21 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -280,16 +279,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
result.update(module.search_space_spec()) result.update(module.search_space_spec())
return result return result
def resample(self) -> dict[str, Any]: def resample(self, memo=None) -> dict[str, Any]:
"""Trigger the resample for each :attr:`nas_modules`. """Trigger the resample for each :attr:`nas_modules`.
Sometimes (e.g., in differentiable cases), it does nothing. Sometimes (e.g., in differentiable cases), it does nothing.
Parameters
----------
memo : dict[str, Any]
Used to ensure the consistency of samples with the same label.
Returns Returns
------- -------
dict dict
Sampled architecture. Sampled architecture.
""" """
result = {} result = memo or {}
for module in self.nas_modules: for module in self.nas_modules:
result.update(module.resample(memo=result)) result.update(module.resample(memo=result))
return result return result
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any, cast from typing import Any, cast, Dict
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
...@@ -19,7 +19,7 @@ from .supermodule.sampling import ( ...@@ -19,7 +19,7 @@ from .supermodule.sampling import (
PathSamplingCell, PathSamplingRepeat PathSamplingCell, PathSamplingRepeat
) )
from .enas import ReinforceController, ReinforceField from .enas import ReinforceController, ReinforceField
from .supermodule.base import sub_state_dict
class RandomSamplingLightningModule(BaseOneShotLightningModule): class RandomSamplingLightningModule(BaseOneShotLightningModule):
_random_note = """ _random_note = """
...@@ -92,6 +92,45 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule): ...@@ -92,6 +92,45 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
) )
return super().export() return super().export()
def _get_base_model(self):
assert isinstance(self.model.model, nn.Module)
base_model: nn.Module = self.model.model
return base_model
def state_dict(self, destination: Any=None, prefix: str='', keep_vars: bool=False) -> Dict[str, Any]:
base_model = self._get_base_model()
state_dict = base_model.state_dict(destination=destination, prefix=prefix, keep_vars=keep_vars)
return state_dict
def load_state_dict(self, state_dict, strict: bool=True) -> None:
base_model = self._get_base_model()
base_model.load_state_dict(state_dict=state_dict, strict=strict)
def sub_state_dict(self, arch: dict[str, Any], destination: Any=None, prefix: str='', keep_vars: bool=False) -> Dict[str, Any]:
"""Given the architecture dict, return the state_dict which can be directly loaded by the fixed subnet.
Parameters
----------
arch : dict[str, Any]
subnet architecture dict.
destination: dict
If provided, the state of module will be updated into the dict and the same object is returned.
Otherwise, an ``OrderedDict`` will be created and returned.
prefix: str
A prefix added to parameter and buffer names to compose the keys in state_dict.
keep_vars: bool
by default the :class:`~torch.Tensor` s returned in the state dict are detached from autograd.
If it's set to ``True``, detaching will not be performed.
Returns
-------
dict
Subnet state dict.
"""
self.resample(memo=arch)
base_model = self._get_base_model()
state_dict = sub_state_dict(base_model, destination, prefix, keep_vars)
return state_dict
class EnasLightningModule(RandomSamplingLightningModule): class EnasLightningModule(RandomSamplingLightningModule):
_enas_note = """ _enas_note = """
......
...@@ -13,7 +13,7 @@ When adding/modifying a new strategy in this file, don't forget to link it in st ...@@ -13,7 +13,7 @@ When adding/modifying a new strategy in this file, don't forget to link it in st
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any, Type from typing import Any, Type, Union
import torch.nn as nn import torch.nn as nn
...@@ -48,31 +48,43 @@ class OneShotStrategy(BaseStrategy): ...@@ -48,31 +48,43 @@ class OneShotStrategy(BaseStrategy):
""" """
return train_dataloaders, val_dataloaders return train_dataloaders, val_dataloaders
def attach_model(self, base_model: Union[Model, nn.Module]):
_reason = 'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
if isinstance(base_model, Model):
if not isinstance(base_model.python_object, nn.Module):
raise TypeError('Model is not a nn.Module. ' + _reason)
py_model: nn.Module = base_model.python_object
if not isinstance(base_model.evaluator, Lightning):
raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.')
evaluator_module: LightningModule = base_model.evaluator.module
evaluator_module.running_mode = 'oneshot'
evaluator_module.set_model(py_model)
else:
from nni.retiarii.evaluator.pytorch.lightning import ClassificationModule
evaluator_module = ClassificationModule()
evaluator_module.running_mode = 'oneshot'
evaluator_module.set_model(base_model)
self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
def run(self, base_model: Model, applied_mutators): def run(self, base_model: Model, applied_mutators):
# one-shot strategy doesn't use ``applied_mutators`` # one-shot strategy doesn't use ``applied_mutators``
# but get the "mutators" on their own # but get the "mutators" on their own
_reason = 'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.' _reason = 'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
if not isinstance(base_model.python_object, nn.Module):
raise TypeError('Model is not a nn.Module. ' + _reason)
py_model: nn.Module = base_model.python_object
if applied_mutators: if applied_mutators:
raise ValueError('Mutator is not empty. ' + _reason) raise ValueError('Mutator is not empty. ' + _reason)
if not isinstance(base_model.evaluator, Lightning): if not isinstance(base_model.evaluator, Lightning):
raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.') raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.')
evaluator_module: LightningModule = base_model.evaluator.module self.attach_model(base_model)
evaluator_module.running_mode = 'oneshot'
evaluator_module.set_model(py_model)
self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
evaluator: Lightning = base_model.evaluator evaluator: Lightning = base_model.evaluator
if evaluator.train_dataloaders is None or evaluator.val_dataloaders is None: if evaluator.train_dataloaders is None or evaluator.val_dataloaders is None:
raise TypeError('Training and validation dataloader are both required to set in evaluator for one-shot strategy.') raise TypeError('Training and validation dataloader are both required to set in evaluator for one-shot strategy.')
train_loader, val_loader = self.preprocess_dataloader(evaluator.train_dataloaders, evaluator.val_dataloaders) train_loader, val_loader = self.preprocess_dataloader(evaluator.train_dataloaders, evaluator.val_dataloaders)
assert isinstance(self.model, BaseOneShotLightningModule)
evaluator.trainer.fit(self.model, train_loader, val_loader) evaluator.trainer.fit(self.model, train_loader, val_loader)
def export_top_models(self, top_k: int = 1) -> list[Any]: def export_top_models(self, top_k: int = 1) -> list[Any]:
...@@ -144,3 +156,7 @@ class RandomOneShot(OneShotStrategy): ...@@ -144,3 +156,7 @@ class RandomOneShot(OneShotStrategy):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(RandomSamplingLightningModule, **kwargs) super().__init__(RandomSamplingLightningModule, **kwargs)
def sub_state_dict(self, arch: dict[str, Any]):
assert isinstance(self.model, RandomSamplingLightningModule)
return self.model.sub_state_dict(arch)
\ No newline at end of file
...@@ -3,13 +3,62 @@ ...@@ -3,13 +3,62 @@
from __future__ import annotations from __future__ import annotations
from typing import Any from collections import OrderedDict
import itertools
from typing import Any, Dict
import torch.nn as nn import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
__all__ = ['BaseSuperNetModule'] __all__ = ['BaseSuperNetModule', 'sub_state_dict']
def sub_state_dict(module: Any, destination: Any=None, prefix: str='', keep_vars: bool=False) -> Dict[str, Any]:
"""Returns a dictionary containing a whole state of the BaseSuperNetModule.
Both parameters and persistent buffers (e.g. running averages) are
included. Keys are corresponding parameter and buffer names.
Parameters and buffers set to ``None`` are not included.
Parameters
----------
arch : dict[str, Any]
subnet architecture dict.
destination (dict, optional):
If provided, the state of module will be updated into the dict
and the same object is returned. Otherwise, an ``OrderedDict``
will be created and returned. Default: ``None``.
prefix (str, optional):
a prefix added to parameter and buffer names to compose the keys in state_dict.
Default: ``''``.
keep_vars (bool, optional):
by default the :class:`~torch.Tensor` s returned in the state dict are
detached from autograd. If it's set to ``True``, detaching will not be performed.
Default: ``False``.
Returns
-------
dict
Subnet state dictionary.
"""
if destination is None:
destination = OrderedDict()
destination._metadata = OrderedDict()
local_metadata = dict(version=module._version)
if hasattr(destination, "_metadata"):
destination._metadata[prefix[:-1]] = local_metadata
if isinstance(module, BaseSuperNetModule):
module._save_to_sub_state_dict(destination, prefix, keep_vars)
else:
module._save_to_state_dict(destination, prefix, keep_vars)
for name, m in module._modules.items():
if m is not None:
sub_state_dict(m, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
return destination
class BaseSuperNetModule(nn.Module): class BaseSuperNetModule(nn.Module):
...@@ -104,3 +153,22 @@ class BaseSuperNetModule(nn.Module): ...@@ -104,3 +153,22 @@ class BaseSuperNetModule(nn.Module):
See :class:`BaseOneShotLightningModule <nni.retiarii.oneshot.pytorch.base_lightning.BaseOneShotLightningModule>` for details. See :class:`BaseOneShotLightningModule <nni.retiarii.oneshot.pytorch.base_lightning.BaseOneShotLightningModule>` for details.
""" """
raise NotImplementedError() raise NotImplementedError()
def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars):
"""Save the params and buffers of the current module to state dict."""
for name, value in itertools.chain(self._parameters.items(), self._buffers.items()): # direct children
if value is None or name in self._non_persistent_buffers_set:
# it won't appear in state dict
continue
destination[prefix + name] = value if keep_vars else value.detach()
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
"""Save the sub-module to state dict."""
for name, module in self._modules.items():
if module is not None:
sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
def _save_to_sub_state_dict(self, destination, prefix, keep_vars):
"""Save to state dict."""
self._save_param_buff_to_state_dict(destination, prefix, keep_vars)
self._save_module_to_state_dict(destination, prefix, keep_vars)
...@@ -23,7 +23,7 @@ from nni.common.hpo_utils import ParameterSpec ...@@ -23,7 +23,7 @@ from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable from nni.common.serializer import is_traceable
from nni.nas.nn.pytorch.choice import ValueChoiceX from nni.nas.nn.pytorch.choice import ValueChoiceX
from .base import BaseSuperNetModule from .base import BaseSuperNetModule, sub_state_dict
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant
from ._operation_utils import Slicable as _S, MaybeWeighted as _W, int_or_int_dict, scalar_or_scalar_dict from ._operation_utils import Slicable as _S, MaybeWeighted as _W, int_or_int_dict, scalar_or_scalar_dict
...@@ -232,6 +232,22 @@ class MixedOperation(BaseSuperNetModule): ...@@ -232,6 +232,22 @@ class MixedOperation(BaseSuperNetModule):
if param.default is not param.empty and param.name not in self.init_arguments: if param.default is not param.empty and param.name not in self.init_arguments:
self.init_arguments[param.name] = param.default self.init_arguments[param.name] = param.default
def slice_param(self, **kwargs):
"""Slice the params and buffers for subnet forward and state dict.
When there is a `mapping=True` in kwargs, the return result will be wrapped in dict.
"""
raise NotImplementedError()
def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars):
kwargs = {name: self.forward_argument(name) for name in self.argument_list}
params_mapping: dict[str, Any] = self.slice_param(**kwargs)
for name, value in itertools.chain(self._parameters.items(), self._buffers.items()): # direct children
if value is None or name in self._non_persistent_buffers_set:
# it won't appear in state dict
continue
value = params_mapping.get(name, value)
destination[prefix + name] = value if keep_vars else value.detach()
class MixedLinear(MixedOperation, nn.Linear): class MixedLinear(MixedOperation, nn.Linear):
"""Mixed linear operation. """Mixed linear operation.
...@@ -250,20 +266,23 @@ class MixedLinear(MixedOperation, nn.Linear): ...@@ -250,20 +266,23 @@ class MixedLinear(MixedOperation, nn.Linear):
def super_init_argument(self, name: str, value_choice: ValueChoiceX): def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice)) return max(traverse_all_options(value_choice))
def forward_with_args(self, def slice_param(self, in_features: int_or_int_dict, out_features: int_or_int_dict, **kwargs) -> Any:
in_features: int_or_int_dict,
out_features: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
in_features_ = _W(in_features) in_features_ = _W(in_features)
out_features_ = _W(out_features) out_features_ = _W(out_features)
weight = _S(self.weight)[:out_features_] weight = _S(self.weight)[:out_features_]
weight = _S(weight)[:, :in_features_] weight = _S(weight)[:, :in_features_]
if self.bias is None: bias = self.bias if self.bias is None else _S(self.bias)[:out_features_]
bias = self.bias
else: return {'weight': weight, 'bias': bias}
bias = _S(self.bias)[:out_features_]
def forward_with_args(self,
in_features: int_or_int_dict,
out_features: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
params_mapping = self.slice_param(in_features, out_features)
weight, bias = [params_mapping.get(name) for name in ['weight', 'bias']]
return F.linear(inputs, weight, bias) return F.linear(inputs, weight, bias)
...@@ -347,19 +366,13 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -347,19 +366,13 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
else: else:
return max(traverse_all_options(value_choice)) return max(traverse_all_options(value_choice))
def forward_with_args(self, def slice_param(self,
in_channels: int_or_int_dict, in_channels: int_or_int_dict,
out_channels: int_or_int_dict, out_channels: int_or_int_dict,
kernel_size: scalar_or_scalar_dict[_int_or_tuple], kernel_size: scalar_or_scalar_dict[_int_or_tuple],
stride: _int_or_tuple, groups: int_or_int_dict,
padding: scalar_or_scalar_dict[_int_or_tuple], **kwargs
dilation: int, ) -> Any:
groups: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [stride, dilation]):
raise ValueError(_diff_not_compatible_error.format('stride, dilation', 'Conv2d'))
in_channels_ = _W(in_channels) in_channels_ = _W(in_channels)
out_channels_ = _W(out_channels) out_channels_ = _W(out_channels)
...@@ -369,6 +382,8 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -369,6 +382,8 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
if not isinstance(groups, dict): if not isinstance(groups, dict):
weight = _S(weight)[:, :in_channels_ // groups] weight = _S(weight)[:, :in_channels_ // groups]
# palce holder
in_channels_per_group = None
else: else:
assert 'groups' in self.mutable_arguments assert 'groups' in self.mutable_arguments
err_message = 'For differentiable one-shot strategy, when groups is a ValueChoice, ' \ err_message = 'For differentiable one-shot strategy, when groups is a ValueChoice, ' \
...@@ -383,15 +398,51 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -383,15 +398,51 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
raise ValueError(err_message) raise ValueError(err_message)
if in_channels_per_group != int(in_channels_per_group): if in_channels_per_group != int(in_channels_per_group):
raise ValueError(f'Input channels per group is found to be a non-integer: {in_channels_per_group}') raise ValueError(f'Input channels per group is found to be a non-integer: {in_channels_per_group}')
# Compute sliced weights and groups (as an integer)
weight = _S(weight)[:, :int(in_channels_per_group)]
kernel_a, kernel_b = self._to_tuple(kernel_size)
kernel_a_, kernel_b_ = _W(kernel_a), _W(kernel_b)
max_kernel_a, max_kernel_b = self.kernel_size # self.kernel_size must be a tuple
kernel_a_left, kernel_b_top = (max_kernel_a - kernel_a_) // 2, (max_kernel_b - kernel_b_) // 2
weight = _S(weight)[:, :, kernel_a_left:kernel_a_left + kernel_a_, kernel_b_top:kernel_b_top + kernel_b_]
bias = _S(self.bias)[:out_channels_] if self.bias is not None else None
return {'weight': weight, 'bias': bias, 'in_channels_per_group': in_channels_per_group}
def forward_with_args(self,
in_channels: int_or_int_dict,
out_channels: int_or_int_dict,
kernel_size: scalar_or_scalar_dict[_int_or_tuple],
stride: _int_or_tuple,
padding: scalar_or_scalar_dict[_int_or_tuple],
dilation: int,
groups: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [stride, dilation]):
raise ValueError(_diff_not_compatible_error.format('stride, dilation', 'Conv2d'))
params_mapping = self.slice_param(in_channels, out_channels, kernel_size, groups)
weight, bias, in_channels_per_group = [
params_mapping.get(name)
for name in ['weight', 'bias', 'in_channels_per_group']
]
if isinstance(groups, dict):
if not isinstance(in_channels_per_group, (int, float)):
raise ValueError(f'Input channels per group is found to be a non-numberic: {in_channels_per_group}')
if inputs.size(1) % in_channels_per_group != 0: if inputs.size(1) % in_channels_per_group != 0:
raise RuntimeError( raise RuntimeError(
f'Input channels must be divisible by in_channels_per_group, but the input shape is {inputs.size()}, ' f'Input channels must be divisible by in_channels_per_group, but the input shape is {inputs.size()}, '
f'while in_channels_per_group = {in_channels_per_group}' f'while in_channels_per_group = {in_channels_per_group}'
) )
else:
# Compute sliced weights and groups (as an integer) groups = inputs.size(1) // int(in_channels_per_group)
weight = _S(weight)[:, :int(in_channels_per_group)]
groups = inputs.size(1) // int(in_channels_per_group)
# slice center # slice center
if isinstance(kernel_size, dict): if isinstance(kernel_size, dict):
...@@ -400,14 +451,6 @@ class MixedConv2d(MixedOperation, nn.Conv2d): ...@@ -400,14 +451,6 @@ class MixedConv2d(MixedOperation, nn.Conv2d):
raise ValueError(f'Use "{self.padding}" in padding is not supported.') raise ValueError(f'Use "{self.padding}" in padding is not supported.')
padding = self.padding # max padding, must be a tuple padding = self.padding # max padding, must be a tuple
kernel_a, kernel_b = self._to_tuple(kernel_size)
kernel_a_, kernel_b_ = _W(kernel_a), _W(kernel_b)
max_kernel_a, max_kernel_b = self.kernel_size # self.kernel_size must be a tuple
kernel_a_left, kernel_b_top = (max_kernel_a - kernel_a_) // 2, (max_kernel_b - kernel_b_) // 2
weight = _S(weight)[:, :, kernel_a_left:kernel_a_left + kernel_a_, kernel_b_top:kernel_b_top + kernel_b_]
bias = _S(self.bias)[:out_channels_] if self.bias is not None else None
# The rest parameters only need to be converted to tuple # The rest parameters only need to be converted to tuple
stride_ = self._to_tuple(stride) stride_ = self._to_tuple(stride)
dilation_ = self._to_tuple(dilation) dilation_ = self._to_tuple(dilation)
...@@ -441,6 +484,21 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d): ...@@ -441,6 +484,21 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
def super_init_argument(self, name: str, value_choice: ValueChoiceX): def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice)) return max(traverse_all_options(value_choice))
def slice_param(self, num_features: int_or_int_dict, **kwargs) -> Any:
if isinstance(num_features, dict):
num_features = self.num_features
weight, bias = self.weight, self.bias
running_mean, running_var = self.running_mean, self.running_var
if num_features < self.num_features:
weight = weight[:num_features]
bias = bias[:num_features]
running_mean = None if running_mean is None else running_mean[:num_features]
running_var = None if running_var is None else running_var[:num_features]
return {'weight': weight, 'bias': bias,
'running_mean': running_mean, 'running_var': running_var}
def forward_with_args(self, def forward_with_args(self,
num_features: int_or_int_dict, num_features: int_or_int_dict,
eps: float, eps: float,
...@@ -450,19 +508,11 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d): ...@@ -450,19 +508,11 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
if any(isinstance(arg, dict) for arg in [eps, momentum]): if any(isinstance(arg, dict) for arg in [eps, momentum]):
raise ValueError(_diff_not_compatible_error.format('eps and momentum', 'BatchNorm2d')) raise ValueError(_diff_not_compatible_error.format('eps and momentum', 'BatchNorm2d'))
if isinstance(num_features, dict): params_mapping = self.slice_param(num_features)
num_features = self.num_features weight, bias, running_mean, running_var = [
params_mapping.get(name)
weight, bias = self.weight, self.bias for name in ['weight', 'bias', 'running_mean', 'running_var']
running_mean, running_var = self.running_mean, self.running_var ]
if num_features < self.num_features:
weight = weight[:num_features]
bias = bias[:num_features]
if running_mean is not None:
running_mean = running_mean[:num_features]
if running_var is not None:
running_var = running_var[:num_features]
if self.training: if self.training:
bn_training = True bn_training = True
...@@ -481,6 +531,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d): ...@@ -481,6 +531,7 @@ class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
eps, eps,
) )
class MixedLayerNorm(MixedOperation, nn.LayerNorm): class MixedLayerNorm(MixedOperation, nn.LayerNorm):
""" """
Mixed LayerNorm operation. Mixed LayerNorm operation.
...@@ -517,14 +568,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm): ...@@ -517,14 +568,7 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
else: else:
return max(all_sizes) return max(all_sizes)
def forward_with_args(self, def slice_param(self, normalized_shape, **kwargs) -> Any:
normalized_shape,
eps: float,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [eps]):
raise ValueError(_diff_not_compatible_error.format('eps', 'LayerNorm'))
if isinstance(normalized_shape, dict): if isinstance(normalized_shape, dict):
normalized_shape = self.normalized_shape normalized_shape = self.normalized_shape
...@@ -541,6 +585,22 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm): ...@@ -541,6 +585,22 @@ class MixedLayerNorm(MixedOperation, nn.LayerNorm):
weight = self.weight[indices] if self.weight is not None else None weight = self.weight[indices] if self.weight is not None else None
bias = self.bias[indices] if self.bias is not None else None bias = self.bias[indices] if self.bias is not None else None
return {'weight': weight, 'bias': bias, 'normalized_shape': normalized_shape}
def forward_with_args(self,
normalized_shape,
eps: float,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [eps]):
raise ValueError(_diff_not_compatible_error.format('eps', 'LayerNorm'))
params_mapping = self.slice_param(normalized_shape)
weight, bias, normalized_shape = [
params_mapping.get(name)
for name in ['weight', 'bias', 'normalized_shape']
]
return F.layer_norm( return F.layer_norm(
inputs, inputs,
normalized_shape, normalized_shape,
...@@ -622,19 +682,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): ...@@ -622,19 +682,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
slice(self.embed_dim * 2, self.embed_dim * 2 + embed_dim) slice(self.embed_dim * 2, self.embed_dim * 2 + embed_dim)
] ]
def forward_with_args( def slice_param(self, embed_dim, kdim, vdim, **kwargs):
self,
embed_dim: int_or_int_dict, num_heads: int,
kdim: int_or_int_dict | None, vdim: int_or_int_dict | None,
dropout: float,
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask: torch.Tensor | None = None,
need_weights: bool = True, attn_mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor | None]:
if any(isinstance(arg, dict) for arg in [num_heads, dropout]):
raise ValueError(_diff_not_compatible_error.format('num_heads and dropout', 'MultiHeadAttention'))
# by default, kdim, vdim can be none # by default, kdim, vdim can be none
if kdim is None: if kdim is None:
kdim = embed_dim kdim = embed_dim
...@@ -643,15 +691,6 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): ...@@ -643,15 +691,6 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
qkv_same_embed_dim = kdim == embed_dim and vdim == embed_dim qkv_same_embed_dim = kdim == embed_dim and vdim == embed_dim
if getattr(self, 'batch_first', False):
# for backward compatibility: v1.7 doesn't have batch_first
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if isinstance(embed_dim, dict):
used_embed_dim = self.embed_dim
else:
used_embed_dim = embed_dim
embed_dim_ = _W(embed_dim) embed_dim_ = _W(embed_dim)
# in projection weights & biases has q, k, v weights concatenated together # in projection weights & biases has q, k, v weights concatenated together
...@@ -673,27 +712,84 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): ...@@ -673,27 +712,84 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
k_proj = _S(k_proj)[:, :_W(kdim)] k_proj = _S(k_proj)[:, :_W(kdim)]
v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim_] v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim_]
v_proj = _S(v_proj)[:, :_W(vdim)] v_proj = _S(v_proj)[:, :_W(vdim)]
else:
q_proj = k_proj = v_proj = None
return {
'in_proj_bias': in_proj_bias, 'in_proj_weight': in_proj_weight,
'bias_k': bias_k, 'bias_v': bias_v,
'out_proj.weight': out_proj_weight, 'out_proj.bias': out_proj_bias,
'q_proj_weight': q_proj, 'k_proj_weight': k_proj, 'v_proj_weight': v_proj,
'qkv_same_embed_dim': qkv_same_embed_dim
}
def _save_param_buff_to_state_dict(self, destination, prefix, keep_vars):
kwargs = {name: self.forward_argument(name) for name in self.argument_list}
params_mapping = self.slice_param(**kwargs, mapping=True)
for name, value in itertools.chain(self._parameters.items(), self._buffers.items()):
if value is None or name in self._non_persistent_buffers_set:
continue
value = params_mapping.get(name, value)
destination[prefix + name] = value if keep_vars else value.detach()
# params of out_proj is handled in ``MixedMultiHeadAttention`` rather than
# ``NonDynamicallyQuantizableLinear`` sub-module. We also convert it to state dict here.
for name in ["out_proj.weight", "out_proj.bias"]:
value = params_mapping.get(name, None)
if value is None:
continue
destination[prefix + name] = value if keep_vars else value.detach()
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
for name, module in self._modules.items():
# the weights of ``NonDynamicallyQuantizableLinear`` has been handled in `_save_param_buff_to_state_dict`.
if isinstance(module, nn.modules.linear.NonDynamicallyQuantizableLinear):
continue
if module is not None:
sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
# The rest part is basically same as pytorch def forward_with_args(
attn_output, attn_output_weights = F.multi_head_attention_forward( self,
query, key, value, used_embed_dim, num_heads, embed_dim: int_or_int_dict, num_heads: int,
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias), kdim: int_or_int_dict | None, vdim: int_or_int_dict | None,
bias_k, bias_v, self.add_zero_attn, dropout: float,
dropout, out_proj_weight, cast(Tensor, out_proj_bias), query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
training=self.training, key_padding_mask: torch.Tensor | None = None,
key_padding_mask=key_padding_mask, need_weights=need_weights, need_weights: bool = True, attn_mask: torch.Tensor | None = None
attn_mask=attn_mask, use_separate_proj_weight=True, ) -> tuple[torch.Tensor, torch.Tensor | None]:
q_proj_weight=q_proj, k_proj_weight=k_proj, v_proj_weight=v_proj)
if any(isinstance(arg, dict) for arg in [num_heads, dropout]):
raise ValueError(_diff_not_compatible_error.format('num_heads and dropout', 'MultiHeadAttention'))
if getattr(self, 'batch_first', False):
# for backward compatibility: v1.7 doesn't have batch_first
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if isinstance(embed_dim, dict):
used_embed_dim = self.embed_dim
else: else:
# Cast tensor here because of a bug in pytorch stub used_embed_dim = embed_dim
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads, params_mapping = self.slice_param(embed_dim, kdim, vdim)
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias), in_proj_bias, in_proj_weight, bias_k, bias_v, \
bias_k, bias_v, self.add_zero_attn, out_proj_weight, out_proj_bias, q_proj, k_proj, v_proj, qkv_same_embed_dim = [
dropout, out_proj_weight, cast(Tensor, out_proj_bias), params_mapping.get(name)
training=self.training, for name in ['in_proj_bias', 'in_proj_weight', 'bias_k', 'bias_v',
key_padding_mask=key_padding_mask, need_weights=need_weights, 'out_proj.weight', 'out_proj.bias', 'q_proj_weight', 'k_proj_weight',
attn_mask=attn_mask) 'v_proj_weight', 'qkv_same_embed_dim']
]
# The rest part is basically same as pytorch
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads,
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=not qkv_same_embed_dim,
q_proj_weight=q_proj, k_proj_weight=k_proj, v_proj_weight=v_proj)
if getattr(self, 'batch_first', False): # backward compatibility if getattr(self, 'batch_first', False): # backward compatibility
return attn_output.transpose(1, 0), attn_output_weights return attn_output.transpose(1, 0), attn_output_weights
...@@ -701,6 +797,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention): ...@@ -701,6 +797,7 @@ class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
return attn_output, attn_output_weights return attn_output, attn_output_weights
NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [ NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [
MixedLinear, MixedLinear,
MixedConv2d, MixedConv2d,
......
...@@ -15,7 +15,7 @@ from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell ...@@ -15,7 +15,7 @@ from nni.nas.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
from nni.nas.nn.pytorch.choice import ValueChoiceX from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.nas.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs from nni.nas.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from .base import BaseSuperNetModule from .base import BaseSuperNetModule, sub_state_dict
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum
from .operation import MixedOperationSamplingPolicy, MixedOperation from .operation import MixedOperationSamplingPolicy, MixedOperation
...@@ -76,6 +76,14 @@ class PathSamplingLayer(BaseSuperNetModule): ...@@ -76,6 +76,14 @@ class PathSamplingLayer(BaseSuperNetModule):
"""Override this to implement customized reduction.""" """Override this to implement customized reduction."""
return weighted_sum(items) return weighted_sum(items)
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
for samp in sampled:
module = getattr(self, str(samp))
if module is not None:
sub_state_dict(module, destination=destination, prefix=prefix, keep_vars=keep_vars)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
if self._sampled is None: if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.') raise RuntimeError('At least one path needs to be sampled before fprop.')
...@@ -229,7 +237,7 @@ class PathSamplingRepeat(BaseSuperNetModule): ...@@ -229,7 +237,7 @@ class PathSamplingRepeat(BaseSuperNetModule):
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int]): def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int]):
super().__init__() super().__init__()
self.blocks = blocks self.blocks: Any = blocks
self.depth = depth self.depth = depth
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth]) self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._sampled: list[int] | int | None = None self._sampled: list[int] | int | None = None
...@@ -268,6 +276,15 @@ class PathSamplingRepeat(BaseSuperNetModule): ...@@ -268,6 +276,15 @@ class PathSamplingRepeat(BaseSuperNetModule):
"""Override this to implement customized reduction.""" """Override this to implement customized reduction."""
return weighted_sum(items) return weighted_sum(items)
def _save_module_to_state_dict(self, destination, prefix, keep_vars):
sampled: Any = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
for cur_depth, (name, module) in enumerate(self.blocks.named_children(), start=1):
if module is not None:
sub_state_dict(module, destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
if not any(d > cur_depth for d in sampled):
break
def forward(self, x): def forward(self, x):
if self._sampled is None: if self._sampled is None:
raise RuntimeError('At least one depth needs to be sampled before fprop.') raise RuntimeError('At least one depth needs to be sampled before fprop.')
......
...@@ -389,3 +389,22 @@ def test_optimizer_lr_scheduler(): ...@@ -389,3 +389,22 @@ def test_optimizer_lr_scheduler():
assert len(learning_rates) == 10 and abs(learning_rates[0] - 0.1) < 1e-5 and \ assert len(learning_rates) == 10 and abs(learning_rates[0] - 0.1) < 1e-5 and \
abs(learning_rates[2] - 0.01) < 1e-5 and abs(learning_rates[-1] - 1e-5) < 1e-6 abs(learning_rates[2] - 0.01) < 1e-5 and abs(learning_rates[-1] - 1e-5) < 1e-6
def test_one_shot_sub_state_dict():
from nni.nas.strategy import RandomOneShot
from nni.nas import fixed_arch
init_kwargs = {}
x = torch.rand(1, 1, 28, 28)
for model_space_cls in [SimpleNet, ValueChoiceConvNet, RepeatNet]:
strategy = RandomOneShot()
model_space = model_space_cls()
strategy.attach_model(model_space)
arch = strategy.model.resample()
with fixed_arch(arch):
model = model_space_cls(**init_kwargs)
model.load_state_dict(strategy.sub_state_dict(arch))
model.eval()
model_space.eval()
assert torch.allclose(model(x), strategy.model(x))
...@@ -154,16 +154,28 @@ def test_differentiable_layerchoice_dedup(): ...@@ -154,16 +154,28 @@ def test_differentiable_layerchoice_dedup():
assert len(memo) == 1 and 'a' in memo assert len(memo) == 1 and 'a' in memo
def _mixed_operation_sampling_sanity_check(operation, memo, *input): def _mutate_op_path_sampling_policy(operation):
for native_op in NATIVE_MIXED_OPERATIONS: for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation): if native_op.bound_type == type(operation):
mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy}) mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
break break
return mutate_op
def _mixed_operation_sampling_sanity_check(operation, memo, *input):
mutate_op = _mutate_op_path_sampling_policy(operation)
mutate_op.resample(memo=memo) mutate_op.resample(memo=memo)
return mutate_op(*input) return mutate_op(*input)
from nni.nas.oneshot.pytorch.supermodule.base import sub_state_dict
def _mixed_operation_state_dict_sanity_check(operation, model, memo, *input):
mutate_op = _mutate_op_path_sampling_policy(operation)
mutate_op.resample(memo=memo)
model.load_state_dict(sub_state_dict(mutate_op))
return mutate_op(*input), model(*input)
def _mixed_operation_differentiable_sanity_check(operation, *input): def _mixed_operation_differentiable_sanity_check(operation, *input):
for native_op in NATIVE_MIXED_OPERATIONS: for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation): if native_op.bound_type == type(operation):
...@@ -188,6 +200,11 @@ def test_mixed_linear(): ...@@ -188,6 +200,11 @@ def test_mixed_linear():
linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8]), bias=ValueChoice([False, True])) linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8]), bias=ValueChoice([False, True]))
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3)) _mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
linear = Linear(ValueChoice([3, 6, 9], label='in_features'), ValueChoice([2, 4, 8], label='out_features'), bias=True)
kwargs = {'in_features': 6, 'out_features': 4}
out1, out2 = _mixed_operation_state_dict_sanity_check(linear, Linear(**kwargs), kwargs, torch.randn(2, 6))
assert torch.allclose(out1, out2)
def test_mixed_conv2d(): def test_mixed_conv2d():
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([2, 4, 8], label='out') * 2, 1) conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([2, 4, 8], label='out') * 2, 1)
...@@ -235,6 +252,17 @@ def test_mixed_conv2d(): ...@@ -235,6 +252,17 @@ def test_mixed_conv2d():
conv.resample({'k': 1}) conv.resample({'k': 1})
assert conv(torch.ones((1, 1, 3, 3))).sum().item() == 9 assert conv(torch.ones((1, 1, 3, 3))).sum().item() == 9
# only `in_channels`, `out_channels`, `kernel_size`, and `groups` influence state_dict
conv = Conv2d(
ValueChoice([2, 4, 8], label='in_channels'), ValueChoice([6, 12, 24], label='out_channels'),
kernel_size=ValueChoice([3, 5, 7], label='kernel_size'), groups=ValueChoice([1, 2], label='groups')
)
kwargs = {
'in_channels': 8, 'out_channels': 12,
'kernel_size': 5, 'groups': 2
}
out1, out2 = _mixed_operation_state_dict_sanity_check(conv, Conv2d(**kwargs), kwargs, torch.randn(2, 8, 16, 16))
assert torch.allclose(out1, out2)
def test_mixed_batchnorm2d(): def test_mixed_batchnorm2d():
bn = BatchNorm2d(ValueChoice([32, 64], label='dim')) bn = BatchNorm2d(ValueChoice([32, 64], label='dim'))
...@@ -244,6 +272,10 @@ def test_mixed_batchnorm2d(): ...@@ -244,6 +272,10 @@ def test_mixed_batchnorm2d():
_mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3)) _mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3))
bn = BatchNorm2d(ValueChoice([32, 48, 64], label='num_features'))
kwargs = {'num_features': 48}
out1, out2 = _mixed_operation_state_dict_sanity_check(bn, BatchNorm2d(**kwargs), kwargs, torch.randn(2, 48, 3, 3))
assert torch.allclose(out1, out2)
def test_mixed_layernorm(): def test_mixed_layernorm():
ln = LayerNorm(ValueChoice([32, 64], label='normalized_shape'), elementwise_affine=True) ln = LayerNorm(ValueChoice([32, 64], label='normalized_shape'), elementwise_affine=True)
...@@ -261,6 +293,10 @@ def test_mixed_layernorm(): ...@@ -261,6 +293,10 @@ def test_mixed_layernorm():
_mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 64, 16)) _mixed_operation_differentiable_sanity_check(ln, torch.randn(2, 64, 16))
ln = LayerNorm(ValueChoice([32, 48, 64], label='normalized_shape'))
kwargs = {'normalized_shape': 48}
out1, out2 = _mixed_operation_state_dict_sanity_check(ln, LayerNorm(**kwargs), kwargs, torch.randn(2, 8, 48))
assert torch.allclose(out1, out2)
def test_mixed_mhattn(): def test_mixed_mhattn():
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4) mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4)
...@@ -293,6 +329,11 @@ def test_mixed_mhattn(): ...@@ -293,6 +329,11 @@ def test_mixed_mhattn():
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(5, 3, 8), torch.randn(5, 3, 8), torch.randn(5, 3, 8)) _mixed_operation_differentiable_sanity_check(mhattn, torch.randn(5, 3, 8), torch.randn(5, 3, 8), torch.randn(5, 3, 8))
mhattn = MultiheadAttention(embed_dim=ValueChoice([4, 8, 16], label='embed_dim'), num_heads=ValueChoice([1, 2, 4], label='num_heads'),
kdim=ValueChoice([4, 8, 16], label='kdim'), vdim=ValueChoice([4, 8, 16], label='vdim'))
kwargs = {'embed_dim': 16, 'num_heads': 2, 'kdim': 4, 'vdim': 8}
(out1, _), (out2, _) = _mixed_operation_state_dict_sanity_check(mhattn, MultiheadAttention(**kwargs), kwargs, torch.randn(7, 2, 16), torch.randn(7, 2, 4), torch.randn(7, 2, 8))
assert torch.allclose(out1, out2)
@pytest.mark.skipif(torch.__version__.startswith('1.7'), reason='batch_first is not supported for legacy PyTorch') @pytest.mark.skipif(torch.__version__.startswith('1.7'), reason='batch_first is not supported for legacy PyTorch')
def test_mixed_mhattn_batch_first(): def test_mixed_mhattn_batch_first():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment