Unverified Commit 14d2966b authored by Frandium's avatar Frandium Committed by GitHub
Browse files

Valuechoice oneshot lightning (#4602)

parent 5b7dac5c
...@@ -39,6 +39,8 @@ class ParameterSpec(NamedTuple): ...@@ -39,6 +39,8 @@ class ParameterSpec(NamedTuple):
categorical: bool # Whether this paramter is categorical (unordered) or numerical (ordered) categorical: bool # Whether this paramter is categorical (unordered) or numerical (ordered)
size: int = None # If it's categorical, how many candidates it has size: int = None # If it's categorical, how many candidates it has
chosen_size: Optional[int] = 1 # If it's categorical, it should choose how many candidates.
# By default, 1. If none, arbitrary number of candidates can be chosen.
# uniform distributed # uniform distributed
low: float = None # Lower bound of uniform parameter low: float = None # Lower bound of uniform parameter
......
...@@ -5,6 +5,6 @@ from .darts import DartsTrainer ...@@ -5,6 +5,6 @@ from .darts import DartsTrainer
from .enas import EnasTrainer from .enas import EnasTrainer
from .proxyless import ProxylessTrainer from .proxyless import ProxylessTrainer
from .random import SinglePathTrainer, RandomTrainer from .random import SinglePathTrainer, RandomTrainer
from .differentiable import DartsModule, ProxylessModule, SnasModule from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasModule, RandomSamplingModule from .sampling import EnasLightningModule, RandomSamplingLightningModule
from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import Dict, Type, Callable, List, Optional import warnings
from itertools import chain
from typing import Dict, Callable, List, Union, Any, Tuple
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.optim as optim import torch.optim as optim
...@@ -9,51 +11,163 @@ import torch.nn as nn ...@@ -9,51 +11,163 @@ import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler from torch.optim.lr_scheduler import _LRScheduler
ReplaceDictType = Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]] import nni.retiarii.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from .supermodule.base import BaseSuperNetModule
__all__ = ['MutationHook', 'BaseSuperNetModule', 'BaseOneShotLightningModule', 'traverse_and_mutate_submodules']
def _replace_module_with_type(root_module: nn.Module, replace_dict: ReplaceDictType, modules: List[nn.Module]):
MutationHook = Callable[[nn.Module, str, Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
def traverse_and_mutate_submodules(
root_module: nn.Module, hooks: List[MutationHook], mutate_kwargs: Dict[str, Any], topdown: bool = True
) -> List[BaseSuperNetModule]:
""" """
Replace xxxChoice in user's model with NAS modules. Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node.
Parameters Parameters
---------- ----------
root_module : nn.Module root_module : nn.Module
User-defined module with xxxChoice in it. In fact, since this method is called in the ``__init__`` of User-defined model space.
``BaseOneShotLightningModule``, this will be a pl.LightningModule. Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`,
replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]] it's usually a ``pytorch_lightning.LightningModule``.
Functions to replace xxxChoice modules. Keys should be xxxChoice type and values should be a The mutation will be in-place on ``root_module``.
function that return an nn.module. hooks : List[MutationHook]
modules : List[nn.Module] List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks.
The replace result. This is also the return value of this function. When a hook returns an module, the module will be replaced (mutated) to the new module.
mutate_kwargs : dict
Extra keyword arguments passed to hooks.
topdown : bool, default = False
If topdown is true, hooks are first called, before traversing its sub-module (i.e., pre-order DFS).
Otherwise, sub-modules are first traversed, before calling hooks on this node (i.e., post-order DFS).
Returns Returns
---------- ----------
modules : List[nn.Module] modules : Dict[str, nn.Module]
The replace result. The replace result.
""" """
if modules is None: memo = {}
modules = []
module_list = []
def apply(m): def apply(m):
for name, child in m.named_children(): for name, child in m.named_children():
child_type = type(child) # post-order DFS
if child_type in replace_dict.keys(): if not topdown:
setattr(m, name, replace_dict[child_type](child)) apply(child)
modules.append((child.key, getattr(m, name)))
else: mutate_result = None
for hook in hooks:
hook_suggest = hook(child, name, memo, mutate_kwargs)
# parse the mutate result
if isinstance(hook_suggest, tuple):
hook_suggest, suppress = hook_suggest
elif hook_suggest is True:
hook_suggest, suppress = None, True
elif not hook_suggest: # none / false
hook_suggest, suppress = None, False
elif isinstance(hook_suggest, nn.Module):
suppress = True
else:
raise TypeError(f'Mutation hook returned {hook_suggest} of unsupported type: {type(hook_suggest)}.')
if hook_suggest is not None:
if not isinstance(hook_suggest, BaseSuperNetModule):
warnings.warn("Mutation hook didn't return a BaseSuperNetModule. It will be ignored in hooked module list.",
RuntimeWarning)
setattr(m, name, hook_suggest)
mutate_result = hook_suggest
# if suppress, no further mutation hooks are called
if suppress:
break
if isinstance(mutate_result, BaseSuperNetModule):
module_list.append(mutate_result)
# pre-order DFS
if topdown:
apply(child) apply(child)
apply(root_module) apply(root_module)
return modules
return module_list
def no_default_hook(module: nn.Module, name: str, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> bool:
"""Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""
# Forward IS NOT supernet
primitive_list = (
nas_nn.LayerChoice,
nas_nn.InputChoice,
nas_nn.ValueChoice,
nas_nn.Repeat,
nas_nn.NasBench101Cell,
# nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet
)
if isinstance(module, primitive_list):
raise TypeError(f'{type(module).__name__} is not supported')
if isinstance(module, nas_nn.Cell) and module.merge_op != 'all':
# need output_node_indices, which depends on super-net
raise TypeError(f'Cell with merge_op `{module.merge_op}` is not supported')
if is_traceable(module):
# check whether there is a value-choice in its arguments
has_valuechoice = False
for arg in chain(module.trace_args, module.trace_kwargs.values()):
if isinstance(arg, ValueChoiceX):
has_valuechoice = True
break
if has_valuechoice:
raise TypeError(f'`basic_unit` {type(module).__name__} with value choice in its arguments is not supported. '
'Please try to remove `basic_unit` to see if that works, or support this type with value choice manually.')
return True # suppress all other hooks
class BaseOneShotLightningModule(pl.LightningModule): class BaseOneShotLightningModule(pl.LightningModule):
_custom_replace_dict_note = """custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None _mutation_hooks_note = """mutation_hooks : List[MutationHook]
The custom xxxChoice replace method. Keys should be ``xxxChoice`` type. Mutation hooks are callable that inputs an Module and returns a :class:`BaseSuperNetModule`.
Values should callable accepting an ``nn.Module`` and returning an ``nn.Module``. They are invoked in :meth:`traverse_and_mutate_submodules`, on each submodules.
This custom replace dict will override the default replace dict of each NAS method. For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
stored in ``nas_modules``, and be the focus of the NAS algorithm.
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
To be more specific, the input arguments are three arguments:
#. a module that might be processed,
#. name of the module in its parent module,
#. a memo dict whose usage depends on the particular algorithm.
Note that the memo should be read/written by hooks.
There won't be any hooks called on root module.
The returned arguments can be also one of the three kinds:
#. tuple of: :class:`BaseSuperNetModule` or None, and boolean,
#. boolean,
#. :class:`BaseSuperNetModule` or None.
The boolean value is ``suppress`` indicates whether the folliwng hooks should be called.
When it's true, it suppresses the subsequent hooks, and they will never be invoked.
Without boolean value specified, it's assumed to be false.
If a none value appears on the place of :class:`BaseSuperNetModule`, it means the hook suggests to
keep the module unchanged, and nothing will happen.
""" """
_inner_module_note = """inner_module : pytorch_lightning.LightningModule _inner_module_note = """inner_module : pytorch_lightning.LightningModule
...@@ -79,30 +193,76 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -79,30 +193,76 @@ class BaseOneShotLightningModule(pl.LightningModule):
Attributes Attributes
---------- ----------
nas_modules : List[nn.Module] nas_modules : List[BaseSuperNetModule]
The replace result of a specific NAS method. Modules that have been mutated, which the search algorithms should care about.
xxxChoice will be replaced with some other modules with respect to the NAS method.
Parameters Parameters
---------- ----------
""" + _inner_module_note + _custom_replace_dict_note """ + _inner_module_note + _mutation_hooks_note
automatic_optimization = False automatic_optimization = False
def __init__(self, inner_module: pl.LightningModule, custom_replace_dict: Optional[ReplaceDictType] = None): def default_mutation_hooks(self) -> List[MutationHook]:
"""Override this to define class-default mutation hooks."""
return [no_default_hook]
def mutate_kwargs(self) -> Dict[str, Any]:
"""Extra keyword arguments passed to mutation hooks. Usually algo-specific."""
return {}
def __init__(self, base_model: pl.LightningModule, mutation_hooks: List[MutationHook] = None):
super().__init__() super().__init__()
assert isinstance(inner_module, pl.LightningModule) assert isinstance(base_model, pl.LightningModule)
self.model = inner_module self.model = base_model
# replace xxxChoice with respect to NAS alg # append the default hooks
# replaced modules are stored in self.nas_modules mutation_hooks = (mutation_hooks or []) + self.default_mutation_hooks()
self.nas_modules = []
choice_replace_dict = self.default_replace_dict # traverse the model, calling hooks on every submodule
if custom_replace_dict is not None: self.nas_modules: List[BaseSuperNetModule] = traverse_and_mutate_submodules(
for k, v in custom_replace_dict.items(): self.model, mutation_hooks, self.mutate_kwargs(), topdown=True)
assert isinstance(v, nn.Module)
choice_replace_dict[k] = v def search_space_spec(self) -> Dict[str, ParameterSpec]:
_replace_module_with_type(self.model, choice_replace_dict, self.nas_modules) """Get the search space specification from ``nas_module``.
Returns
-------
dict
Key is the name of the choice, value is the corresponding :class:`ParameterSpec`.
"""
result = {}
for module in self.nas_modules:
result.update(module.search_space_spec())
return result
def resample(self) -> Dict[str, Any]:
"""Trigger the resample for each ``nas_module``.
Sometimes (e.g., in differentiable cases), it does nothing.
Returns
-------
dict
Sampled architecture.
"""
result = {}
for module in self.nas_modules:
result.update(module.resample(memo=result))
return result
def export(self) -> Dict[str, Any]:
"""
Export the NAS result, ideally the best choice of each ``nas_module``.
You may implement an ``export`` method for your customized ``nas_module``.
Returns
--------
dict
Keys are names of ``nas_modules``, and values are the choice indices of them.
"""
result = {}
for module in self.nas_modules:
result.update(module.export(memo=result))
return result
def forward(self, x): def forward(self, x):
return self.model(x) return self.model(x)
...@@ -138,8 +298,8 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -138,8 +298,8 @@ class BaseOneShotLightningModule(pl.LightningModule):
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization) lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization)
if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers): if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers):
raise Exception( raise Exception(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`." "Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
) )
# variables used to handle optimizer frequency # variables used to handle optimizer frequency
self.cur_optimizer_step = 0 self.cur_optimizer_step = 0
...@@ -148,6 +308,9 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -148,6 +308,9 @@ class BaseOneShotLightningModule(pl.LightningModule):
return arc_optimizers + w_optimizers, lr_schedulers return arc_optimizers + w_optimizers, lr_schedulers
def on_train_start(self): def on_train_start(self):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer self.model.trainer = self.trainer
self.model.log = self.log self.model.log = self.log
return self.model.on_train_start() return self.model.on_train_start()
...@@ -161,10 +324,10 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -161,10 +324,10 @@ class BaseOneShotLightningModule(pl.LightningModule):
def on_fit_end(self): def on_fit_end(self):
return self.model.on_train_end() return self.model.on_train_end()
def on_train_batch_start(self, batch, batch_idx, unused = 0): def on_train_batch_start(self, batch, batch_idx, unused=0):
return self.model.on_train_batch_start(batch, batch_idx, unused) return self.model.on_train_batch_start(batch, batch_idx, unused)
def on_train_batch_end(self, outputs, batch, batch_idx, unused = 0): def on_train_batch_end(self, outputs, batch, batch_idx, unused=0):
return self.model.on_train_batch_end(outputs, batch, batch_idx, unused) return self.model.on_train_batch_end(outputs, batch, batch_idx, unused)
def on_epoch_start(self): def on_epoch_start(self):
...@@ -185,7 +348,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -185,7 +348,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
def on_after_backward(self): def on_after_backward(self):
return self.model.on_after_backward() return self.model.on_after_backward()
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val = None, gradient_clip_algorithm = None): def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None):
return self.model.configure_gradient_clipping(optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm) return self.model.configure_gradient_clipping(optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm)
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
...@@ -200,20 +363,6 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -200,20 +363,6 @@ class BaseOneShotLightningModule(pl.LightningModule):
""" """
return None return None
@property
def default_replace_dict(self):
"""
Default ``xxxChoice`` replace dict. This is called in ``__init__`` to get the default replace functions for your NAS algorithm.
Note that your default replace functions may be overridden by user-defined ``custom_replace_dict``.
Returns
----------
replace_dict : Dict[Type, Callable[nn.Module, nn.Module]]
Same as ``custom_replace_dict`` in ``__init__``, but this will be overridden if users define their own replace functions.
"""
replace_dict = {}
return replace_dict
def call_lr_schedulers(self, batch_index): def call_lr_schedulers(self, batch_index):
""" """
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off
...@@ -226,14 +375,14 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -226,14 +375,14 @@ class BaseOneShotLightningModule(pl.LightningModule):
""" """
def apply(lr_scheduler): def apply(lr_scheduler):
# single scheduler is called every epoch # single scheduler is called every epoch
if isinstance(lr_scheduler, _LRScheduler) and \ if isinstance(lr_scheduler, _LRScheduler) and \
self.trainer.is_last_batch: self.trainer.is_last_batch:
lr_schedulers.step() lr_schedulers.step()
# lr_scheduler_config is called as configured # lr_scheduler_config is called as configured
elif isinstance(lr_scheduler, dict): elif isinstance(lr_scheduler, dict):
interval = lr_scheduler['interval'] interval = lr_scheduler['interval']
frequency = lr_scheduler['frequency'] frequency = lr_scheduler['frequency']
if ( if (
interval == 'step' and interval == 'step' and
batch_index % frequency == 0 batch_index % frequency == 0
) or \ ) or \
...@@ -241,8 +390,8 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -241,8 +390,8 @@ class BaseOneShotLightningModule(pl.LightningModule):
interval == 'epoch' and interval == 'epoch' and
self.trainer.is_last_batch and self.trainer.is_last_batch and
(self.trainer.current_epoch + 1) % frequency == 0 (self.trainer.current_epoch + 1) % frequency == 0
): ):
lr_scheduler.step() lr_scheduler.step()
lr_schedulers = self.lr_schedulers() lr_schedulers = self.lr_schedulers()
...@@ -254,13 +403,13 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -254,13 +403,13 @@ class BaseOneShotLightningModule(pl.LightningModule):
def call_user_optimizers(self, method): def call_user_optimizers(self, method):
""" """
Function that imitates lightning trainer's behaviour of calling user's optimizers. Since auto_optimization is turned off by this Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer. class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Parameters Parameters
---------- ----------
method : str method : str
Method to call. Only 'step' and 'zero_grad' are supported now. Method to call. Only ``step`` and ``zero_grad`` are supported now.
""" """
def apply_method(optimizer, method): def apply_method(optimizer, method):
if method == 'step': if method == 'step':
...@@ -296,7 +445,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -296,7 +445,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
architecture optimizers. architecture optimizers.
""" """
opts = self.optimizers() opts = self.optimizers()
if isinstance(opts,list): if isinstance(opts, list):
# pylint: disable=unsubscriptable-object # pylint: disable=unsubscriptable-object
arc_opts = opts[:self.arc_optim_count] arc_opts = opts[:self.arc_optim_count]
if len(arc_opts) == 1: if len(arc_opts) == 1:
...@@ -310,7 +459,7 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -310,7 +459,7 @@ class BaseOneShotLightningModule(pl.LightningModule):
@property @property
def user_optimizers(self): def user_optimizers(self):
""" """
Get user optimizers from all optimizers. Use this to get user optimizers in ``training step``. Get user optimizers from all optimizers. Use this to get user optimizers in ``training_step``.
Returns Returns
---------- ----------
...@@ -318,26 +467,10 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -318,26 +467,10 @@ class BaseOneShotLightningModule(pl.LightningModule):
Optimizers defined by user's model. This will be None if there is no user optimizers. Optimizers defined by user's model. This will be None if there is no user optimizers.
""" """
opts = self.optimizers() opts = self.optimizers()
if isinstance(opts,list): if isinstance(opts, list):
# pylint: disable=unsubscriptable-object # pylint: disable=unsubscriptable-object
return opts[self.arc_optim_count:] return opts[self.arc_optim_count:]
# If there is only 1 optimizer and no architecture optimizer # If there is only 1 optimizer and no architecture optimizer
if self.arc_optim_count == 0: if self.arc_optim_count == 0:
return opts return opts
return None return None
def export(self):
"""
Export the NAS result, ideally the best choice of each nas_modules.
You may implement an ``export`` method for your customized nas_module.
Returns
--------
result : Dict[str, int]
Keys are names of nas_modules, and values are the choice indices of them.
"""
result = {}
for name, module in self.nas_modules:
if name not in result:
result[name] = module.export()
return result
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from collections import OrderedDict """Experimental version of differentiable one-shot implementation."""
from typing import Optional
from typing import List
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from .base_lightning import BaseOneShotLightningModule, ReplaceDictType
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.label
self.op_choices = nn.ModuleDict(OrderedDict([(name, layer_choice[name]) for name in layer_choice.names]))
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs):
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self, recurse=False):
for name, p in super(DartsLayerChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return list(self.op_choices.keys())[torch.argmax(self.alpha).item()]
class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.name = input_choice.label
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1
def forward(self, inputs):
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self, recurse=False): from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
for name, p in super(DartsInputChoice, self).named_parameters(): from .supermodule.differentiable import (
if name == 'alpha': DifferentiableMixedLayer, DifferentiableMixedInput,
continue MixedOpDifferentiablePolicy, GumbelSoftmax
yield name, p )
from .supermodule.proxyless import ProxylessMixedInput, ProxylessMixedLayer
from .supermodule.operation import NATIVE_MIXED_OPERATIONS
def export(self):
return torch.argsort(-self.alpha).cpu().numpy().tolist()[:self.n_chosen]
class DartsLightningModule(BaseOneShotLightningModule):
class DartsModule(BaseOneShotLightningModule):
_darts_note = """ _darts_note = """
DARTS :cite:p:`liu2018darts` algorithm is one of the most fundamental one-shot algorithm. DARTS :cite:p:`liu2018darts` algorithm is one of the most fundamental one-shot algorithm.
...@@ -74,6 +26,10 @@ class DartsModule(BaseOneShotLightningModule): ...@@ -74,6 +26,10 @@ class DartsModule(BaseOneShotLightningModule):
The current implementation is for DARTS in first order. Second order (unrolled) is not supported yet. The current implementation is for DARTS in first order. Second order (unrolled) is not supported yet.
*New in v2.8*: Supports searching for ValueChoices on operations, with the technique described in
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
One difference is that, in DARTS, we are using Softmax instead of GumbelSoftmax.
{{module_notes}} {{module_notes}}
Parameters Parameters
...@@ -82,18 +38,34 @@ class DartsModule(BaseOneShotLightningModule): ...@@ -82,18 +38,34 @@ class DartsModule(BaseOneShotLightningModule):
{base_params} {base_params}
arc_learning_rate : float arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._custom_replace_dict_note) """.format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _darts_note.format( __doc__ = _darts_note.format(
module_notes='The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.', module_notes='The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note, module_params=BaseOneShotLightningModule._inner_module_note,
) )
def default_mutation_hooks(self) -> List[MutationHook]:
"""Replace modules with differentiable versions"""
hooks = [
DifferentiableMixedLayer.mutate,
DifferentiableMixedInput.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
return hooks
def mutate_kwargs(self):
"""Use differentiable strategy for mixed operations."""
return {
'mixed_op_sampling': MixedOpDifferentiablePolicy
}
def __init__(self, inner_module: pl.LightningModule, def __init__(self, inner_module: pl.LightningModule,
custom_replace_dict: Optional[ReplaceDictType] = None, mutation_hooks: List[MutationHook] = None,
arc_learning_rate: float = 3.0E-4): arc_learning_rate: float = 3.0E-4):
super().__init__(inner_module, custom_replace_dict=custom_replace_dict)
self.arc_learning_rate = arc_learning_rate self.arc_learning_rate = arc_learning_rate
super().__init__(inner_module, mutation_hooks=mutation_hooks)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# grad manually # grad manually
...@@ -105,7 +77,7 @@ class DartsModule(BaseOneShotLightningModule): ...@@ -105,7 +77,7 @@ class DartsModule(BaseOneShotLightningModule):
# phase 1: architecture step # phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless. # The _resample hook is kept for some darts-based NAS methods like proxyless.
# See code of those methods for details. # See code of those methods for details.
self._resample() self.resample()
arc_optim.zero_grad() arc_optim.zero_grad()
arc_step_loss = self.model.training_step(val_batch, 2 * batch_idx) arc_step_loss = self.model.training_step(val_batch, 2 * batch_idx)
if isinstance(arc_step_loss, dict): if isinstance(arc_step_loss, dict):
...@@ -115,7 +87,7 @@ class DartsModule(BaseOneShotLightningModule): ...@@ -115,7 +87,7 @@ class DartsModule(BaseOneShotLightningModule):
arc_optim.step() arc_optim.step()
# phase 2: model step # phase 2: model step
self._resample() self.resample()
self.call_user_optimizers('zero_grad') self.call_user_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1) loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1)
w_step_loss = loss_and_metrics['loss'] \ w_step_loss = loss_and_metrics['loss'] \
...@@ -127,178 +99,21 @@ class DartsModule(BaseOneShotLightningModule): ...@@ -127,178 +99,21 @@ class DartsModule(BaseOneShotLightningModule):
return loss_and_metrics return loss_and_metrics
def _resample(self):
# Note: This hook is kept for following darts-based NAS algs.
pass
def finalize_grad(self): def finalize_grad(self):
# Note: This hook is currently kept for Proxyless NAS. # Note: This hook is currently kept for Proxyless NAS.
pass pass
@property
def default_replace_dict(self):
return {
LayerChoice: DartsLayerChoice,
InputChoice: DartsInputChoice
}
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
# The alpha in DartsXXXChoices is the architecture parameter of DARTS. All alphas share one optimizer. # The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params = {} ctrl_params = []
for _, m in self.nas_modules: for m in self.nas_modules:
if m.name in ctrl_params: ctrl_params += list(m.parameters(arch=True))
assert m.alpha.size() == ctrl_params[m.name].size(), 'Size of parameters with the same label should be same.' ctrl_optim = torch.optim.Adam(list(set(ctrl_params)), 3.e-4, betas=(0.5, 0.999),
m.alpha = ctrl_params[m.name]
else:
ctrl_params[m.name] = m.alpha
ctrl_optim = torch.optim.Adam(list(ctrl_params.values()), 3.e-4, betas=(0.5, 0.999),
weight_decay=1.0E-3) weight_decay=1.0E-3)
return ctrl_optim return ctrl_optim
class _ArchGradientFunction(torch.autograd.Function): class ProxylessLightningModule(DartsLightningModule):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = x.detach()
detached_x.requires_grad = x.requires_grad
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class ProxylessLayerChoice(nn.Module):
def __init__(self, ops):
super(ProxylessLayerChoice, self).__init__()
self.ops = nn.ModuleList(ops)
self.alpha = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
self._binary_gates = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
self.sampled = None
def forward(self, *args, **kwargs):
if self.training:
def run_function(ops, active_id, **kwargs):
def forward(_x):
return ops[active_id](_x, **kwargs)
return forward
def backward_function(ops, active_id, binary_gates, **kwargs):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data, **kwargs)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1
x = args[0]
return _ArchGradientFunction.apply(
x, self._binary_gates, run_function(self.ops, self.sampled, **kwargs),
backward_function(self.ops, self.sampled, self._binary_gates, **kwargs)
)
return super().forward(*args, **kwargs)
def resample(self):
probs = F.softmax(self.alpha, dim=-1)
sample = torch.multinomial(probs, 1)[0].item()
self.sampled = sample
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0
def finalize_grad(self):
binary_grads = self._binary_gates.grad
with torch.no_grad():
if self.alpha.grad is None:
self.alpha.grad = torch.zeros_like(self.alpha.data)
probs = F.softmax(self.alpha, dim=-1)
for i in range(len(self.ops)):
for j in range(len(self.ops)):
self.alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
def export(self):
return torch.argmax(self.alpha).item()
def export_prob(self):
return F.softmax(self.alpha, dim=-1)
class ProxylessInputChoice(nn.Module):
def __init__(self, input_choice):
super().__init__()
self.num_input_candidates = input_choice.n_candidates
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1E-3)
self._binary_gates = nn.Parameter(torch.randn(input_choice.n_candidates) * 1E-3)
self.sampled = None
def forward(self, inputs):
if self.training:
def run_function(active_sample):
return lambda x: x[active_sample]
def backward_function(binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(self.num_input_candidates):
out_k = _x[k].data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
inputs = torch.stack(inputs, 0)
return _ArchGradientFunction.apply(
inputs, self._binary_gates, run_function(self.sampled),
backward_function(self._binary_gates)
)
return super().forward(inputs)
def resample(self, sample=None):
if sample is None:
probs = F.softmax(self.alpha, dim=-1)
sample = torch.multinomial(probs, 1)[0].item()
self.sampled = sample
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0
return self.sampled
def finalize_grad(self):
binary_grads = self._binary_gates.grad
with torch.no_grad():
if self.alpha.grad is None:
self.alpha.grad = torch.zeros_like(self.alpha.data)
probs = F.softmax(self.alpha, dim=-1)
for i in range(self.num_input_candidates):
for j in range(self.num_input_candidates):
self.alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessModule(DartsModule):
_proxyless_note = """ _proxyless_note = """
Implementation of ProxylessNAS :cite:p:`cai2018proxylessnas`. Implementation of ProxylessNAS :cite:p:`cai2018proxylessnas`.
It's a DARTS-based method that resamples the architecture to reduce memory consumption. It's a DARTS-based method that resamples the architecture to reduce memory consumption.
...@@ -313,54 +128,38 @@ class ProxylessModule(DartsModule): ...@@ -313,54 +128,38 @@ class ProxylessModule(DartsModule):
{base_params} {base_params}
arc_learning_rate : float arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._custom_replace_dict_note) """.format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _proxyless_note.format( __doc__ = _proxyless_note.format(
module_notes='This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.', module_notes='This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note, module_params=BaseOneShotLightningModule._inner_module_note,
) )
@property def default_mutation_hooks(self) -> List[MutationHook]:
def default_replace_dict(self): """Replace modules with gumbel-differentiable versions"""
return { hooks = [
LayerChoice: ProxylessLayerChoice, ProxylessMixedLayer.mutate,
InputChoice: ProxylessInputChoice ProxylessMixedInput.mutate,
} no_default_hook,
]
def _resample(self): # FIXME: no support for mixed operation currently
for _, m in self.nas_modules: return hooks
m.resample()
def finalize_grad(self): def finalize_grad(self):
for _, m in self.nas_modules: for m in self.nas_modules:
m.finalize_grad() m.finalize_grad()
class SNASLayerChoice(DartsLayerChoice): class GumbelDartsLightningModule(DartsLightningModule):
def forward(self, *args, **kwargs): _gumbel_darts_note = """
one_hot = F.gumbel_softmax(self.alpha, self.temp)
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
yhat = torch.sum(op_results * one_hot.view(*alpha_shape), 0)
return yhat
class SNASInputChoice(DartsInputChoice):
def forward(self, inputs):
one_hot = F.gumbel_softmax(self.alpha, self.temp)
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
yhat = torch.sum(inputs * one_hot.view(*alpha_shape), 0)
return yhat
class SnasModule(DartsModule):
_snas_note = """
Implementation of SNAS :cite:p:`xie2018snas`. Implementation of SNAS :cite:p:`xie2018snas`.
It's a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution. It's a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution.
Essentially, it samples one path on forward, Essentially, it samples one path on forward,
and implements its own backward to update the architecture parameters based on only one path. and implements its own backward to update the architecture parameters based on only one path.
*New in v2.8*: Supports searching for ValueChoices on operations, with the technique described in
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
{{module_notes}} {{module_notes}}
Parameters Parameters
...@@ -376,20 +175,32 @@ class SnasModule(DartsModule): ...@@ -376,20 +175,32 @@ class SnasModule(DartsModule):
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False. The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
arc_learning_rate : float arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._custom_replace_dict_note) """.format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _snas_note.format( def default_mutation_hooks(self) -> List[MutationHook]:
module_notes='This module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.', """Replace modules with gumbel-differentiable versions"""
module_params=BaseOneShotLightningModule._inner_module_note, hooks = [
) DifferentiableMixedLayer.mutate,
DifferentiableMixedInput.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
return hooks
def mutate_kwargs(self):
"""Use gumbel softmax."""
return {
'mixed_op_sampling': MixedOpDifferentiablePolicy,
'softmax': GumbelSoftmax(),
}
def __init__(self, inner_module, def __init__(self, inner_module,
custom_replace_dict: Optional[ReplaceDictType] = None, mutation_hooks: List[MutationHook] = None,
arc_learning_rate: float = 3.0e-4, arc_learning_rate: float = 3.0e-4,
gumbel_temperature: float = 1., gumbel_temperature: float = 1.,
use_temp_anneal: bool = False, use_temp_anneal: bool = False,
min_temp: float = .33): min_temp: float = .33):
super().__init__(inner_module, custom_replace_dict, arc_learning_rate=arc_learning_rate) super().__init__(inner_module, mutation_hooks, arc_learning_rate=arc_learning_rate)
self.temp = gumbel_temperature self.temp = gumbel_temperature
self.init_temp = gumbel_temperature self.init_temp = gumbel_temperature
self.use_temp_anneal = use_temp_anneal self.use_temp_anneal = use_temp_anneal
...@@ -400,14 +211,7 @@ class SnasModule(DartsModule): ...@@ -400,14 +211,7 @@ class SnasModule(DartsModule):
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp) self.temp = max(self.temp, self.min_temp)
for _, nas_module in self.nas_modules: for module in self.nas_modules:
nas_module.temp = self.temp module._softmax.temp = self.temp
return self.model.on_epoch_start() return self.model.on_epoch_start()
@property
def default_replace_dict(self):
return {
LayerChoice: SNASLayerChoice,
InputChoice: SNASInputChoice
}
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from typing import Dict, Any, Optional """Experimental version of sampling-based one-shot implementation."""
from typing import Dict, Any, List
import random
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from nni.retiarii.nn.pytorch.api import LayerChoice, InputChoice from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .random import PathSamplingLayerChoice, PathSamplingInputChoice from .supermodule.sampling import PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy
from .base_lightning import BaseOneShotLightningModule, ReplaceDictType from .supermodule.operation import NATIVE_MIXED_OPERATIONS
from .enas import ReinforceController, ReinforceField from .enas import ReinforceController, ReinforceField
class EnasModule(BaseOneShotLightningModule): class RandomSamplingLightningModule(BaseOneShotLightningModule):
_random_note = """
Random Sampling NAS Algorithm.
In each epoch, model parameters are trained after a uniformly random sampling of each choice.
Notably, the exporting result is **also a random sample** of the search space.
Parameters
----------
{{module_params}}
{base_params}
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _random_note.format(
module_params=BaseOneShotLightningModule._inner_module_note,
)
# turn on automatic optimization because nothing interesting is going on here.
automatic_optimization = True
def default_mutation_hooks(self) -> List[MutationHook]:
"""Replace modules with differentiable versions"""
hooks = [
PathSamplingLayer.mutate,
PathSamplingInput.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
return hooks
def mutate_kwargs(self):
"""Use path sampling strategy for mixed-operations."""
return {
'mixed_op_sampling': MixedOpPathSamplingPolicy
}
def training_step(self, batch, batch_idx):
self.resample()
return self.model.training_step(batch, batch_idx)
class EnasLightningModule(RandomSamplingLightningModule):
_enas_note = """ _enas_note = """
The implementation of ENAS :cite:p:`pham2018efficient`. There are 2 steps in an epoch. The implementation of ENAS :cite:p:`pham2018efficient`. There are 2 steps in an epoch.
Firstly, training model parameters. Firstly, training model parameters.
...@@ -39,27 +80,34 @@ class EnasModule(BaseOneShotLightningModule): ...@@ -39,27 +80,34 @@ class EnasModule(BaseOneShotLightningModule):
Number of steps that will be aggregated into one mini-batch for RL controller. Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_grad_clip : float ctrl_grad_clip : float
Gradient clipping value of controller. Gradient clipping value of controller.
""".format(base_params=BaseOneShotLightningModule._custom_replace_dict_note) """.format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _enas_note.format( __doc__ = _enas_note.format(
module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.', module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.',
module_params=BaseOneShotLightningModule._inner_module_note, module_params=BaseOneShotLightningModule._inner_module_note,
) )
automatic_optimization = False
def __init__(self, def __init__(self,
inner_module: pl.LightningModule, inner_module: pl.LightningModule,
*,
ctrl_kwargs: Dict[str, Any] = None, ctrl_kwargs: Dict[str, Any] = None,
entropy_weight: float = 1e-4, entropy_weight: float = 1e-4,
skip_weight: float = .8, skip_weight: float = .8,
baseline_decay: float = .999, baseline_decay: float = .999,
ctrl_steps_aggregate: float = 20, ctrl_steps_aggregate: float = 20,
ctrl_grad_clip: float = 0, ctrl_grad_clip: float = 0,
custom_replace_dict: Optional[ReplaceDictType] = None): mutation_hooks: List[MutationHook] = None):
super().__init__(inner_module, custom_replace_dict) super().__init__(inner_module, mutation_hooks)
self.nas_fields = [ReinforceField(name, len(module), # convert parameter spec to legacy ReinforceField
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1) # this part will be refactored
for name, module in self.nas_modules] self.nas_fields: List[ReinforceField] = []
for name, param_spec in self.search_space_spec().items():
if param_spec.chosen_size not in (1, None):
raise ValueError('ENAS does not support n_chosen to be values other than 1 or None.')
self.nas_fields.append(ReinforceField(name, param_spec.size, param_spec.chosen_size == 1))
self.controller = ReinforceController(self.nas_fields, **(ctrl_kwargs or {})) self.controller = ReinforceController(self.nas_fields, **(ctrl_kwargs or {}))
self.entropy_weight = entropy_weight self.entropy_weight = entropy_weight
...@@ -72,20 +120,13 @@ class EnasModule(BaseOneShotLightningModule): ...@@ -72,20 +120,13 @@ class EnasModule(BaseOneShotLightningModule):
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4) return optim.Adam(self.controller.parameters(), lr=3.5e-4)
@property
def default_replace_dict(self):
return {
LayerChoice : PathSamplingLayerChoice,
InputChoice : PathSamplingInputChoice
}
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
# The ConcatenateTrainValDataloader yields both data and which dataloader it comes from. # The ConcatenateTrainValDataloader yields both data and which dataloader it comes from.
batch, source = batch batch, source = batch
if source == 'train': if source == 'train':
# step 1: train model params # step 1: train model params
self._resample() self.resample()
self.call_user_optimizers('zero_grad') self.call_user_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(batch, batch_idx) loss_and_metrics = self.model.training_step(batch, batch_idx)
w_step_loss = loss_and_metrics['loss'] \ w_step_loss = loss_and_metrics['loss'] \
...@@ -99,7 +140,7 @@ class EnasModule(BaseOneShotLightningModule): ...@@ -99,7 +140,7 @@ class EnasModule(BaseOneShotLightningModule):
x, y = batch x, y = batch
arc_opt = self.architecture_optimizers arc_opt = self.architecture_optimizers
arc_opt.zero_grad() arc_opt.zero_grad()
self._resample() self.resample()
with torch.no_grad(): with torch.no_grad():
logits = self.model(x) logits = self.model(x)
# use the default metric of self.model as reward function # use the default metric of self.model as reward function
...@@ -107,8 +148,8 @@ class EnasModule(BaseOneShotLightningModule): ...@@ -107,8 +148,8 @@ class EnasModule(BaseOneShotLightningModule):
_, metric = next(iter(self.model.metrics.items())) _, metric = next(iter(self.model.metrics.items()))
else: else:
if 'default' not in self.model.metrics.keys(): if 'default' not in self.model.metrics.keys():
raise KeyError('model.metrics should contain a ``default`` key when' \ raise KeyError('model.metrics should contain a ``default`` key when'
'there are multiple metrics') 'there are multiple metrics')
metric = self.model.metrics['default'] metric = self.model.metrics['default']
reward = metric(logits, y) reward = metric(logits, y)
...@@ -128,60 +169,23 @@ class EnasModule(BaseOneShotLightningModule): ...@@ -128,60 +169,23 @@ class EnasModule(BaseOneShotLightningModule):
arc_opt.step() arc_opt.step()
arc_opt.zero_grad() arc_opt.zero_grad()
def _resample(self): def resample(self):
""" """Resample the architecture with ENAS controller."""
Resample the architecture as ENAS result. This doesn't require an ``export`` method in nas_modules to work. sample = self.controller.resample()
""" result = self._interpret_controller_sampling_result(sample)
result = self.controller.resample() for module in self.nas_modules:
for name, module in self.nas_modules: module.resample(memo=result)
module.sampled = result[name] return result
def export(self): def export(self):
"""Run one more inference of ENAS controller."""
self.controller.eval() self.controller.eval()
with torch.no_grad(): with torch.no_grad():
return self.controller.resample() return self._interpret_controller_sampling_result(self.controller.resample())
def _interpret_controller_sampling_result(self, sample: Dict[str, int]) -> Dict[str, Any]:
class RandomSamplingModule(BaseOneShotLightningModule): """Convert ``{label: index}`` to ``{label: name}``"""
_random_note = """ space_spec = self.search_space_spec()
Random Sampling NAS Algorithm. for key in list(sample.keys()):
In each epoch, model parameters are trained after a uniformly random sampling of each choice. sample[key] = space_spec[key].values[sample[key]]
Notably, the exporting result is **also a random sample** of the search space. return sample
Parameters
----------
{{module_params}}
{base_params}
""".format(base_params=BaseOneShotLightningModule._custom_replace_dict_note)
__doc__ = _random_note.format(
module_params=BaseOneShotLightningModule._inner_module_note,
)
automatic_optimization = True
def training_step(self, batch, batch_idx):
self._resample()
return self.model.training_step(batch, batch_idx)
@property
def default_replace_dict(self):
return {
LayerChoice : PathSamplingLayerChoice,
InputChoice : PathSamplingInputChoice
}
def _resample(self):
"""
Resample the architecture as RandomSample result. This is simply a uniformly sampling that doesn't require an ``export``
method in nas_modules to work.
"""
result = {}
for name, module in self.nas_modules:
if name not in result:
result[name] = random.randint(0, len(module) - 1)
module.sampled = result[name]
return result
def export(self):
return self._resample()
...@@ -6,6 +6,8 @@ ...@@ -6,6 +6,8 @@
This file is put here simply because it relies on "pytorch". This file is put here simply because it relies on "pytorch".
For consistency, please consider importing strategies from ``nni.retiarii.strategy``. For consistency, please consider importing strategies from ``nni.retiarii.strategy``.
For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course). For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
""" """
import warnings import warnings
...@@ -19,8 +21,8 @@ from nni.retiarii.strategy.base import BaseStrategy ...@@ -19,8 +21,8 @@ from nni.retiarii.strategy.base import BaseStrategy
from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule
from .base_lightning import BaseOneShotLightningModule from .base_lightning import BaseOneShotLightningModule
from .differentiable import DartsModule, ProxylessModule, SnasModule from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasModule, RandomSamplingModule from .sampling import EnasLightningModule, RandomSamplingLightningModule
from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader
...@@ -80,50 +82,50 @@ class OneShotStrategy(BaseStrategy): ...@@ -80,50 +82,50 @@ class OneShotStrategy(BaseStrategy):
class DARTS(OneShotStrategy): class DARTS(OneShotStrategy):
__doc__ = DartsModule._darts_note.format(module_notes='', module_params='') __doc__ = DartsLightningModule._darts_note.format(module_notes='', module_params='')
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(DartsModule, **kwargs) super().__init__(DartsLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders): def _get_dataloader(self, train_dataloader, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders) return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders)
class Proxyless(OneShotStrategy): class Proxyless(OneShotStrategy):
__doc__ = ProxylessModule._proxyless_note.format(module_notes='', module_params='') __doc__ = ProxylessLightningModule._proxyless_note.format(module_notes='', module_params='')
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(EnasModule, **kwargs) super().__init__(ProxylessLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders): def _get_dataloader(self, train_dataloader, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders) return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders)
class SNAS(OneShotStrategy): class GumbelDARTS(OneShotStrategy):
__doc__ = SnasModule._snas_note.format(module_notes='', module_params='') __doc__ = GumbelDartsLightningModule._gumbel_darts_note.format(module_notes='', module_params='')
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(SnasModule, **kwargs) super().__init__(GumbelDartsLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders): def _get_dataloader(self, train_dataloader, val_dataloaders):
return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders) return InterleavedTrainValDataLoader(train_dataloader, val_dataloaders)
class ENAS(OneShotStrategy): class ENAS(OneShotStrategy):
__doc__ = EnasModule._enas_note.format(module_notes='', module_params='') __doc__ = EnasLightningModule._enas_note.format(module_notes='', module_params='')
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(EnasModule, **kwargs) super().__init__(EnasLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders): def _get_dataloader(self, train_dataloader, val_dataloaders):
return ConcatenateTrainValDataLoader(train_dataloader, val_dataloaders) return ConcatenateTrainValDataLoader(train_dataloader, val_dataloaders)
class RandomOneShot(OneShotStrategy): class RandomOneShot(OneShotStrategy):
__doc__ = RandomSamplingModule._random_note.format(module_notes='', module_params='') __doc__ = RandomSamplingLightningModule._random_note.format(module_notes='', module_params='')
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(RandomSamplingModule, **kwargs) super().__init__(RandomSamplingLightningModule, **kwargs)
def _get_dataloader(self, train_dataloader, val_dataloaders): def _get_dataloader(self, train_dataloader, val_dataloaders):
return train_dataloader, val_dataloaders return train_dataloader, val_dataloaders
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Thie file handles "slice" commonly used in mixed-operation.
The ``slice_type`` we support here, is "slice" or "list of slice".
The reason is that sometimes (e.g., in multi-head attention),
the tensor slice could be from multiple parts. This type is extensible.
We can support arbitrary masks in future if we need them.
To slice a tensor, we need ``multidim_slice``,
which is simply a tuple consists of ``slice_type``.
Usually in python programs, the variable put into slice's start, stop and step
should be integers (or NoneType).
But in our case, it could also be a dict from integer to float,
representing a distribution of integer. When that happens,
we convert a "slice with some weighted values", to a "weighted slice".
To this end, we track the computation with ``MaybeWeighted``,
and replay the computation with each possible value.
Meanwhile, we record their weights.
Note that ``MaybeWeighted`` is also extensible.
We can support more types of objects on slice in future.
The fixed/weighted slice is fed into ``_slice_weight``,
which interprets the slice and apply it on a tensor.
"""
import operator
from typing import Tuple, Union, List, Dict, Callable, Optional, Iterator, TypeVar, Any, Generic
import numpy as np
import torch
T = TypeVar('T')
slice_type = Union[slice, List[slice]]
multidim_slice = Tuple[slice_type, ...]
scalar_or_scalar_dict = Union[T, Dict[T, float]]
int_or_int_dict = scalar_or_scalar_dict[int]
_value_fn_type = Optional[Callable[[int_or_int_dict], int]]
def zeros_like(arr: T) -> T:
if isinstance(arr, np.ndarray):
return np.zeros_like(arr)
elif isinstance(arr, torch.Tensor):
return torch.zeros_like(arr)
else:
raise TypeError(f'Unsupported type for {arr}: {type(arr)}')
def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slice:
# get rid of list of slice
result = []
for i in range(len(slice_)):
if isinstance(slice_[i], list):
# convert list of slices to mask
mask = np.zeros(shape[i], dtype=np.bool)
for sl in slice_[i]:
mask[sl] = 1
result.append(mask)
else:
result.append(slice_[i])
return tuple(result)
def _slice_weight(weight: T, slice_: Union[multidim_slice, List[Tuple[multidim_slice, float]]]) -> T:
# slice_ can be a tuple of slice, e.g., ([3:6], [2:4])
# or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3}
if isinstance(slice_, list):
# for weighted case, we get the corresponding masks. e.g.,
# {([3:6],): 0.6, ([2:4],): 0.3} => [0, 0, 0.3, 0.9, 0.6, 0.6] (if the whole length is 6)
# this mask is broadcasted and multiplied onto the weight
masks = []
# the accepted argument is list of tuple here
# because slice can't be key of dict
for sl, wt in slice_:
# create a mask with weight w
with torch.no_grad():
mask = zeros_like(weight)
mask[_eliminate_list_slice(weight.shape, sl)] = 1
# track gradients here
masks.append((mask * wt))
masks = sum(masks)
return masks * weight
else:
# for unweighted case, we slice it directly.
def _do_slice(arr, slice_):
return arr[_eliminate_list_slice(arr.shape, slice_)]
# sometimes, we don't need slice.
# this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr = np.zeros(weight.shape, dtype=np.bool)
no_effect = _do_slice(dummy_arr, slice_).shape == dummy_arr.shape
if no_effect:
return weight
return _do_slice(weight, slice_)
class Slicable(Generic[T]):
"""Wraps the weight so that in can be sliced with a ``multidim_slice``.
The value within the slice can be instances of :class:`MaybeWeighted`.
Examples
--------
>>> weight = conv2d.weight
>>> Slicable(weight)[:MaybeWeighted({32: 0.4, 64: 0.6})]
Tensor of shape (64, 64, 3, 3)
"""
def __init__(self, weight: T):
if not isinstance(weight, np.ndarray) and not torch.is_tensor(weight):
raise TypeError(f'Unsuppoted weight type: {type(weight)}')
self.weight = weight
def __getitem__(self, index: multidim_slice) -> T:
if not isinstance(index, tuple):
index = (index, )
# Get the dict value in index's leafs
# There can be at most one dict
leaf_dict: Optional[Dict[int, float]] = None
for maybe_weighted in _iterate_over_multidim_slice(index):
for d in maybe_weighted.leaf_values():
if isinstance(d, dict):
if leaf_dict is None:
leaf_dict = d
elif leaf_dict is not d:
raise ValueError('There can be at most one distinct dict in leaf values.')
if leaf_dict is None:
# in case of simple types with no dict
res_index = _evaluate_multidim_slice(index)
else:
# there is a dict, iterate over dict
res_index = []
for val, wt in leaf_dict.items():
res_index_item = _evaluate_multidim_slice(index, lambda _: val)
res_index.append((res_index_item, wt))
return _slice_weight(self.weight, res_index)
class MaybeWeighted:
"""Wrap a value (int or dict with int keys), so that the computation on it can be replayed.
It builds a binary tree. If ``value`` is not None, it's a leaf node.
Otherwise, it has left sub-tree and right sub-tree and an operation.
Only support basic arithmetic operations: ``+``, ``-``, ``*``, ``//``.
"""
def __init__(self,
value: Optional[int_or_int_dict] = None, *,
lhs: Optional[Union['MaybeWeighted', int]] = None,
rhs: Optional[Union['MaybeWeighted', int]] = None,
operation: Optional[Callable[[int, int], int]] = None):
if operation is None:
if not isinstance(value, (int, dict)):
raise TypeError(f'Unsupported value type: {type(value)}')
self.value = value
self.lhs = lhs
self.rhs = rhs
self.operation = operation
def leaf_values(self) -> Iterator[Dict[int, float]]:
"""Iterate over values on leaf nodes."""
if self.value is not None:
yield self.value
else:
if isinstance(self.lhs, MaybeWeighted):
yield from self.lhs.leaf_values()
if isinstance(self.rhs, MaybeWeighted):
yield from self.rhs.leaf_values()
def evaluate(self, value_fn: _value_fn_type = None) -> int:
"""Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``.
If ``value_fn`` is none, no replacement will happen and the raw value will be used.
"""
if self.value is not None:
if value_fn is not None:
return value_fn(self.value)
return self.value
else:
if isinstance(self.lhs, MaybeWeighted):
eval_lhs = self.lhs.evaluate(value_fn)
else:
eval_lhs = self.lhs
if isinstance(self.rhs, MaybeWeighted):
eval_rhs = self.rhs.evaluate(value_fn)
else:
eval_rhs = self.rhs
return self.operation(eval_lhs, eval_rhs)
def __repr__(self):
if self.value is not None:
return f'{self.__class__.__name__}({self.value})'
return f'{self.__class__.__name__}(lhs={self.lhs}, rhs={self.rhs}, op={self.operation})'
def __add__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.add)
def __radd__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.add)
def __sub__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.sub)
def __rsub__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.sub)
def __mul__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.mul)
def __rmul__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.mul)
def __floordiv__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.floordiv)
def __rfloordiv__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.floordiv)
def _iterate_over_slice_type(s: slice_type):
if isinstance(s, list):
for se in s:
yield from _iterate_over_slice_type(se)
else:
# s must be a "slice" now
if isinstance(s.start, MaybeWeighted):
yield s.start
if isinstance(s.stop, MaybeWeighted):
yield s.stop
if isinstance(s.step, MaybeWeighted):
yield s.step
def _iterate_over_multidim_slice(ms: multidim_slice):
"""Get :class:`MaybeWeighted` instances in ``ms``."""
for s in ms:
if s is not None:
yield from _iterate_over_slice_type(s)
def _evaluate_slice_type(s: slice_type, value_fn: _value_fn_type = None):
if isinstance(s, list):
return [_evaluate_slice_type(se, value_fn) for se in s]
else:
return slice(
s.start.evaluate(value_fn) if isinstance(s.start, MaybeWeighted) else s.start,
s.stop.evaluate(value_fn) if isinstance(s.stop, MaybeWeighted) else s.stop,
s.step.evaluate(value_fn) if isinstance(s.step, MaybeWeighted) else s.step
)
def _evaluate_multidim_slice(ms: multidim_slice, value_fn: _value_fn_type = None):
"""Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``."""
res = []
for s in ms:
if s is not None:
res.append(_evaluate_slice_type(s, value_fn))
else:
res.append(None)
return tuple(res)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
"""This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__.
These are merely some components of the algorithm. The complete support is an undergoing work item.
Keep this file here so that it can be "blamed".
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.retiarii.nn.pytorch import ValueChoice
class DifferentiableSuperConv2d(nn.Conv2d):
"""
Only ``kernel_size`` ``in_channels`` and ``out_channels`` are supported. Kernel size candidates should be larger or smaller
than each other in both candidates. See examples below:
the following example is not allowed:
>>> ValueChoice(candidates = [(5, 3), (3, 5)])
□ ■ ■ ■ □ □ □ □ □ □
□ ■ ■ ■ □ ■ ■ ■ ■ ■ # candidates are not bigger or smaller on both dimension
□ ■ ■ ■ □ ■ ■ ■ ■ ■
□ ■ ■ ■ □ ■ ■ ■ ■ ■
□ ■ ■ ■ □ □ □ □ □ □
the following 3 examples are valid:
>>> ValueChoice(candidates = [5, 3, 1])
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ ■ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
>>> ValueChoice(candidates = [(5, 7), (3, 5), (1, 3)])
■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ ■ ■ ■ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ □ □
>>> # when the difference between any two candidates is not even, the left upper will be picked:
>>> ValueChoice(candidates = [(5, 5), (4, 4), (3, 3)])
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
"""
def __init__(self, module, name):
self.label = name
args = module.trace_kwargs
# compulsory params
if isinstance(args['in_channels'], ValueChoice):
args['in_channels'] = max(args['in_channels'].candidates)
self.out_channel_candidates = None
if isinstance(args['out_channels'], ValueChoice):
self.out_channel_candidates = sorted(args['out_channels'].candidates, reverse=True)
args['out_channels'] = self.out_channel_candidates[0]
# kernel_size may be an int or tuple, we turn it into a tuple for simplicity
self.kernel_size_candidates = None
if isinstance(args['kernel_size'], ValueChoice):
# unify kernel size as tuple
candidates = args['kernel_size'].candidates
if not isinstance(candidates[0], tuple):
candidates = [(k, k) for k in candidates]
# sort kernel size in descending order
self.kernel_size_candidates = sorted(candidates, key=lambda t: t[0], reverse=True)
for i in range(0, len(self.kernel_size_candidates) - 1):
bigger = self.kernel_size_candidates[i]
smaller = self.kernel_size_candidates[i + 1]
assert bigger[1] > smaller[1] or (bigger[1] == smaller[1] and bigger[0] > smaller[0]), f'Kernel_size candidates ' \
f'should be larger or smaller than each other on both dimensions, but found {bigger} and {smaller}.'
args['kernel_size'] = self.kernel_size_candidates[0]
super().__init__(**args)
self.generate_architecture_params()
def forward(self, input):
# Note that there is no need to handle ``in_channels`` here since it is already handle by the ``out_channels`` in the
# previous module. If we multiply alpha with refer to ``in_channels`` here again, the alpha will indeed be considered
# twice, which is not what we expect.
weight = self.weight
def sum_weight(input_weight, masks, thresholds, indicator):
"""
This is to get the weighted sum of weight.
Parameters
----------
input_weight : Tensor
the weight to be weighted summed
masks : List[Tensor]
weight masks.
thresholds : List[float]
thresholds, should have a length of ``len(masks) - 1``
indicator : Callable[[Tensor, float], float]
take a tensor and a threshold as input, and output the weight
Returns
----------
weight : Tensor
weighted sum of ``input_weight``. this is of the same shape as ``input_sum``
"""
# Note that ``masks`` and ``thresholds`` have different lengths. There alignment is shown below:
# self.xxx_candidates = [ c_0 , c_1 , ... , c_n-2 , c_n-1 ] # descending order
# self.xxx_mask = [ mask_0 , mask_1 , ... , mask_n-2, mask_n-1]
# self.t_xxx = [ t_0 , t_2 , ... , t_n-2 ]
# So we zip the first n-1 items, and multiply masks[-1] in the end.
weight = torch.zeros_like(input_weight)
for mask, t in zip(masks[:-1], thresholds):
cur_part = input_weight * mask
alpha = indicator(cur_part, t)
weight = (weight + cur_part) * alpha
# we do not consider skip-op here for out_channel/expansion candidates, which means at least the smallest channel
# candidate is included
weight += input_weight * masks[-1]
return weight
if self.kernel_size_candidates is not None:
weight = sum_weight(weight, self.kernel_masks, self.t_kernel, self.Lasso_sigmoid)
if self.out_channel_candidates is not None:
weight = sum_weight(weight, self.channel_masks, self.t_expansion, self.Lasso_sigmoid)
output = self._conv_forward(input, weight, self.bias)
return output
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super().named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
"""
result = {
'kernel_size': i,
'out_channels': j
}
which means the best candidate for an argument is the i-th one if candidates are sorted in descending order
"""
result = {}
eps = 1e-5
with torch.no_grad():
if self.kernel_size_candidates is not None:
weight = torch.zeros_like(self.weight)
# ascending order
for i in range(len(self.kernel_size_candidates) - 2, -1, -1):
mask = self.kernel_masks[i]
t = self.t_kernel[i]
cur_part = self.weight * mask
alpha = self.Lasso_sigmoid(cur_part, t)
if alpha <= eps: # takes the smaller one
result['kernel_size'] = self.kernel_size_candidates[i + 1]
break
weight = (weight + cur_part) * alpha
if 'kernel_size' not in result:
result['kernel_size'] = self.kernel_size_candidates[0]
else:
weight = self.weight
if self.out_channel_candidates is not None:
for i in range(len(self.out_channel_candidates) - 2, -1, -1):
mask = self.channel_masks[i]
t = self.t_expansion[i]
alpha = self.Lasso_sigmoid(weight * mask, t)
if alpha <= eps:
result['out_channels'] = self.out_channel_candidates[i + 1]
if 'out_channels' not in result:
result['out_channels'] = self.out_channel_candidates[0]
return result
@staticmethod
def Lasso_sigmoid(matrix, t):
"""
A trick that can make use of both the value of bool(lasso > t) and the gradient of sigmoid(lasso - t)
Parameters
----------
matrix : Tensor
the matrix to calculate lasso norm
t : float
the threshold
"""
lasso = torch.norm(matrix) - t
indicator = (lasso > 0).float() # torch.sign(lasso)
with torch.no_grad():
# indicator = indicator / 2 + .5 # realign indicator from (-1, 1) to (0, 1)
indicator -= F.sigmoid(lasso)
indicator += F.sigmoid(lasso)
return indicator
def generate_architecture_params(self):
self.alpha = {}
if self.kernel_size_candidates is not None:
# kernel size arch params
self.t_kernel = nn.Parameter(torch.rand(len(self.kernel_size_candidates) - 1))
self.alpha['kernel_size'] = self.t_kernel
# kernel size mask
self.kernel_masks = []
for i in range(0, len(self.kernel_size_candidates) - 1):
big_size = self.kernel_size_candidates[i]
small_size = self.kernel_size_candidates[i + 1]
mask = torch.zeros_like(self.weight)
mask[:, :, :big_size[0], :big_size[1]] = 1 # if self.weight.shape = (out, in, 7, 7), big_size = (5, 5) and
mask[:, :, :small_size[0], :small_size[1]] = 0 # small_size = (3, 3), mask will look like:
self.kernel_masks.append(mask) # 0 0 0 0 0 0 0
mask = torch.zeros_like(self.weight) # 0 1 1 1 1 1 0
mask[:, :, :self.kernel_size_candidates[-1][0], :self.kernel_size_candidates[-1][1]] = 1 # 0 1 0 0 0 1 0
self.kernel_masks.append(mask) # 0 1 0 0 0 1 0
# 0 1 0 0 0 1 0
if self.out_channel_candidates is not None: # 0 1 1 1 1 1 0
# out_channel (or expansion) arch params. we do not consider skip-op here, so we # 0 0 0 0 0 0 0
# only generate ``len(self.kernel_size_candidates) - 1 `` thresholds
self.t_expansion = nn.Parameter(torch.rand(len(self.out_channel_candidates) - 1))
self.alpha['out_channels'] = self.t_expansion
self.channel_masks = []
for i in range(0, len(self.out_channel_candidates) - 1):
big_channel, small_channel = self.out_channel_candidates[i], self.out_channel_candidates[i + 1]
mask = torch.zeros_like(self.weight)
mask[:big_channel] = 1
mask[:small_channel] = 0
# if self.weight.shape = (32, in, W, H), big_channel = 16 and small_size = 8, mask will look like:
# 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
self.channel_masks.append(mask)
mask = torch.zeros_like(self.weight)
mask[:self.out_channel_candidates[-1]] = 1
self.channel_masks.append(mask)
class DifferentiableBatchNorm2d(nn.BatchNorm2d):
def __init__(self, module, name):
self.label = name
args = module.trace_kwargs
if isinstance(args['num_features'], ValueChoice):
args['num_features'] = max(args['num_features'].candidates)
super().__init__(**args)
# no architecture parameter is needed for BatchNorm2d Layers
self.alpha = nn.Parameter(torch.tensor([]))
def export(self):
"""
No need to export ``BatchNorm2d``. Refer to the ``Conv2d`` layer that has the ``ValueChoice`` as ``out_channels``.
"""
return -1
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Utilities to process the value choice compositions,
in the way that is most convenient to one-shot algorithms."""
import itertools
from typing import List, Any, Dict, Tuple, Optional, Union
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch.api import ValueChoiceX
Choice = Any
__all__ = ['dedup_inner_choices', 'evaluate_value_choice_with_dict', 'traverse_all_options']
def dedup_inner_choices(value_choices: List[ValueChoiceX]) -> Dict[str, ParameterSpec]:
"""Find all leaf nodes in ``value_choices``,
save them into in the format of ``{label: parameter_spec}``.
"""
result = {}
for value_choice in value_choices:
for choice in value_choice.inner_choices():
param_spec = ParameterSpec(choice.label, 'choice', choice.candidates, (choice.label, ), True, size=len(choice.candidates))
if choice.label in result:
if param_spec != result[choice.label]:
raise ValueError('Value choice conflict: same label with different candidates: '
f'{param_spec} vs. {result[choice.label]}')
else:
result[choice.label] = param_spec
return result
def evaluate_value_choice_with_dict(value_choice: ValueChoiceX, chosen: Dict[str, Choice]) -> Any:
"""To evaluate a composition of value-choice with a dict,
with format of ``{label: chosen_value}``.
The implementation is two-pass. We first get a list of values,
then feed the values into ``value_choice.evaluate``.
This can be potentially optimized in terms of speed.
Examples
--------
>>> chosen = {"exp_ratio": 3}
>>> evaluate_value_choice_with_dict(value_choice_in, chosen)
48
>>> evaluate_value_choice_with_dict(value_choice_out, chosen)
96
"""
choice_inner_values = []
for choice in value_choice.inner_choices():
if choice.label not in chosen:
raise KeyError(f'{value_choice} depends on a value with key {choice.label}, but not found in {chosen}')
choice_inner_values.append(chosen[choice.label])
return value_choice.evaluate(choice_inner_values)
def traverse_all_options(value_choice: ValueChoiceX,
weights: Optional[Dict[str, List[float]]] = None) -> List[Union[Tuple[Any, float], Any]]:
"""Traverse all possible computation outcome of a value choice.
If ``weights`` is not None, it will also compute the probability of each possible outcome.
Parameters
----------
value_choice : ValueChoiceX
The value choice to traverse.
weights : Optional[Dict[str, List[float]]], default = None
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
weights can be provided. The key is label, value are list of float indicating probability.
Normally, they should sum up to 1, but we will not check them in this function.
Returns
-------
List[Union[Tuple[Any, float], Any]]
Results will be sorted and duplicates will be eliminated.
If weights is provided, the return value will be a list of tuple, with option and its weight.
Otherwise, it will be a list of options.
"""
# get a dict of {label: list of tuple of choice and weight}
leafs: Dict[str, List[Tuple[Choice, float]]] = {}
for label, param_spec in dedup_inner_choices([value_choice]).items():
if weights is not None:
if label not in weights:
raise KeyError(f'{value_choice} depends on a weight with key {label}, but not found in {weights}')
if len(weights[label]) != param_spec.size:
raise KeyError(f'Expect weights with {label} to be of length {param_spec.size}, but {len(weights[label])} found')
leafs[label] = list(zip(param_spec.values, weights[label]))
else:
# create a dummy weight of zero, in case that weights are not provided.
leafs[label] = list(zip(param_spec.values, itertools.repeat(0., param_spec.size)))
# result is a dict from a option to its weight
result: Dict[str, Optional[float]] = {}
labels, values = list(leafs.keys()), list(leafs.values())
if not labels:
raise ValueError(f'There expects at least one leaf value choice in {value_choice}, but nothing found')
# get all combinations
for prod_value in itertools.product(*values):
# For example,
# prod_value = ((3, 0.1), ("cat", 0.3), ({"in": 5}, 0.5))
# the first dim is chosen value, second dim is probability
# chosen = {"ks": 3, "animal": "cat", "linear_args": {"in": 5}}
# chosen_weight = np.prod([0.1, 0.3, 0.5])
chosen = {label: value[0] for label, value in zip(labels, prod_value)}
eval_res = evaluate_value_choice_with_dict(value_choice, chosen)
if weights is None:
result[eval_res] = None
else:
# we can't use reduce or inplace product here,
# because weight can sometimes be tensors
chosen_weight = prod_value[0][1]
for value in prod_value[1:]:
if chosen_weight is None:
chosen_weight = value[1]
else:
chosen_weight = chosen_weight * value[1]
if eval_res in result:
result[eval_res] = result[eval_res] + chosen_weight
else:
result[eval_res] = chosen_weight
if weights is None:
return sorted(result.keys())
else:
return sorted(result.items())
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Any, Dict, Tuple, Union
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
class BaseSuperNetModule(nn.Module):
"""
Mutated module in super-net.
Usually, the feed-forward of the module itself is undefined.
It has to be resampled with ``resample()`` so that a specific path is selected.
(Sometimes, this is not required. For example, differentiable super-net.)
A super-net module usually corresponds to one sample. But two exceptions:
* A module can have multiple parameter spec. For example, a convolution-2d can sample kernel size, channels at the same time.
* Multiple modules can share one parameter spec. For example, multiple layer choices with the same label.
For value choice compositions, the parameter spec are bounded to the underlying (original) value choices,
rather than their compositions.
"""
def resample(self, memo: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Resample the super-net module.
Parameters
----------
memo : Dict[str, Any]
Used to ensure the consistency of samples with the same label.
Returns
-------
dict
Sampled result. If nothing new is sampled, it should return an empty dict.
"""
raise NotImplementedError()
def export(self, memo: Dict[str, Any] = None) -> Dict[str, Any]:
"""
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
Parameters
----------
memo : Dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise NotImplementedError()
def search_space_spec(self) -> Dict[str, ParameterSpec]:
"""
Space specification (sample points).
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
For example: ::
{"layer1": ParameterSpec(values=["conv", "pool"])}
"""
raise NotImplementedError()
@classmethod
def mutate(cls, module: nn.Module, name: str, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> \
Union['BaseSuperNetModule', bool, Tuple['BaseSuperNetModule', bool]]:
"""This is a mutation hook that creates a :class:`BaseSuperNetModule`.
The method should be implemented in each specific super-net module,
because they usually have specific rules about what kind of modules to operate on.
Parameters
----------
module : nn.Module
The module to be mutated (replaced).
name : str
Name of this module. With full prefix. For example, ``module1.block1.conv``.
memo : dict
Memo to enable sharing parameters among mutated modules. It should be read and written by
mutate functions themselves.
mutate_kwargs : dict
Algo-related hyper-parameters, and some auxiliary information.
Returns
-------
Union[BaseSuperNetModule, bool, Tuple[BaseSuperNetModule, bool]]
The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks.
See :class:`nni.retiarii.oneshot.pytorch.base.BaseOneShotLightningModule` for details.
"""
raise NotImplementedError()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import functools
import warnings
from typing import List, Tuple, Optional, Dict, Any, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy
from ._valuechoice_utils import traverse_all_options
class GumbelSoftmax(nn.Softmax):
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
def __init__(self, dim: Optional[int] = -1) -> None:
super().__init__(dim)
self.tau = 1
self.hard = False
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return F.gumbel_softmax(inputs, tau=self.tau, hard=self.hard, dim=self.dim)
class DifferentiableMixedLayer(BaseSuperNetModule):
"""
Mixed layer, in which fprop is decided by a weighted sum of several layers.
Proposed in `DARTS: Differentiable Architecture Search <https://arxiv.org/abs/1806.09055>`__.
The weight ``alpha`` is usually learnable, and optimized on validation dataset.
Differentiable sampling layer requires all operators returning the same shape for one input,
as all outputs will be weighted summed to get the final output.
Parameters
----------
paths : List[Tuple[str, nn.Module]]
Layers to choose from. Each is a tuple of name, and its module.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
op_names : str
Operator names.
label : str
Name of the choice.
"""
_arch_parameter_names: List[str] = ['_arch_alpha']
def __init__(self, paths: List[Tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__()
self.op_names = []
if len(alpha) != len(paths):
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({len(paths)}).')
for name, module in paths:
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self.label = label
self._arch_alpha = alpha
self._softmax = softmax
def resample(self, memo):
"""Do nothing. Differentiable layer doesn't need resample."""
return {}
def export(self, memo):
"""Choose the operator with the maximum logit."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[torch.argmax(self._arch_alpha).item()]}
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, LayerChoice):
size = len(module)
if module.label in memo:
alpha = memo[module.label]
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(list(module.named_children()), alpha, softmax, module.label)
def forward(self, *args, **kwargs):
"""The forward of mixed layer accepts same arguments as its sub-layer."""
op_results = torch.stack([getattr(self, op)(*args, **kwargs) for op in self.op_names])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * self._softmax(self._arch_alpha).view(*alpha_shape), 0)
def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
"""Named parameters excluding architecture parameters."""
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name == par_name for par_name in self._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
class DifferentiableMixedInput(BaseSuperNetModule):
"""
Mixed input. Forward returns a weighted sum of candidates.
Implementation is very similar to :class:`DifferentiableMixedLayer`.
Parameters
----------
n_candidates : int
Expect number of input candidates.
n_chosen : int
Expect numebr of inputs finally chosen.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
label : str
Name of the choice.
"""
_arch_parameter_names: List[str] = ['_arch_alpha']
def __init__(self, n_candidates: int, n_chosen: Optional[int], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__()
self.n_candidates = n_candidates
if len(alpha) != n_candidates:
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({n_candidates}).')
if n_chosen is None:
warnings.warn('Differentiable architecture search does not support choosing multiple inputs. Assuming one.',
RuntimeWarning)
self.n_chosen = 1
self.n_chosen = n_chosen
self.label = label
self._softmax = softmax
self._arch_alpha = alpha
def resample(self, memo):
"""Do nothing. Differentiable layer doesn't need resample."""
return {}
def export(self, memo):
"""Choose the operator with the top ``n_chosen`` logits."""
if self.label in memo:
return {} # nothing new to export
chosen = sorted(torch.argsort(-self._arch_alpha).cpu().numpy().tolist()[:self.n_chosen])
if len(chosen) == 1:
chosen = chosen[0]
return {self.label: chosen}
def search_space_spec(self):
return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
(self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen)
}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, InputChoice):
if module.reduction not in ['sum', 'mean']:
raise ValueError('Only input choice of sum/mean reduction is supported.')
size = module.n_candidates
if module.label in memo:
alpha = memo[module.label]
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label)
def forward(self, inputs):
"""Forward takes a list of input candidates."""
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * self._softmax(self._arch_alpha).view(*alpha_shape), 0)
def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
"""Named parameters excluding architecture parameters."""
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name == par_name for par_name in self._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
"""Implementes the differentiable sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Thus the ``_arch_alpha`` here is a parameter dict, and ``named_parameters``
filters out multiple parameters with ``_arch_alpha`` as its prefix.
When this class is asked for ``forward_argument``, it returns a distribution,
i.e., a dict from int to float based on its weights.
All the parameters (``_arch_alpha``, ``parameters()``, ``_softmax``) are
saved as attributes of ``operation``, rather than ``self``,
because this class itself is not a ``nn.Module``, and saved parameters here
won't be optimized.
"""
_arch_parameter_names: List[str] = ['_arch_alpha']
def __init__(self, operation: MixedOperation, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation._arch_alpha = nn.ParameterDict()
for name, spec in operation.search_space_spec().items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
operation._arch_alpha[name] = alpha
operation.parameters = functools.partial(self.parameters, self=operation) # bind self
operation.named_parameters = functools.partial(self.named_parameters, self=operation)
operation._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
@staticmethod
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
@staticmethod
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super(self.__class__, self).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
def resample(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
"""Differentiable. Do nothing in resample."""
return {}
def export(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
for name, spec in operation.search_space_spec().items():
if name in result:
continue
chosen_index = torch.argmax(operation._arch_alpha[name]).item()
result[name] = spec.values[chosen_index]
return result
def forward_argument(self, operation: MixedOperation, name: str) -> Union[Dict[Any, float], Any]:
if name in operation.mutable_arguments:
weights = {label: operation._softmax(alpha) for label, alpha in operation._arch_alpha.items()}
return dict(traverse_all_options(operation.mutable_arguments[name], weights=weights))
return operation.init_arguments[name]
This diff is collapsed.
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Implementation of ProxylessNAS: a hyrbid approach between differentiable and sampling.
The support remains limited. Known limitations include:
- No support for multiple arguments in forward.
- No support for mixed-operation (value choice).
- The code contains duplicates. Needs refactor.
"""
from typing import List, Tuple, Optional
import torch
import torch.nn as nn
from .differentiable import DifferentiableMixedLayer, DifferentiableMixedInput
class _ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = x.detach()
detached_x.requires_grad = x.requires_grad
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class ProxylessMixedLayer(DifferentiableMixedLayer):
"""Proxyless version of differentiable mixed layer.
It resamples a single-path every time, rather than go through the softmax.
"""
_arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, paths: List[Tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(paths, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(len(paths)) * 1E-3)
# like sampling-based methods, it has a ``_sampled``.
self._sampled: Optional[str] = None
self._sample_idx: Optional[int] = None
def forward(self, *args, **kwargs):
def run_function(ops, active_id, **kwargs):
def forward(_x):
return ops[active_id](_x, **kwargs)
return forward
def backward_function(ops, active_id, binary_gates, **kwargs):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data, **kwargs)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1, 'ProxylessMixedLayer only supports exactly one input argument.'
x = args[0]
assert self._sampled is not None, 'Need to call resample() before running fprop.'
list_ops = [getattr(self, op) for op in self.op_names]
return _ArchGradientFunction.apply(
x, self._binary_gates, run_function(list_ops, self._sample_idx, **kwargs),
backward_function(list_ops, self._sample_idx, self._binary_gates, **kwargs)
)
def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo."""
if self.label in memo:
self._sampled = memo[self.label]
self._sample_idx = self.op_names.index(self._sampled)
else:
probs = self._softmax(self._arch_alpha)
self._sample_idx = torch.multinomial(probs, 1)[0].item()
self._sampled = self.op_names[self._sample_idx]
# set binary gates
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[self._sample_idx] = 1.0
return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[torch.argmax(self._arch_alpha).item()]}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(len(self._arch_alpha)):
for j in range(len(self._arch_alpha)):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessMixedInput(DifferentiableMixedInput):
"""Proxyless version of differentiable input choice.
See :class:`ProxylessLayerChoice` for implementation details.
"""
_arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, n_candidates: int, n_chosen: Optional[int], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(n_candidates, n_chosen, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(n_candidates) * 1E-3)
self._sampled: Optional[int] = None
def forward(self, inputs):
def run_function(active_sample):
return lambda x: x[active_sample]
def backward_function(binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(self.n_candidates):
out_k = _x[k].data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
inputs = torch.stack(inputs, 0)
assert self._sampled is not None, 'Need to call resample() before running fprop.'
return _ArchGradientFunction.apply(
inputs, self._binary_gates, run_function(self._sampled),
backward_function(self._binary_gates)
)
def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo."""
if self.label in memo:
self._sampled = memo[self.label]
else:
probs = self._softmax(self._arch_alpha)
sample = torch.multinomial(probs, 1)[0].item()
self._sampled = sample
# set binary gates
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0
return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: torch.argmax(self._arch_alpha).item()}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(self.n_candidates):
for j in range(self.n_candidates):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
from typing import Optional, List, Tuple, Union, Dict, Any
import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from .base import BaseSuperNetModule
from ._valuechoice_utils import evaluate_value_choice_with_dict
from .operation import MixedOperationSamplingPolicy, MixedOperation
class PathSamplingLayer(BaseSuperNetModule):
"""
Mixed layer, in which fprop is decided by exactly one inner layer or sum of multiple (sampled) layers.
If multiple modules are selected, the result will be summed and returned.
Attributes
----------
_sampled : int or list of str
Sampled module indices.
label : str
Name of the choice.
"""
def __init__(self, paths: List[Tuple[str, nn.Module]], label: str):
super().__init__()
self.op_names = []
for name, module in paths:
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self._sampled: Optional[Union[List[str], str]] = None # sampled can be either a list of indices or an index
self.label = label
def resample(self, memo):
"""Random choose one path if label is not found in memo."""
if self.label in memo:
self._sampled = memo[self.label]
else:
self._sampled = random.choice(self.op_names)
return {self.label: self._sampled}
def export(self, memo):
"""Random choose one name if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: random.choice(self.op_names)}
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, LayerChoice):
return cls(list(module.named_children()), module.label)
def forward(self, *args, **kwargs):
if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
# str(samp) is needed here because samp can sometimes be integers, but attr are always str
res = [getattr(self, str(samp))(*args, **kwargs) for samp in sampled]
if len(res) == 1:
return res[0]
else:
return sum(res)
class PathSamplingInput(BaseSuperNetModule):
"""
Mixed input. Take a list of tensor as input, select some of them and return the sum.
Attributes
----------
_sampled : int or list of int
Sampled input indices.
"""
def __init__(self, n_candidates: int, n_chosen: int, reduction: str, label: str):
super().__init__()
self.n_candidates = n_candidates
self.n_chosen = n_chosen
self.reduction = reduction
self._sampled: Optional[Union[List[int], int]] = None
self.label = label
def _random_choose_n(self):
sampling = list(range(self.n_candidates))
random.shuffle(sampling)
sampling = sorted(sampling[:self.n_chosen])
if len(sampling) == 1:
return sampling[0]
else:
return sampling
def resample(self, memo):
"""Random choose one path / multiple paths if label is not found in memo.
If one path is selected, only one integer will be in ``self._sampled``.
If multiple paths are selected, a list will be in ``self._sampled``.
"""
if self.label in memo:
self._sampled = memo[self.label]
else:
self._sampled = self._random_choose_n()
return {self.label: self._sampled}
def export(self, memo):
"""Random choose one name if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self._random_choose_n()}
def search_space_spec(self):
return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
(self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen)
}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, InputChoice):
if module.reduction not in ['sum', 'mean', 'concat']:
raise ValueError('Only input choice of sum/mean/concat reduction is supported.')
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
def forward(self, input_tensors):
if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.')
if len(input_tensors) != self.n_candidates:
raise ValueError(f'Expect {self.n_candidates} input tensors, found {len(input_tensors)}.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
res = [input_tensors[samp] for samp in sampled]
if len(res) == 1:
return res[0]
else:
if self.reduction == 'sum':
return sum(res)
elif self.reduction == 'mean':
return sum(res) / len(res)
elif self.reduction == 'concat':
return torch.cat(res, 1)
class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
"""Implementes the path sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Each value choice can be further decomposed into "leaf value choices".
We sample the leaf nodes, and composits them into the values on arguments.
"""
def __init__(self, operation: MixedOperation, memo: Dict[str, Any], mutate_kwargs: Dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self._sampled: Optional[Dict[str, Any]] = None
def resample(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
"""Random sample for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
for label in space_spec:
if label in memo:
result[label] = memo[label]
else:
result[label] = random.choice(space_spec[label].values)
# composits to kwargs
# example: result = {"exp_ratio": 3}, self._sampled = {"in_channels": 48, "out_channels": 96}
self._sampled = {}
for key, value in operation.mutable_arguments.items():
self._sampled[key] = evaluate_value_choice_with_dict(value, result)
return result
def export(self, operation: MixedOperation, memo: Dict[str, Any] = None) -> Dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
for label in space_spec:
if label not in memo:
result[label] = random.choice(space_spec[label].values)
return result
def forward_argument(self, operation: MixedOperation, name: str) -> Any:
if self._sampled is None:
raise ValueError('Need to call resample() before running forward')
if name in operation.mutable_arguments:
return self._sampled[name]
return operation.init_arguments[name]
...@@ -7,4 +7,4 @@ from .evolution import RegularizedEvolution ...@@ -7,4 +7,4 @@ from .evolution import RegularizedEvolution
from .tpe_strategy import TPEStrategy from .tpe_strategy import TPEStrategy
from .local_debug_strategy import _LocalDebugStrategy from .local_debug_strategy import _LocalDebugStrategy
from .rl import PolicyBasedRL from .rl import PolicyBasedRL
from .oneshot import DARTS, Proxyless, SNAS, ENAS, RandomOneShot from .oneshot import DARTS, Proxyless, GumbelDARTS, ENAS, RandomOneShot
...@@ -5,7 +5,7 @@ from .base import BaseStrategy ...@@ -5,7 +5,7 @@ from .base import BaseStrategy
try: try:
from nni.retiarii.oneshot.pytorch.strategy import ( # pylint: disable=unused-import from nni.retiarii.oneshot.pytorch.strategy import ( # pylint: disable=unused-import
DARTS, SNAS, Proxyless, ENAS, RandomOneShot DARTS, GumbelDARTS, Proxyless, ENAS, RandomOneShot
) )
except ImportError as import_err: except ImportError as import_err:
_import_err = import_err _import_err = import_err
...@@ -16,7 +16,7 @@ except ImportError as import_err: ...@@ -16,7 +16,7 @@ except ImportError as import_err:
# otherwise typing check will pointing to the wrong location # otherwise typing check will pointing to the wrong location
globals()['DARTS'] = ImportFailedStrategy globals()['DARTS'] = ImportFailedStrategy
globals()['SNAS'] = ImportFailedStrategy globals()['GumbelDARTS'] = ImportFailedStrategy
globals()['Proxyless'] = ImportFailedStrategy globals()['Proxyless'] = ImportFailedStrategy
globals()['ENAS'] = ImportFailedStrategy globals()['ENAS'] = ImportFailedStrategy
globals()['RandomOneShot'] = ImportFailedStrategy globals()['RandomOneShot'] = ImportFailedStrategy
import argparse import argparse
import torch import torch
import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import pytorch_lightning as pl import pytorch_lightning as pl
import pytest import pytest
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torch.utils.data.sampler import RandomSampler from torch.utils.data import Dataset, RandomSampler
from nni.retiarii import strategy, model_wrapper import nni.retiarii.nn.pytorch as nn
from nni.retiarii import strategy, model_wrapper, basic_unit
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.evaluator.pytorch.lightning import Classification, DataLoader from nni.retiarii.evaluator.pytorch.lightning import Classification, Regression, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
class DepthwiseSeparableConv(nn.Module): class DepthwiseSeparableConv(nn.Module):
...@@ -25,107 +25,261 @@ class DepthwiseSeparableConv(nn.Module): ...@@ -25,107 +25,261 @@ class DepthwiseSeparableConv(nn.Module):
@model_wrapper @model_wrapper
class Net(pl.LightningModule): class SimpleNet(nn.Module):
def __init__(self): def __init__(self, value_choice=True):
super().__init__() super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1) self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = LayerChoice([ self.conv2 = LayerChoice([
nn.Conv2d(32, 64, 3, 1), nn.Conv2d(32, 64, 3, 1),
DepthwiseSeparableConv(32, 64) DepthwiseSeparableConv(32, 64)
]) ])
self.dropout1 = nn.Dropout(.25) self.dropout1 = LayerChoice([
self.dropout2 = nn.Dropout(0.5) nn.Dropout(.25),
self.dropout_choice = InputChoice(2, 1) nn.Dropout(.5),
self.fc = LayerChoice([ nn.Dropout(.75)
nn.Sequential(
nn.Linear(9216, 64),
nn.ReLU(),
nn.Linear(64, 10),
),
nn.Sequential(
nn.Linear(9216, 128),
nn.ReLU(),
nn.Linear(128, 10),
),
nn.Sequential(
nn.Linear(9216, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
]) ])
self.dropout2 = nn.Dropout(0.5)
if value_choice:
hidden = nn.ValueChoice([32, 64, 128])
else:
hidden = 64
self.fc1 = nn.Linear(9216, hidden)
self.fc2 = nn.Linear(hidden, 10)
self.rpfc = nn.Linear(10, 10) self.rpfc = nn.Linear(10, 10)
self.input_ch = InputChoice(2, 1)
def forward(self, x): def forward(self, x):
x = F.relu(self.conv1(x)) x = F.relu(self.conv1(x))
x = F.max_pool2d(self.conv2(x), 2) x = F.max_pool2d(self.conv2(x), 2)
x1 = torch.flatten(self.dropout1(x), 1) x = torch.flatten(self.dropout1(x), 1)
x2 = torch.flatten(self.dropout2(x), 1) x = self.fc1(x)
x = self.dropout_choice([x1, x2]) x = F.relu(x)
x = self.fc(x) x = self.dropout2(x)
x = self.rpfc(x) x = self.fc2(x)
x1 = self.rpfc(x)
x = self.input_ch([x, x1])
output = F.log_softmax(x, dim=1) output = F.log_softmax(x, dim=1)
return output return output
def prepare_model_data(): @model_wrapper
base_model = Net() class MultiHeadAttentionNet(nn.Module):
def __init__(self, head_count):
super().__init__()
embed_dim = ValueChoice(candidates=[32, 64])
self.linear1 = nn.Linear(128, embed_dim)
self.mhatt = nn.MultiheadAttention(embed_dim, head_count)
self.linear2 = nn.Linear(embed_dim, 1)
def forward(self, batch):
query, key, value = batch
q, k, v = self.linear1(query), self.linear1(key), self.linear1(value)
output, _ = self.mhatt(q, k, v, need_weights=False)
y = self.linear2(output)
return F.relu(y)
@model_wrapper
class ValueChoiceConvNet(nn.Module):
def __init__(self):
super().__init__()
ch1 = ValueChoice([16, 32])
kernel = ValueChoice([3, 5])
self.conv1 = nn.Conv2d(1, ch1, kernel, padding=kernel // 2)
self.batch_norm = nn.BatchNorm2d(ch1)
self.conv2 = nn.Conv2d(ch1, 64, 3)
self.dropout1 = LayerChoice([
nn.Dropout(.25),
nn.Dropout(.5),
nn.Dropout(.75)
])
self.fc = nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.batch_norm(x)
x = F.relu(x)
x = F.max_pool2d(self.conv2(x), 2)
x = torch.mean(x, (2, 3))
x = self.fc(x)
return F.log_softmax(x, dim=1)
@model_wrapper
class RepeatNet(nn.Module):
def __init__(self):
super().__init__()
ch1 = ValueChoice([16, 32])
kernel = ValueChoice([3, 5])
self.conv1 = nn.Conv2d(1, ch1, kernel, padding=kernel // 2)
self.batch_norm = nn.BatchNorm2d(ch1)
self.conv2 = nn.Conv2d(ch1, 64, 3, padding=1)
self.dropout1 = LayerChoice([
nn.Dropout(.25),
nn.Dropout(.5),
nn.Dropout(.75)
])
self.fc = nn.Linear(64, 10)
self.rpfc = nn.Repeat(nn.Linear(10, 10), (1, 4))
def forward(self, x):
x = self.conv1(x)
x = self.batch_norm(x)
x = F.relu(x)
x = F.max_pool2d(self.conv2(x), 2)
x = torch.mean(x, (2, 3))
x = self.fc(x)
x = self.rpfc(x)
return F.log_softmax(x, dim=1)
@basic_unit
class MyOp(nn.Module):
def __init__(self, some_ch):
super().__init__()
self.some_ch = some_ch
self.batch_norm = nn.BatchNorm2d(some_ch)
def forward(self, x):
return self.batch_norm(x)
@model_wrapper
class CustomOpValueChoiceNet(nn.Module):
def __init__(self):
super().__init__()
ch1 = ValueChoice([16, 32])
kernel = ValueChoice([3, 5])
self.conv1 = nn.Conv2d(1, ch1, kernel, padding=kernel // 2)
self.batch_norm = MyOp(ch1)
self.conv2 = nn.Conv2d(ch1, 64, 3, padding=1)
self.dropout1 = LayerChoice([
nn.Dropout(.25),
nn.Dropout(.5),
nn.Dropout(.75)
])
self.fc = nn.Linear(64, 10)
def forward(self, x):
x = self.conv1(x)
x = self.batch_norm(x)
x = F.relu(x)
x = F.max_pool2d(self.conv2(x), 2)
x = torch.mean(x, (2, 3))
x = self.fc(x)
return F.log_softmax(x, dim=1)
def _mnist_net(type_):
if type_ == 'simple':
base_model = SimpleNet(False)
elif type_ == 'simple_value_choice':
base_model = SimpleNet()
elif type_ == 'value_choice':
base_model = ValueChoiceConvNet()
elif type_ == 'repeat':
base_model = RepeatNet()
elif type_ == 'custom_op':
base_model = CustomOpValueChoiceNet()
else:
raise ValueError(f'Unsupported type: {type_}')
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST('data/mnist', train = True, download=True, transform=transform) train_dataset = MNIST('data/mnist', train=True, download=True, transform=transform)
train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 10)) train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 20))
train_loader = DataLoader(train_dataset, 64, sampler = train_random_sampler) train_loader = DataLoader(train_dataset, 64, sampler=train_random_sampler)
valid_dataset = MNIST('data/mnist', train = False, download=True, transform=transform) valid_dataset = MNIST('data/mnist', train=False, download=True, transform=transform)
valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 10)) valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 20))
valid_loader = DataLoader(valid_dataset, 64, sampler = valid_random_sampler) valid_loader = DataLoader(valid_dataset, 64, sampler=valid_random_sampler)
evaluator = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, max_epochs=1)
return base_model, evaluator
def _multihead_attention_net():
base_model = MultiHeadAttentionNet(1)
class AttentionRandDataset(Dataset):
def __init__(self, data_shape, gt_shape, len) -> None:
super().__init__()
self.datashape = data_shape
self.gtshape = gt_shape
self.len = len
def __getitem__(self, index):
q = torch.rand(self.datashape)
k = torch.rand(self.datashape)
v = torch.rand(self.datashape)
gt = torch.rand(self.gtshape)
return (q, k, v), gt
def __len__(self):
return self.len
trainer_kwargs = { train_set = AttentionRandDataset((1, 128), (1, 1), 1000)
'max_epochs' : 1 val_set = AttentionRandDataset((1, 128), (1, 1), 500)
} train_loader = DataLoader(train_set, batch_size=32)
val_loader = DataLoader(val_set, batch_size=32)
return base_model, train_loader, valid_loader, trainer_kwargs evaluator = Regression(train_dataloader=train_loader, val_dataloaders=val_loader, max_epochs=1)
return base_model, evaluator
def _test_strategy(strategy_): def _test_strategy(strategy_, support_value_choice=True):
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data() to_test = [
cls = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **trainer_kwargs) # (model, evaluator), support_or_net
experiment = RetiariiExperiment(base_model, cls, strategy=strategy_) (_mnist_net('simple'), True),
(_mnist_net('simple_value_choice'), support_value_choice),
(_mnist_net('value_choice'), support_value_choice),
(_mnist_net('repeat'), False), # no strategy supports repeat currently
(_mnist_net('custom_op'), False), # this is definitely a NO
(_multihead_attention_net(), support_value_choice),
]
config = RetiariiExeConfig() for (base_model, evaluator), support_or_not in to_test:
config.execution_engine = 'oneshot' print('Testing:', type(strategy_).__name__, type(base_model).__name__, type(evaluator).__name__, support_or_not)
experiment = RetiariiExperiment(base_model, evaluator, strategy=strategy_)
experiment.run(config) config = RetiariiExeConfig()
config.execution_engine = 'oneshot'
assert isinstance(experiment.export_top_models()[0], dict) if support_or_not:
experiment.run(config)
assert isinstance(experiment.export_top_models()[0], dict)
else:
with pytest.raises(TypeError, match='not supported'):
experiment.run(config)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_darts(): def test_darts():
_test_strategy(strategy.DARTS()) _test_strategy(strategy.DARTS())
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_proxyless(): def test_proxyless():
_test_strategy(strategy.Proxyless()) _test_strategy(strategy.Proxyless(), False)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_enas(): def test_enas():
_test_strategy(strategy.ENAS()) _test_strategy(strategy.ENAS())
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_random(): def test_random():
_test_strategy(strategy.RandomOneShot()) _test_strategy(strategy.RandomOneShot())
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs') @pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
def test_snas(): def test_gumbel_darts():
_test_strategy(strategy.SNAS()) _test_strategy(strategy.GumbelDARTS())
if __name__ == '__main__': if __name__ == '__main__':
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='all', metavar='E', parser.add_argument('--exp', type=str, default='all', metavar='E',
help='exp to run, default = all' ) help='experiment to run, default = all')
args = parser.parse_args() args = parser.parse_args()
if args.exp == 'all': if args.exp == 'all':
...@@ -133,6 +287,6 @@ if __name__ == '__main__': ...@@ -133,6 +287,6 @@ if __name__ == '__main__':
test_proxyless() test_proxyless()
test_enas() test_enas()
test_random() test_random()
test_snas() test_gumbel_darts()
else: else:
globals()[f'test_{args.exp}']() globals()[f'test_{args.exp}']()
import pytest
import numpy as np
import torch
import torch.nn as nn
from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, Linear, MultiheadAttention
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import (
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax
)
from nni.retiarii.oneshot.pytorch.supermodule.sampling import (
MixedOpPathSamplingPolicy, PathSamplingLayer, PathSamplingInput
)
from nni.retiarii.oneshot.pytorch.supermodule.operation import MixedConv2d, NATIVE_MIXED_OPERATIONS
from nni.retiarii.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput
from nni.retiarii.oneshot.pytorch.supermodule._operation_utils import Slicable as S, MaybeWeighted as W
from nni.retiarii.oneshot.pytorch.supermodule._valuechoice_utils import *
def test_slice():
weight = np.ones((3, 7, 24, 23))
assert S(weight)[:, 1:3, :, 9:13].shape == (3, 2, 24, 4)
assert S(weight)[:, 1:W(3)*2+1, :, 9:13].shape == (3, 6, 24, 4)
assert S(weight)[:, 1:W(3)*2+1].shape == (3, 6, 24, 23)
# no effect
assert S(weight)[:] is weight
# list
assert S(weight)[[slice(1), slice(2, 3)]].shape == (2, 7, 24, 23)
assert S(weight)[[slice(1), slice(2, W(2) + 1)], W(2):].shape == (2, 5, 24, 23)
# weighted
weight = S(weight)[:W({1: 0.5, 2: 0.3, 3: 0.2})]
weight = weight[:, 0, 0, 0]
assert weight[0] == 1 and weight[1] == 0.5 and weight[2] == 0.2
weight = np.ones((3, 6, 6))
value = W({1: 0.5, 3: 0.5})
weight = S(weight)[:, 3 - value:3 + value, 3 - value:3 + value]
for i in range(0, 6):
for j in range(0, 6):
if 2 <= i <= 3 and 2 <= j <= 3:
assert weight[0, i, j] == 1
else:
assert weight[1, i, j] == 0.5
# weighted + list
value = W({1: 0.5, 3: 0.5})
weight = np.ones((8, 4))
weight = S(weight)[[slice(value), slice(4, value + 4)]]
assert weight.sum(1).tolist() == [4, 2, 2, 0, 4, 2, 2, 0]
with pytest.raises(ValueError, match='one distinct'):
# has to be exactly the same instance, equal is not enough
weight = S(weight)[:W({1: 0.5}), : W({1: 0.5})]
def test_valuechoice_utils():
chosen = {"exp": 3, "add": 1}
vc0 = ValueChoice([3, 4, 6], label='exp') * 2 + ValueChoice([0, 1], label='add')
assert evaluate_value_choice_with_dict(vc0, chosen) == 7
vc = vc0 + ValueChoice([3, 4, 6], label='exp')
assert evaluate_value_choice_with_dict(vc, chosen) == 10
assert list(dedup_inner_choices([vc0, vc]).keys()) == ['exp', 'add']
assert traverse_all_options(vc) == [9, 10, 12, 13, 18, 19]
weights = dict(traverse_all_options(vc, weights={'exp': [0.5, 0.3, 0.2], 'add': [0.4, 0.6]}))
ans = dict([(9, 0.2), (10, 0.3), (12, 0.12), (13, 0.18), (18, 0.08), (19, 0.12)])
assert len(weights) == len(ans)
for value, weight in ans.items():
assert abs(weight - weights[value]) < 1e-6
def test_pathsampling_valuechoice():
orig_conv = Conv2d(3, ValueChoice([3, 5, 7], label='123'), kernel_size=3)
conv = MixedConv2d.mutate(orig_conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
conv.resample(memo={'123': 5})
assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 5
conv.resample(memo={'123': 7})
assert conv(torch.zeros((1, 3, 5, 5))).size(1) == 7
assert conv.export({})['123'] in [3, 5, 7]
def test_differentiable_valuechoice():
orig_conv = Conv2d(3, ValueChoice([3, 5, 7], label='456'), kernel_size=ValueChoice(
[3, 5, 7], label='123'), padding=ValueChoice([3, 5, 7], label='123') // 2)
conv = MixedConv2d.mutate(orig_conv, 'dummy', {}, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
assert conv(torch.zeros((1, 3, 7, 7))).size(2) == 7
assert set(conv.export({}).keys()) == {'123', '456'}
def _mixed_operation_sampling_sanity_check(operation, memo, *input):
for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation):
mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
break
mutate_op.resample(memo=memo)
return mutate_op(*input)
def _mixed_operation_differentiable_sanity_check(operation, *input):
for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation):
mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
break
return mutate_op(*input)
def test_mixed_linear():
linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8]))
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
_mixed_operation_sampling_sanity_check(linear, {'shared': 9}, torch.randn(2, 9))
_mixed_operation_differentiable_sanity_check(linear, torch.randn(2, 9))
linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8]), bias=False)
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
with pytest.raises(TypeError):
linear = Linear(ValueChoice([3, 6, 9], label='shared'), ValueChoice([2, 4, 8]), bias=ValueChoice([False, True]))
_mixed_operation_sampling_sanity_check(linear, {'shared': 3}, torch.randn(2, 3))
def test_mixed_conv2d():
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([2, 4, 8], label='out') * 2, 1)
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'out': 4}, torch.randn(2, 3, 9, 9)).size(1) == 8
_mixed_operation_differentiable_sanity_check(conv, torch.randn(2, 9, 3, 3))
# stride
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([2, 4, 8], label='out'), 1, stride=ValueChoice([1, 2], label='stride'))
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 2}, torch.randn(2, 3, 10, 10)).size(2) == 5
assert _mixed_operation_sampling_sanity_check(conv, {'in': 3, 'stride': 1}, torch.randn(2, 3, 10, 10)).size(2) == 10
# groups, dw conv
conv = Conv2d(ValueChoice([3, 6, 9], label='in'), ValueChoice([3, 6, 9], label='in'), 1, groups=ValueChoice([3, 6, 9], label='in'))
assert _mixed_operation_sampling_sanity_check(conv, {'in': 6}, torch.randn(2, 6, 10, 10)).size() == torch.Size([2, 6, 10, 10])
# make sure kernel is sliced correctly
conv = Conv2d(1, 1, ValueChoice([1, 3], label='k'), bias=False)
conv = MixedConv2d.mutate(conv, 'dummy', {}, {'mixed_op_sampling': MixedOpPathSamplingPolicy})
with torch.no_grad():
conv.weight.zero_()
# only center is 1, must pick center to pass this test
conv.weight[0, 0, 1, 1] = 1
conv.resample({'k': 1})
assert conv(torch.ones((1, 1, 3, 3))).sum().item() == 9
def test_mixed_batchnorm2d():
bn = BatchNorm2d(ValueChoice([32, 64], label='dim'))
assert _mixed_operation_sampling_sanity_check(bn, {'dim': 32}, torch.randn(2, 32, 3, 3)).size(1) == 32
assert _mixed_operation_sampling_sanity_check(bn, {'dim': 64}, torch.randn(2, 64, 3, 3)).size(1) == 64
_mixed_operation_differentiable_sanity_check(bn, torch.randn(2, 64, 3, 3))
def test_mixed_mhattn():
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4)
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8},
torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8))[0].size(-1) == 8
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 8))
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), ValueChoice([2, 3, 4], label='heads'))
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'heads': 2},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
with pytest.raises(AssertionError, match='divisible'):
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'heads': 3},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 4))[0].size(-1) == 4
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4, kdim=ValueChoice([5, 7], label='kdim'))
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'kdim': 7},
torch.randn(7, 2, 4), torch.randn(7, 2, 7), torch.randn(7, 2, 4))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'kdim': 5},
torch.randn(7, 2, 8), torch.randn(7, 2, 5), torch.randn(7, 2, 8))[0].size(-1) == 8
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 4, vdim=ValueChoice([5, 8], label='vdim'))
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'vdim': 8},
torch.randn(7, 2, 4), torch.randn(7, 2, 4), torch.randn(7, 2, 8))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'vdim': 5},
torch.randn(7, 2, 8), torch.randn(7, 2, 8), torch.randn(7, 2, 5))[0].size(-1) == 8
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(5, 3, 8), torch.randn(5, 3, 8), torch.randn(5, 3, 8))
@pytest.mark.skipif(torch.__version__.startswith('1.7'), reason='batch_first is not supported for legacy PyTorch')
def test_mixed_mhattn_batch_first():
# batch_first is not supported for legacy pytorch versions
# mark 1.7 because 1.7 is used on legacy pipeline
mhattn = MultiheadAttention(ValueChoice([4, 8], label='emb'), 2, kdim=(ValueChoice([3, 7], label='kdim')), vdim=ValueChoice([5, 8], label='vdim'),
bias=False, add_bias_kv=True, batch_first=True)
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 4, 'kdim': 7, 'vdim': 8},
torch.randn(2, 7, 4), torch.randn(2, 7, 7), torch.randn(2, 7, 8))[0].size(-1) == 4
assert _mixed_operation_sampling_sanity_check(mhattn, {'emb': 8, 'kdim': 3, 'vdim': 5},
torch.randn(2, 7, 8), torch.randn(2, 7, 3), torch.randn(2, 7, 5))[0].size(-1) == 8
_mixed_operation_differentiable_sanity_check(mhattn, torch.randn(1, 7, 8), torch.randn(1, 7, 7), torch.randn(1, 7, 8))
def test_pathsampling_layer_input():
op = PathSamplingLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], label='ccc')
with pytest.raises(RuntimeError, match='sample'):
op(torch.randn(4, 2))
op.resample({})
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.search_space_spec()['ccc'].values == ['a', 'b']
assert op.export({})['ccc'] in ['a', 'b']
input = PathSamplingInput(5, 2, 'concat', 'ddd')
sample = input.resample({})
assert 'ddd' in sample
assert len(sample['ddd']) == 2
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 4
assert len(input.export({})['ddd']) == 2
def test_differentiable_layer_input():
op = DifferentiableMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b']
assert len(list(op.parameters())) == 3
input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2
assert len(input.export({})['ddd']) == 2
def test_proxyless_layer_input():
op = ProxylessMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
assert op.resample({})['eee'] in ['a', 'b']
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b']
assert len(list(op.parameters())) == 3
input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input.resample({})['ddd'] in list(range(5))
assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2])
assert input.export({})['ddd'] in list(range(5))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment