Unverified Commit a0fd0036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
parents d6dcb483 bc6d8796
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
# pylint: disable=wildcard-import,unused-wildcard-import
# To make auto-completion happy, we generate a _nn.py that lists out all the classes.
nn_cache_file_path = Path(__file__).parent / '_nn.py'
# Update this when cache format changes, to enforce an update.
cache_version = 2
def validate_cache() -> bool:
import torch
cache_valid = []
if nn_cache_file_path.exists():
lines = nn_cache_file_path.read_text().splitlines()
for line in lines:
if line.startswith('# _torch_version'):
_cached_torch_version = line[line.find('=') + 1:].strip()
if _cached_torch_version == torch.__version__:
cache_valid.append(True)
if line.startswith('# _torch_nn_cache_version'):
_cached_cache_version = int(line[line.find('=') + 1:].strip())
if _cached_cache_version == cache_version:
cache_valid.append(True)
return len(cache_valid) >= 2 and all(cache_valid)
def generate_stub_file() -> str:
import inspect
import warnings
import torch
import torch.nn as nn
_NO_WRAP_CLASSES = [
# not an nn.Module
'Parameter',
'ParameterList',
'UninitializedBuffer',
'UninitializedParameter',
# arguments are special
'Module',
'Sequential',
# utilities
'Container',
'DataParallel',
]
_WRAP_WITHOUT_TAG_CLASSES = [
# special support on graph engine
'ModuleList',
'ModuleDict',
]
code = [
'# Copyright (c) Microsoft Corporation.',
'# Licensed under the MIT license.',
'# This file is auto-generated to make auto-completion work.',
'# When pytorch version does not match, it will get automatically updated.',
'# pylint: skip-file',
'# pyright: reportGeneralTypeIssues=false',
f'# _torch_version = {torch.__version__}',
f'# _torch_nn_cache_version = {cache_version}',
'import typing',
'import torch.nn as nn',
'from nni.retiarii.serializer import basic_unit',
]
all_names = []
# Add modules, classes, functions in torch.nn into this module.
for name, obj in inspect.getmembers(torch.nn):
if inspect.isclass(obj):
if name in _NO_WRAP_CLASSES:
code.append(f'{name} = nn.{name}')
elif not issubclass(obj, nn.Module):
# It should never go here
# We did it to play safe
warnings.warn(f'{obj} is found to be not a nn.Module, which is unexpected. '
'It means your PyTorch version might not be supported.', RuntimeWarning)
code.append(f'{name} = nn.{name}')
elif name in _WRAP_WITHOUT_TAG_CLASSES:
code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}, basic_unit_tag=False))')
else:
code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}))')
all_names.append(name)
elif inspect.isfunction(obj) or inspect.ismodule(obj):
code.append(f'{name} = nn.{name}') # no modification
all_names.append(name)
code.append(f'__all__ = {all_names}')
return '\n'.join(code)
def write_cache(code: str) -> None:
with nn_cache_file_path.open('w') as fp:
fp.write(code)
code = generate_stub_file()
if not validate_cache():
write_cache(code)
del Path, validate_cache, write_cache, cache_version, nn_cache_file_path, code
from ._nn import * # pylint: disable=import-error, wildcard-import, unused-wildcard-import
from nni.nas.nn.pytorch.layers import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import tensorflow as tf
class LayerChoice(tf.keras.Layer):
# FIXME: This is only a draft to test multi-framework support, it's not unimplemented at all.
pass
from nni.nas.nn.tensorflow.api import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import warnings
from itertools import chain
from typing import Callable, Any, Dict, Union, Tuple, List, cast
import pytorch_lightning as pl
import torch.optim as optim
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import nni.retiarii.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.typehint import Literal
from .supermodule.base import BaseSuperNetModule
__all__ = [
'MutationHook',
'BaseSuperNetModule',
'BaseOneShotLightningModule',
'traverse_and_mutate_submodules',
'no_default_hook'
]
MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
def traverse_and_mutate_submodules(
root_module: nn.Module, hooks: list[MutationHook], mutate_kwargs: dict[str, Any], topdown: bool = True
) -> list[BaseSuperNetModule]:
"""
Traverse the module-tree of ``root_module``, and call ``hooks`` on every tree node.
Parameters
----------
root_module : nn.Module
User-defined model space.
Since this method is called in the ``__init__`` of :class:`BaseOneShotLightningModule`,
it's usually a ``pytorch_lightning.LightningModule``.
The mutation will be in-place on ``root_module``.
hooks : list[MutationHook]
List of mutation hooks. See :class:`BaseOneShotLightningModule` for how to write hooks.
When a hook returns an module, the module will be replaced (mutated) to the new module.
mutate_kwargs : dict
Extra keyword arguments passed to hooks.
topdown : bool, default = False
If topdown is true, hooks are first called, before traversing its sub-module (i.e., pre-order DFS).
Otherwise, sub-modules are first traversed, before calling hooks on this node (i.e., post-order DFS).
Returns
----------
modules : dict[str, nn.Module]
The replace result.
"""
memo = {}
module_list = []
def apply(m):
# Need to call list() here because the loop body might replace some children in-place.
for name, child in list(m.named_children()):
# post-order DFS
if not topdown:
apply(child)
mutate_result = None
for hook in hooks:
hook_suggest = hook(child, name, memo, mutate_kwargs)
# parse the mutate result
if isinstance(hook_suggest, tuple):
hook_suggest, suppress = hook_suggest
elif hook_suggest is True:
hook_suggest, suppress = None, True
elif not hook_suggest: # none / false
hook_suggest, suppress = None, False
elif isinstance(hook_suggest, nn.Module):
suppress = True
else:
raise TypeError(f'Mutation hook returned {hook_suggest} of unsupported type: {type(hook_suggest)}.')
if hook_suggest is not None:
if not isinstance(hook_suggest, BaseSuperNetModule):
warnings.warn("Mutation hook didn't return a BaseSuperNetModule. It will be ignored in hooked module list.",
RuntimeWarning)
setattr(m, name, hook_suggest)
mutate_result = hook_suggest
# if suppress, no further mutation hooks are called
if suppress:
break
if isinstance(mutate_result, BaseSuperNetModule):
# Replace child with the mutate result, and DFS this one
child = mutate_result
module_list.append(mutate_result)
# pre-order DFS
if topdown:
apply(child)
apply(root_module)
return module_list
def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> bool:
"""Add this hook at the end of your hook list to raise error for unsupported mutation primitives."""
# Forward IS NOT supernet
primitive_list = (
nas_nn.LayerChoice,
nas_nn.InputChoice,
nas_nn.Repeat,
nas_nn.NasBench101Cell,
# nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet
)
if isinstance(module, primitive_list):
raise TypeError(f'{type(module).__name__} is not supported')
if isinstance(module, nas_nn.Cell) and module.merge_op != 'all':
# need output_node_indices, which depends on super-net
raise TypeError(f'Cell with merge_op `{module.merge_op}` is not supported')
if is_traceable(module):
# check whether there is a value-choice in its arguments
has_valuechoice = False
for arg in chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()):
if isinstance(arg, ValueChoiceX):
has_valuechoice = True
break
if has_valuechoice:
raise TypeError(f'`basic_unit` {type(module).__name__} with value choice in its arguments is not supported. '
'Please try to remove `basic_unit` to see if that works, or support this type with value choice manually.')
return True # suppress all other hooks
class BaseOneShotLightningModule(pl.LightningModule):
_mutation_hooks_note = """mutation_hooks : list[MutationHook]
Extra mutation hooks to support customized mutation on primitives other than built-ins.
Mutation hooks are callable that inputs an Module and returns a
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
stored in :attr:`nas_modules`, and be the focus of the NAS algorithm.
The hook list will be appended by ``default_mutation_hooks`` in each one-shot module.
To be more specific, the input arguments are four arguments:
1. a module that might be processed,
2. name of the module in its parent module,
3. a memo dict whose usage depends on the particular algorithm.
4. keyword arguments (configurations).
Note that the memo should be read/written by hooks.
There won't be any hooks called on root module.
The returned arguments can be also one of the three kinds:
1. tuple of: :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
2. boolean,
3. :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
The boolean value is ``suppress`` indicates whether the following hooks should be called.
When it's true, it suppresses the subsequent hooks, and they will never be invoked.
Without boolean value specified, it's assumed to be false.
If a none value appears on the place of
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
it means the hook suggests to
keep the module unchanged, and nothing will happen.
An example of mutation hook is given in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.no_default_hook`.
However it's recommended to implement mutation hooks by deriving
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
and add its classmethod ``mutate`` to this list.
"""
_inner_module_note = """inner_module : pytorch_lightning.LightningModule
It's a `LightningModule <https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html>`__
that defines computations, train/val loops, optimizers in a single class.
When used in NNI, the ``inner_module`` is the combination of instances of evaluator + base model
(to be precise, a base model wrapped with LightningModule in evaluator).
"""
__doc__ = """
The base class for all one-shot NAS modules.
In NNI, we try to separate the "search" part and "training" part in one-shot NAS.
The "training" part is defined with evaluator interface (has to be lightning evaluator interface to work with oneshot).
Since the lightning evaluator has already broken down the training into minimal building blocks,
we can re-assemble them after combining them with the "search" part of a particular algorithm.
After the re-assembling, this module has defined all the search + training. The experiment can use a lightning trainer
(which is another part in the evaluator) to train this module, so as to complete the search process.
Essential function such as preprocessing user's model, redirecting lightning hooks for user's model,
configuring optimizers and exporting NAS result are implemented in this class.
Attributes
----------
nas_modules : list[BaseSuperNetModule]
Modules that have been mutated, which the search algorithms should care about.
model : pl.LightningModule
PyTorch lightning module. A model space with training recipe defined (wrapped by LightningModule in evaluator).
Parameters
----------
""" + _inner_module_note + _mutation_hooks_note
trainer: pl.Trainer
@property
def automatic_optimization(self) -> bool:
return False
def default_mutation_hooks(self) -> list[MutationHook]:
"""Override this to define class-default mutation hooks."""
return [no_default_hook]
def mutate_kwargs(self) -> dict[str, Any]:
"""Extra keyword arguments passed to mutation hooks. Usually algo-specific."""
return {}
def __init__(self, model: pl.LightningModule, mutation_hooks: list[MutationHook] | None = None):
super().__init__()
assert isinstance(model, pl.LightningModule)
self.model = model
# append the default hooks
mutation_hooks = (mutation_hooks or []) + self.default_mutation_hooks()
# traverse the model, calling hooks on every submodule
self.nas_modules: list[BaseSuperNetModule] = traverse_and_mutate_submodules(
self.model, mutation_hooks, self.mutate_kwargs(), topdown=True)
def search_space_spec(self) -> dict[str, ParameterSpec]:
"""Get the search space specification from :attr:`nas_modules`.
Returns
-------
dict
Key is the name of the choice, value is the corresponding :class:`ParameterSpec`.
"""
result = {}
for module in self.nas_modules:
result.update(module.search_space_spec())
return result
def resample(self) -> dict[str, Any]:
"""Trigger the resample for each :attr:`nas_modules`.
Sometimes (e.g., in differentiable cases), it does nothing.
Returns
-------
dict
Sampled architecture.
"""
result = {}
for module in self.nas_modules:
result.update(module.resample(memo=result))
return result
def export(self) -> dict[str, Any]:
"""
Export the NAS result, ideally the best choice of each :attr:`nas_modules`.
You may implement an ``export`` method for your customized :attr:`nas_modules`.
Returns
--------
dict
Keys are names of ``nas_modules``, and values are the choice indices of them.
"""
result = {}
for module in self.nas_modules:
result.update(module.export(memo=result))
return result
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
"""This is the implementation of what happens in training loops of one-shot algos.
It usually calls ``self.model.training_step`` which implements the real training recipe of the users' model.
"""
return self.model.training_step(batch, batch_idx)
def configure_optimizers(self):
"""
Combine architecture optimizers and user's model optimizers.
You can overwrite :meth:`configure_architecture_optimizers` if architecture optimizers are needed in your NAS algorithm.
For now :attr:`model` is tested against evaluators in :mod:`nni.retiarii.evaluator.pytorch.lightning`
and it only returns 1 optimizer.
But for extendibility, codes for other return value types are also implemented.
"""
# pylint: disable=assignment-from-none
arc_optimizers = self.configure_architecture_optimizers()
if arc_optimizers is None:
return self.model.configure_optimizers()
if isinstance(arc_optimizers, optim.Optimizer):
arc_optimizers = [arc_optimizers]
self.arc_optim_count = len(arc_optimizers)
# FIXME: this part uses non-official lightning API.
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
try:
# above v1.6
from pytorch_lightning.core.optimizer import ( # pylint: disable=import-error
_configure_optimizers, # type: ignore
_configure_schedulers_automatic_opt, # type: ignore
_configure_schedulers_manual_opt # type: ignore
)
w_optimizers, lr_schedulers, self.frequencies, monitor = \
_configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = (
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
if self.automatic_optimization
else _configure_schedulers_manual_opt(lr_schedulers)
)
except ImportError:
# under v1.5
w_optimizers, lr_schedulers, self.frequencies, monitor = \
self.trainer._configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization) # type: ignore
if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers): # type: ignore
raise Exception(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
# variables used to handle optimizer frequency
self.cur_optimizer_step = 0
self.cur_optimizer_index = 0
return arc_optimizers + w_optimizers, lr_schedulers
def on_train_start(self):
return self.model.on_train_start()
def on_train_end(self):
return self.model.on_train_end()
def on_fit_start(self):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer # type: ignore
self.model.log = self.log
return self.model.on_fit_start()
def on_fit_end(self):
return self.model.on_fit_end()
def on_train_batch_start(self, batch, batch_idx, unused=0):
return self.model.on_train_batch_start(batch, batch_idx, unused)
def on_train_batch_end(self, outputs, batch, batch_idx, unused=0):
return self.model.on_train_batch_end(outputs, batch, batch_idx, unused)
# Deprecated hooks in pytorch-lightning
def on_epoch_start(self):
return self.model.on_epoch_start()
def on_epoch_end(self):
return self.model.on_epoch_end()
def on_train_epoch_start(self):
return self.model.on_train_epoch_start()
def on_train_epoch_end(self):
return self.model.on_train_epoch_end()
def on_before_backward(self, loss):
return self.model.on_before_backward(loss)
def on_after_backward(self):
return self.model.on_after_backward()
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None):
return self.model.configure_gradient_clipping(optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm)
def configure_architecture_optimizers(self):
"""
Hook kept for subclasses. A specific NAS method inheriting this base class should return its architecture optimizers here
if architecture parameters are needed. Note that lr schedulers are not supported now for architecture_optimizers.
Returns
----------
arc_optimizers : list[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
"""
return None
def call_lr_schedulers(self, batch_index):
"""
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off
by this class, you can use this function to make schedulers behave as they were automatically handled by the lightning trainer.
Parameters
----------
batch_idx : int
batch index
"""
def apply(lr_scheduler):
# single scheduler is called every epoch
if isinstance(lr_scheduler, _LRScheduler):
if self.trainer.is_last_batch:
lr_scheduler.step()
# lr_scheduler_config is called as configured
elif isinstance(lr_scheduler, dict):
interval = lr_scheduler['interval']
frequency = lr_scheduler['frequency']
if (
interval == 'step' and
batch_index % frequency == 0
) or \
(
interval == 'epoch' and
self.trainer.is_last_batch and
(self.trainer.current_epoch + 1) % frequency == 0
):
lr_scheduler['scheduler'].step()
lr_schedulers = self.lr_schedulers()
if isinstance(lr_schedulers, list):
for lr_scheduler in lr_schedulers:
apply(lr_scheduler)
else:
apply(lr_schedulers)
def call_weight_optimizers(self, method: Literal['step', 'zero_grad']):
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Parameters
----------
method : str
Method to call. Only ``step`` and ``zero_grad`` are supported now.
"""
def apply_method(optimizer, method):
if method == 'step':
optimizer.step()
elif method == 'zero_grad':
optimizer.zero_grad()
optimizers = self.weight_optimizers()
if optimizers is None:
return
assert isinstance(optimizers, list), 'Did you forget to set use_pl_optimizers to true?'
if len(self.frequencies) > 0:
self.cur_optimizer_step += 1
if self.frequencies[self.cur_optimizer_index] == self.cur_optimizer_step:
self.cur_optimizer_step = 0
self.cur_optimizer_index = self.cur_optimizer_index + 1 \
if self.cur_optimizer_index + 1 < len(optimizers) \
else 0
apply_method(optimizers[self.cur_optimizer_index], method)
else:
for optimizer in optimizers:
apply_method(optimizer, method)
def architecture_optimizers(self) -> list[Optimizer] | Optimizer | None:
"""
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in :meth:`training_step`.
Returns
----------
opts : list[Optimizer], Optimizer, None
Architecture optimizers defined in :meth:`configure_architecture_optimizers`. This will be None if there is no
architecture optimizers.
"""
opts = self.optimizers()
if isinstance(opts, list):
# pylint: disable=unsubscriptable-object
arc_opts = opts[:self.arc_optim_count]
if len(arc_opts) == 1:
return cast(Optimizer, arc_opts[0])
return cast(List[Optimizer], arc_opts)
# If there is only 1 optimizer and it is the architecture optimizer
if self.arc_optim_count == 1:
return cast(Union[List[Optimizer], Optimizer], opts)
return None
def weight_optimizers(self) -> list[Optimizer] | Optimizer | None:
"""
Get user optimizers from all optimizers. Use this to get user optimizers in :meth:`training_step`.
Returns
----------
opts : list[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
# Since use_pl_optimizer is set true (by default) here.
# opts always return a list
opts = self.optimizers()
if isinstance(opts, list):
# pylint: disable=unsubscriptable-object
return cast(List[Optimizer], opts[self.arc_optim_count:])
# FIXME: this case is actually not correctly handled
# If there is only 1 optimizer and no architecture optimizer
if self.arc_optim_count == 0:
return cast(Union[List[Optimizer], Optimizer], opts)
return None
from nni.nas.oneshot.pytorch.base_lightning import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
# pylint: disable=wildcard-import,unused-wildcard-import
from typing import Any
from pytorch_lightning.trainer.supporters import CombinedLoader, CombinedLoaderIterator
__all__ = ['ConcatLoader']
class ConcatLoader(CombinedLoader):
"""This loader is same as CombinedLoader in PyTorch-Lightning, but concatenate sub-loaders
instead of loading them in parallel.
Parameters
----------
loaders
For example, ::
{
"train": DataLoader(train_dataset),
"val": DataLoader(val_dataset)
}
In this example, the loader will first produce the batches from "train", then "val".
mode
Only support "min_size" for now.
"""
def __init__(self, loaders: dict[str, Any], mode: str = 'min_size'):
# FIXME: max_cycle will make dataloaders cycle iterators,
# causing extra problems.
if mode != 'min_size':
raise ValueError('Only min_size mode is supported now.')
super().__init__(loaders, mode)
def __iter__(self) -> Any:
"""Replace the super-class iterator with ours."""
self._try_to_patch_pytorch_dataloader()
iterator = ConcatLoaderIterator(self.loaders)
# handle fault tolerant restart.
self.on_restart(iterator)
self._iterator = iterator
return iterator
@staticmethod
def _try_to_patch_pytorch_dataloader():
"""Copied from CombinedLoader."""
from torch.utils.data.dataloader import _BaseDataLoaderIter
# prevent `NotImplementedError` from PyTorch:
# https://github.com/pytorch/pytorch/blob/v1.9.0/torch/utils/data/dataloader.py#L541
def __getstate__patch__(*_):
return {}
_BaseDataLoaderIter.__getstate__ = __getstate__patch__ # type: ignore
def __len__(self) -> int:
return int(sum(self._calc_num_batches(loader) for loader in self.loaders.values()))
class ConcatLoaderIterator(CombinedLoaderIterator):
"""Similar to CombinedLoaderIterator in Lightning, but in a concat manner."""
def __next__(self) -> Any:
"""Fetches the next batch from multiple data loaders,
by looking for the first iterator that isn't exhausted yet.
"""
if not len(self.loader_iters) == len(self.loaders):
raise RuntimeError('loader_iters must have the same length as loaders.')
for i, (loader_name, iterator) in enumerate(self.loader_iters.items()):
try:
return (self.request_next_batch(iterator), loader_name)
except StopIteration:
if i + 1 == len(self.loader_iters):
raise
from nni.nas.oneshot.pytorch.dataloader import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Experimental version of differentiable one-shot implementation."""
# pylint: disable=wildcard-import,unused-wildcard-import
from __future__ import annotations
import pytorch_lightning as pl
import torch
import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.differentiable import (
DifferentiableMixedLayer, DifferentiableMixedInput,
MixedOpDifferentiablePolicy, GumbelSoftmax,
DifferentiableMixedCell, DifferentiableMixedRepeat
)
from .supermodule.proxyless import ProxylessMixedInput, ProxylessMixedLayer
from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES
class DartsLightningModule(BaseOneShotLightningModule):
_darts_note = """
Continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient descent.
`Reference <https://arxiv.org/abs/1806.09055>`__.
DARTS algorithm is one of the most fundamental one-shot algorithm.
DARTS repeats iterations, where each iteration consists of 2 training phases.
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet.
.. versionadded:: 2.8
Supports searching for ValueChoices on operations, with the technique described in
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
One difference is that, in DARTS, we are using Softmax instead of GumbelSoftmax.
The supported mutation primitives of DARTS are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
)
__doc__ = _darts_note.format(
module_notes='The DARTS Module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with differentiable versions"""
hooks = [
DifferentiableMixedLayer.mutate,
DifferentiableMixedInput.mutate,
DifferentiableMixedCell.mutate,
DifferentiableMixedRepeat.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
return hooks
def mutate_kwargs(self):
"""Use differentiable strategy for mixed operations."""
return {
'mixed_op_sampling': MixedOpDifferentiablePolicy
}
def __init__(self, inner_module: pl.LightningModule,
mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0E-4):
self.arc_learning_rate = arc_learning_rate
super().__init__(inner_module, mutation_hooks=mutation_hooks)
def training_step(self, batch, batch_idx):
# grad manually
arc_optim = self.architecture_optimizers()
if not isinstance(arc_optim, optim.Optimizer):
raise TypeError(f'Expect arc_optim to be a single Optimizer, but found: {arc_optim}')
# DARTS strategy makes sure that ``train`` and ``val`` must be in the batch
trn_batch = batch['train']
val_batch = batch['val']
# phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless.
# See code of those methods for details.
self.resample()
arc_optim.zero_grad()
arc_step_loss = self.model.training_step(val_batch, 2 * batch_idx)
if isinstance(arc_step_loss, dict):
arc_step_loss = arc_step_loss['loss']
self.manual_backward(arc_step_loss)
self.finalize_grad()
arc_optim.step()
# phase 2: model step
self.resample()
self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1)
w_step_loss = loss_and_metrics['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
self.call_lr_schedulers(batch_idx)
return loss_and_metrics
def finalize_grad(self):
# Note: This hook is currently kept for Proxyless NAS.
pass
def configure_architecture_optimizers(self):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params = []
for m in self.nas_modules:
ctrl_params += list(m.parameters(arch=True)) # type: ignore
ctrl_optim = torch.optim.Adam(list(set(ctrl_params)), 3.e-4, betas=(0.5, 0.999),
weight_decay=1.0E-3)
return ctrl_optim
class ProxylessLightningModule(DartsLightningModule):
_proxyless_note = """
A low-memory-consuming optimized version of differentiable architecture search. See `reference <https://arxiv.org/abs/1812.00332>`__.
This is a DARTS-based method that resamples the architecture to reduce memory consumption.
Essentially, it samples one path on forward,
and implements its own backward to update the architecture parameters based on only one path.
The supported mutation primitives of Proxyless are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
__doc__ = _proxyless_note.format(
module_notes='This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with gumbel-differentiable versions"""
hooks = [
ProxylessMixedLayer.mutate,
ProxylessMixedInput.mutate,
no_default_hook,
]
# FIXME: no support for mixed operation currently
return hooks
def finalize_grad(self):
for m in self.nas_modules:
m.finalize_grad() # type: ignore
class GumbelDartsLightningModule(DartsLightningModule):
_gumbel_darts_note = """
Choose the best block by using Gumbel Softmax random sampling and differentiable training.
See `FBNet <https://arxiv.org/abs/1812.03443>`__ and `SNAS <https://arxiv.org/abs/1812.09926>`__.
This is a DARTS-based method that uses gumbel-softmax to simulate one-hot distribution.
Essentially, it tries to mimick the behavior of sampling one path on forward by gradually
cool down the temperature, aiming to bridge the gap between differentiable architecture weights and
discretization of architectures.
.. versionadded:: 2.8
Supports searching for ValueChoices on operations, with the technique described in
`FBNetV2: Differentiable Neural Architecture Search for Spatial and Channel Dimensions <https://arxiv.org/abs/2004.05565>`__.
The supported mutation primitives of GumbelDARTS are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
gumbel_temperature : float
The initial temperature used in gumbel-softmax.
use_temp_anneal : bool
If true, a linear annealing will be applied to ``gumbel_temperature``.
Otherwise, run at a fixed temperature. See `SNAS <https://arxiv.org/abs/1812.09926>`__ for details.
min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
)
def mutate_kwargs(self):
"""Use gumbel softmax."""
return {
'mixed_op_sampling': MixedOpDifferentiablePolicy,
'softmax': GumbelSoftmax(),
}
def __init__(self, inner_module,
mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0e-4,
gumbel_temperature: float = 1.,
use_temp_anneal: bool = False,
min_temp: float = .33):
super().__init__(inner_module, mutation_hooks, arc_learning_rate=arc_learning_rate)
self.temp = gumbel_temperature
self.init_temp = gumbel_temperature
self.use_temp_anneal = use_temp_anneal
self.min_temp = min_temp
def on_train_epoch_end(self):
if self.use_temp_anneal:
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp)
for module in self.nas_modules:
if hasattr(module, '_softmax'):
module._softmax.temp = self.temp # type: ignore
return self.model.on_train_epoch_end()
from nni.nas.oneshot.pytorch.differentiable import *
......@@ -3,14 +3,14 @@
import logging
import warnings
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import SubsetRandomSampler, DataLoader
from nni.nas.oneshot.pytorch.enas import ReinforceController, ReinforceField
from ..interface import BaseOneShotTrainer
from .random import PathSamplingLayerChoice, PathSamplingInputChoice
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, to_device
......@@ -18,148 +18,6 @@ from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger = logging.getLogger(__name__)
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c
class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
class ReinforceController(nn.Module):
"""
A controller that mutates the graph with RL.
Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect (chosen by InputChoice) will appear.
If the chosen number of inputs is away from the ``skip_connect``, there will be
a sample skip penalty which is a KL divergence added.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, entropy_reduction='sum'):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.skip_target = skip_target
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
})
self.embedding = nn.ModuleDict({
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})
def resample(self):
self._initialize()
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
return result
def _initialize(self):
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob: torch.Tensor = cast(torch.Tensor, 0)
self.sample_entropy: torch.Tensor = cast(torch.Tensor, 0)
self.sample_skip_penalty: torch.Tensor = cast(torch.Tensor, 0)
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field):
self._lstm_next_step()
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one:
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) # type: ignore
sampled = sampled.detach().cpu().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1:
sampled = sampled[0]
return sampled
class EnasTrainer(BaseOneShotTrainer):
"""
ENAS trainer.
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Experimental version of sampling-based one-shot implementation."""
# pylint: disable=wildcard-import,unused-wildcard-import
from __future__ import annotations
import warnings
from typing import Any
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES
from .supermodule.sampling import (
PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy,
PathSamplingCell, PathSamplingRepeat
)
from .enas import ReinforceController, ReinforceField
class RandomSamplingLightningModule(BaseOneShotLightningModule):
_random_note = """
Train a super-net with uniform path sampling. See `reference <https://arxiv.org/abs/1904.00420>`__.
In each epoch, model parameters are trained after a uniformly random sampling of each choice.
Notably, the exporting result is **also a random sample** of the search space.
The supported mutation primitives of RandomOneShot are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
Parameters
----------
{{module_params}}
{base_params}
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
)
__doc__ = _random_note.format(
module_params=BaseOneShotLightningModule._inner_module_note,
)
# turn on automatic optimization because nothing interesting is going on here.
@property
def automatic_optimization(self) -> bool:
return True
def default_mutation_hooks(self) -> list[MutationHook]:
"""Replace modules with differentiable versions"""
hooks = [
PathSamplingLayer.mutate,
PathSamplingInput.mutate,
PathSamplingRepeat.mutate,
PathSamplingCell.mutate,
]
hooks += [operation.mutate for operation in NATIVE_MIXED_OPERATIONS]
hooks.append(no_default_hook)
return hooks
def mutate_kwargs(self):
"""Use path sampling strategy for mixed-operations."""
return {
'mixed_op_sampling': MixedOpPathSamplingPolicy
}
def training_step(self, batch, batch_idx):
self.resample()
return self.model.training_step(batch, batch_idx)
def export(self) -> dict[str, Any]:
"""
Export of Random one-shot. It will return an arbitrary architecture.
"""
warnings.warn(
'Direct export from RandomOneShot returns an arbitrary architecture. '
'Sampling the best architecture from this trained supernet is another search process. '
'Users need to do another search based on the checkpoint of the one-shot strategy.',
UserWarning
)
return super().export()
class EnasLightningModule(RandomSamplingLightningModule):
_enas_note = """
RL controller learns to generate the best network on a super-net. See `ENAS paper <https://arxiv.org/abs/1802.03268>`__.
There are 2 steps in an epoch.
- Firstly, training model parameters.
- Secondly, training ENAS RL agent. The agent will produce a sample of model architecture to get the best reward.
.. note::
ENAS requires the evaluator to report metrics via ``self.log`` in its ``validation_step``.
See explanation of ``reward_metric_name`` for details.
The supported mutation primitives of ENAS are:
* :class:`nni.retiarii.nn.pytorch.LayerChoice`.
* :class:`nni.retiarii.nn.pytorch.InputChoice`.
* :class:`nni.retiarii.nn.pytorch.ValueChoice` (only when used in {supported_ops}).
* :class:`nni.retiarii.nn.pytorch.Repeat`.
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
{{module_notes}}
Parameters
----------
{{module_params}}
{base_params}
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`~nni.retiarii.oneshot.pytorch.enas.ReinforceController`.
entropy_weight : float
Weight of sample entropy loss in RL.
skip_weight : float
Weight of skip penalty loss. See :class:`~nni.retiarii.oneshot.pytorch.enas.ReinforceController` for details.
baseline_decay : float
Decay factor of reward baseline, which is used to normalize the reward in RL.
At each step, the new reward baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_steps_aggregate : int
Number of steps for which the gradients will be accumulated,
before updating the weights of RL controller.
ctrl_grad_clip : float
Gradient clipping value of controller.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
If there are multiple, by default, it will find the metric with key name ``default``.
If reward_metric_name is specified, it will find reward_metric_name.
Otherwise it raises an exception indicating multiple metrics are found.
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
)
__doc__ = _enas_note.format(
module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
@property
def automatic_optimization(self) -> bool:
return False
def __init__(self,
inner_module: pl.LightningModule,
*,
ctrl_kwargs: dict[str, Any] | None = None,
entropy_weight: float = 1e-4,
skip_weight: float = .8,
baseline_decay: float = .999,
ctrl_steps_aggregate: float = 20,
ctrl_grad_clip: float = 0,
reward_metric_name: str | None = None,
mutation_hooks: list[MutationHook] | None = None):
super().__init__(inner_module, mutation_hooks)
# convert parameter spec to legacy ReinforceField
# this part will be refactored
self.nas_fields: list[ReinforceField] = []
for name, param_spec in self.search_space_spec().items():
if param_spec.chosen_size not in (1, None):
raise ValueError('ENAS does not support n_chosen to be values other than 1 or None.')
self.nas_fields.append(ReinforceField(name, param_spec.size, param_spec.chosen_size == 1))
self.controller = ReinforceController(self.nas_fields, **(ctrl_kwargs or {}))
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.ctrl_grad_clip = ctrl_grad_clip
self.reward_metric_name = reward_metric_name
def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4)
def training_step(self, batch_packed, batch_idx):
batch, mode = batch_packed
if mode == 'train':
# train model params
with torch.no_grad():
self.resample()
self.call_weight_optimizers('zero_grad')
step_output = self.model.training_step(batch, batch_idx)
w_step_loss = step_output['loss'] \
if isinstance(step_output, dict) else step_output
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
else:
# train ENAS agent
arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad()
self.resample()
step_output = self.model.validation_step(batch, batch_idx)
# use the default metric of self.model as reward function
if len(self.trainer.callback_metrics) == 1:
_, metric = next(iter(self.trainer.callback_metrics.items()))
else:
metric_name = self.reward_metric_name or 'default'
if metric_name not in self.trainer.callback_metrics:
raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but '
f'found multiple (or zero) metrics without default: {list(self.trainer.callback_metrics.keys())}. '
f'Try to use self.log to report metrics with the specified key ``{metric_name}`` in validation_step, '
'and remember to set on_step=True.')
metric = self.trainer.callback_metrics[metric_name]
reward: float = metric.item()
if self.entropy_weight:
reward = reward + self.entropy_weight * self.controller.sample_entropy.item() # type: ignore
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
rnn_step_loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
rnn_step_loss = rnn_step_loss + self.skip_weight * self.controller.sample_skip_penalty
rnn_step_loss = rnn_step_loss / self.ctrl_steps_aggregate
self.manual_backward(rnn_step_loss)
if (batch_idx + 1) % self.ctrl_steps_aggregate == 0:
if self.ctrl_grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.ctrl_grad_clip)
arc_opt.step()
arc_opt.zero_grad()
return step_output
def resample(self):
"""Resample the architecture with ENAS controller."""
sample = self.controller.resample()
result = self._interpret_controller_sampling_result(sample)
for module in self.nas_modules:
module.resample(memo=result)
return result
def export(self):
"""Run one more inference of ENAS controller."""
self.controller.eval()
with torch.no_grad():
return self._interpret_controller_sampling_result(self.controller.resample())
def _interpret_controller_sampling_result(self, sample: dict[str, int]) -> dict[str, Any]:
"""Convert ``{label: index}`` to ``{label: name}``"""
space_spec = self.search_space_spec()
for key in list(sample.keys()):
sample[key] = space_spec[key].values[sample[key]]
return sample
from nni.nas.oneshot.pytorch.sampling import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Strategy integration of one-shot.
# pylint: disable=wildcard-import,unused-wildcard-import
This file is put here simply because it relies on "pytorch".
For consistency, please consider importing strategies from ``nni.retiarii.strategy``.
For example, ``nni.retiarii.strategy.DartsStrategy`` (this requires pytorch to be installed of course).
When adding/modifying a new strategy in this file, don't forget to link it in strategy/oneshot.py.
"""
from __future__ import annotations
import warnings
from typing import Any, Type
import torch.nn as nn
from nni.retiarii.graph import Model
from nni.retiarii.strategy.base import BaseStrategy
from nni.retiarii.evaluator.pytorch.lightning import Lightning, LightningModule
from .base_lightning import BaseOneShotLightningModule
from .differentiable import DartsLightningModule, ProxylessLightningModule, GumbelDartsLightningModule
from .sampling import EnasLightningModule, RandomSamplingLightningModule
class OneShotStrategy(BaseStrategy):
"""Wrap an one-shot lightning module as a one-shot strategy."""
def __init__(self, oneshot_module: Type[BaseOneShotLightningModule], **kwargs):
self.oneshot_module = oneshot_module
self.oneshot_kwargs = kwargs
self.model: BaseOneShotLightningModule | None = None
def preprocess_dataloader(self, train_dataloaders: Any, val_dataloaders: Any) -> tuple[Any, Any]:
"""
One-shot strategy typically requires fusing train and validation dataloader in an ad-hoc way.
As one-shot strategy doesn't try to open the blackbox of a batch,
theoretically, these dataloader can be
`any dataloader types supported by Lightning <https://pytorch-lightning.readthedocs.io/en/stable/guides/data.html>`__.
Returns
-------
A tuple of preprocessed train dataloaders and validation dataloaders.
"""
return train_dataloaders, val_dataloaders
def run(self, base_model: Model, applied_mutators):
# one-shot strategy doesn't use ``applied_mutators``
# but get the "mutators" on their own
_reason = 'The reason might be that you have used the wrong execution engine. Try to set engine to `oneshot` and try again.'
if not isinstance(base_model.python_object, nn.Module):
raise TypeError('Model is not a nn.Module. ' + _reason)
py_model: nn.Module = base_model.python_object
if applied_mutators:
raise ValueError('Mutator is not empty. ' + _reason)
if not isinstance(base_model.evaluator, Lightning):
raise TypeError('Evaluator needs to be a lightning evaluator to make one-shot strategy work.')
evaluator_module: LightningModule = base_model.evaluator.module
evaluator_module.running_mode = 'oneshot'
evaluator_module.set_model(py_model)
self.model = self.oneshot_module(evaluator_module, **self.oneshot_kwargs)
evaluator: Lightning = base_model.evaluator
if evaluator.train_dataloaders is None or evaluator.val_dataloaders is None:
raise TypeError('Training and validation dataloader are both required to set in evaluator for one-shot strategy.')
train_loader, val_loader = self.preprocess_dataloader(evaluator.train_dataloaders, evaluator.val_dataloaders)
evaluator.trainer.fit(self.model, train_loader, val_loader)
def export_top_models(self, top_k: int = 1) -> list[Any]:
"""The behavior of export top models in strategy depends on the implementation of inner one-shot module."""
if self.model is None:
raise RuntimeError('One-shot strategy needs to be run before export.')
if top_k != 1:
warnings.warn('One-shot strategy currently only supports exporting top-1 model.', RuntimeWarning)
return [self.model.export()]
class DARTS(OneShotStrategy):
__doc__ = DartsLightningModule._darts_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(DartsLightningModule, **kwargs)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
# By returning a dict, we make a CombinedLoader (in Lightning)
return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class Proxyless(OneShotStrategy):
__doc__ = ProxylessLightningModule._proxyless_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(ProxylessLightningModule, **kwargs)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class GumbelDARTS(OneShotStrategy):
__doc__ = GumbelDartsLightningModule._gumbel_darts_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(GumbelDartsLightningModule, **kwargs)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
return {
'train': train_dataloaders,
'val': val_dataloaders
}, None
class ENAS(OneShotStrategy):
__doc__ = EnasLightningModule._enas_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(EnasLightningModule, **kwargs)
def preprocess_dataloader(self, train_dataloaders, val_dataloaders):
# Import locally to avoid import error on legacy PL version
from .dataloader import ConcatLoader
return ConcatLoader({
'train': train_dataloaders,
'val': val_dataloaders
}), None
class RandomOneShot(OneShotStrategy):
__doc__ = RandomSamplingLightningModule._random_note.format(module_notes='', module_params='')
def __init__(self, **kwargs):
super().__init__(RandomSamplingLightningModule, **kwargs)
from nni.nas.oneshot.pytorch.strategy import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Thie file handles "slice" commonly used in mixed-operation.
# pylint: disable=wildcard-import,unused-wildcard-import
The ``slice_type`` we support here, is "slice" or "list of slice".
The reason is that sometimes (e.g., in multi-head attention),
the tensor slice could be from multiple parts. This type is extensible.
We can support arbitrary masks in future if we need them.
To slice a tensor, we need ``multidim_slice``,
which is simply a tuple consists of ``slice_type``.
Usually in python programs, the variable put into slice's start, stop and step
should be integers (or NoneType).
But in our case, it could also be a dict from integer to float,
representing a distribution of integer. When that happens,
we convert a "slice with some weighted values", to a "weighted slice".
To this end, we track the computation with ``MaybeWeighted``,
and replay the computation with each possible value.
Meanwhile, we record their weights.
Note that ``MaybeWeighted`` is also extensible.
We can support more types of objects on slice in future.
The fixed/weighted slice is fed into ``_slice_weight``,
which interprets the slice and apply it on a tensor.
"""
from __future__ import annotations
import operator
from typing import Callable, Iterator, TypeVar, Any, Optional, Tuple, Union, List, Dict, Generic, cast
import numpy as np
import torch
__all__ = [
'slice_type',
'multidim_slice',
'scalar_or_scalar_dict',
'int_or_int_dict',
'zeros_like',
'Slicable',
'MaybeWeighted',
]
T = TypeVar('T')
slice_type = Union[slice, List[slice]]
multidim_slice = Tuple[slice_type, ...]
scalar_or_scalar_dict = Union[T, Dict[T, float]]
int_or_int_dict = scalar_or_scalar_dict[int]
_value_fn_type = Optional[Callable[[int_or_int_dict], int]]
def zeros_like(arr: T) -> T:
if isinstance(arr, np.ndarray):
return np.zeros_like(arr)
elif isinstance(arr, torch.Tensor):
return torch.zeros_like(arr)
else:
raise TypeError(f'Unsupported type for {arr}: {type(arr)}')
def _eliminate_list_slice(shape: tuple, slice_: multidim_slice) -> multidim_slice:
# get rid of list of slice
result = []
for i in range(len(slice_)):
if isinstance(slice_[i], list):
# convert list of slices to mask
mask = np.zeros(shape[i], dtype=np.bool) # type: ignore
for sl in cast(List[slice], slice_[i]):
mask[sl] = 1
result.append(mask)
else:
result.append(slice_[i])
return tuple(result)
def _slice_weight(weight: T, slice_: multidim_slice | list[tuple[multidim_slice, float]]) -> T:
# slice_ can be a tuple of slice, e.g., ([3:6], [2:4])
# or tuple of slice -> float, e.g. {([3:6],): 0.6, ([2:4],): 0.3}
if isinstance(slice_, list):
# for weighted case, we get the corresponding masks. e.g.,
# {([3:6],): 0.6, ([2:4],): 0.3} => [0, 0, 0.3, 0.9, 0.6, 0.6] (if the whole length is 6)
# this mask is broadcasted and multiplied onto the weight
masks = []
# the accepted argument is list of tuple here
# because slice can't be key of dict
for sl, wt in slice_:
# create a mask with weight w
with torch.no_grad():
mask = zeros_like(weight)
mask[_eliminate_list_slice(weight.shape, sl)] = 1 # type: ignore
# track gradients here
masks.append(mask * wt) # type: ignore
masks = sum(masks)
return masks * weight # type: ignore
else:
# for unweighted case, we slice it directly.
def _do_slice(arr, slice_):
return arr[_eliminate_list_slice(arr.shape, slice_)] # type: ignore
# sometimes, we don't need slice.
# this saves an op on computational graph, which will hopefully make training faster
# Use a dummy array to check this. Otherwise it would be too complex.
dummy_arr = np.zeros(weight.shape, dtype=bool) # type: ignore
no_effect = cast(Any, _do_slice(dummy_arr, slice_)).shape == dummy_arr.shape
if no_effect:
return weight
return _do_slice(weight, slice_)
class Slicable(Generic[T]):
"""Wraps the weight so that in can be sliced with a ``multidim_slice``.
The value within the slice can be instances of :class:`MaybeWeighted`.
Examples
--------
>>> weight = conv2d.weight
>>> Slicable(weight)[:MaybeWeighted({32: 0.4, 64: 0.6})]
Tensor of shape (64, 64, 3, 3)
"""
def __init__(self, weight: T):
if not isinstance(weight, np.ndarray) and not torch.is_tensor(weight):
raise TypeError(f'Unsuppoted weight type: {type(weight)}')
self.weight = weight
def __getitem__(self, index: slice_type | multidim_slice | Any) -> T:
if not isinstance(index, tuple):
index = (index, )
index = cast(multidim_slice, index)
# Get the dict value in index's leafs
# There can be at most one dict
leaf_dict: dict[int, float] | None = None
for maybe_weighted in _iterate_over_multidim_slice(index):
for d in maybe_weighted.leaf_values():
if isinstance(d, dict):
if leaf_dict is None:
leaf_dict = d
elif leaf_dict is not d:
raise ValueError('There can be at most one distinct dict in leaf values.')
if leaf_dict is None:
# in case of simple types with no dict
res_index = _evaluate_multidim_slice(index)
else:
# there is a dict, iterate over dict
res_index = []
for val, wt in leaf_dict.items():
res_index_item = _evaluate_multidim_slice(index, lambda _: val)
res_index.append((res_index_item, wt))
return _slice_weight(self.weight, res_index)
class MaybeWeighted:
"""Wrap a value (int or dict with int keys), so that the computation on it can be replayed.
It builds a binary tree. If ``value`` is not None, it's a leaf node.
Otherwise, it has left sub-tree and right sub-tree and an operation.
Only support basic arithmetic operations: ``+``, ``-``, ``*``, ``//``.
"""
def __init__(self,
value: int_or_int_dict | None = None, *,
lhs: 'MaybeWeighted' | int | None = None,
rhs: 'MaybeWeighted' | int | None = None,
operation: Callable[[int_or_int_dict, int_or_int_dict], int_or_int_dict] | None = None):
if operation is None:
if not isinstance(value, (int, dict)):
raise TypeError(f'Unsupported value type: {type(value)}')
self.value = value
self.lhs = lhs
self.rhs = rhs
self.operation = operation
def leaf_values(self) -> Iterator[int_or_int_dict]:
"""Iterate over values on leaf nodes."""
if self.value is not None:
yield self.value
else:
if isinstance(self.lhs, MaybeWeighted):
yield from self.lhs.leaf_values()
if isinstance(self.rhs, MaybeWeighted):
yield from self.rhs.leaf_values()
def evaluate(self, value_fn: _value_fn_type = None) -> int_or_int_dict:
"""Evaluate the value on root node, after replacing every value on leaf node with ``value_fn``.
If ``value_fn`` is none, no replacement will happen and the raw value will be used.
"""
if self.value is not None:
if value_fn is not None:
return value_fn(self.value)
return self.value
else:
if isinstance(self.lhs, MaybeWeighted):
eval_lhs = self.lhs.evaluate(value_fn)
else:
eval_lhs = cast(int, self.lhs)
if isinstance(self.rhs, MaybeWeighted):
eval_rhs = self.rhs.evaluate(value_fn)
else:
eval_rhs = cast(int, self.rhs)
assert self.operation is not None
return self.operation(eval_lhs, eval_rhs)
def __repr__(self):
if self.value is not None:
return f'{self.__class__.__name__}({self.value})'
return f'{self.__class__.__name__}(lhs={self.lhs}, rhs={self.rhs}, op={self.operation})'
def __add__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.add)
def __radd__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.add)
def __sub__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.sub)
def __rsub__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.sub)
def __mul__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.mul)
def __rmul__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.mul)
def __floordiv__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=self, rhs=other, operation=operator.floordiv)
def __rfloordiv__(self, other: Any) -> 'MaybeWeighted':
return MaybeWeighted(lhs=other, rhs=self, operation=operator.floordiv)
def _iterate_over_slice_type(s: slice_type):
if isinstance(s, list):
for se in s:
yield from _iterate_over_slice_type(se)
else:
# s must be a "slice" now
if isinstance(s.start, MaybeWeighted):
yield s.start
if isinstance(s.stop, MaybeWeighted):
yield s.stop
if isinstance(s.step, MaybeWeighted):
yield s.step
def _iterate_over_multidim_slice(ms: multidim_slice):
"""Get :class:`MaybeWeighted` instances in ``ms``."""
for s in ms:
if s is not None and s is not Ellipsis:
yield from _iterate_over_slice_type(s)
def _evaluate_slice_type(s: slice_type, value_fn: _value_fn_type = None):
if isinstance(s, list):
return [_evaluate_slice_type(se, value_fn) for se in s]
else:
return slice(
s.start.evaluate(value_fn) if isinstance(s.start, MaybeWeighted) else s.start,
s.stop.evaluate(value_fn) if isinstance(s.stop, MaybeWeighted) else s.stop,
s.step.evaluate(value_fn) if isinstance(s.step, MaybeWeighted) else s.step
)
def _evaluate_multidim_slice(ms: multidim_slice, value_fn: _value_fn_type = None):
"""Wraps :meth:`MaybeWeighted.evaluate` to evaluate the whole ``multidim_slice``."""
res = []
for s in ms:
if s is not None and s is not Ellipsis:
res.append(_evaluate_slice_type(s, value_fn))
else:
res.append(s)
return tuple(res)
from nni.nas.oneshot.pytorch.supermodule._operation_utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
# pylint: skip-file
# type: ignore
# pylint: disable=wildcard-import,unused-wildcard-import
"""This file is an incomplete implementation of `Single-path NAS <https://arxiv.org/abs/1904.02877>`__.
These are merely some components of the algorithm. The complete support is an undergoing work item.
Keep this file here so that it can be "blamed".
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.retiarii.nn.pytorch import ValueChoice
class DifferentiableSuperConv2d(nn.Conv2d):
"""
Only ``kernel_size`` ``in_channels`` and ``out_channels`` are supported. Kernel size candidates should be larger or smaller
than each other in both candidates. See examples below:
the following example is not allowed:
>>> ValueChoice(candidates = [(5, 3), (3, 5)])
□ ■ ■ ■ □ □ □ □ □ □
□ ■ ■ ■ □ ■ ■ ■ ■ ■ # candidates are not bigger or smaller on both dimension
□ ■ ■ ■ □ ■ ■ ■ ■ ■
□ ■ ■ ■ □ ■ ■ ■ ■ ■
□ ■ ■ ■ □ □ □ □ □ □
the following 3 examples are valid:
>>> ValueChoice(candidates = [5, 3, 1])
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ ■ □ □
■ ■ ■ ■ ■ □ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
>>> ValueChoice(candidates = [(5, 7), (3, 5), (1, 3)])
■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ ■ ■ ■ □ □
■ ■ ■ ■ ■ ■ ■ □ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □ □ □ □ □
>>> # when the difference between any two candidates is not even, the left upper will be picked:
>>> ValueChoice(candidates = [(5, 5), (4, 4), (3, 3)])
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ □ □ □ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ ■ ■ ■ ■ □ □ ■ ■ ■ □
■ ■ ■ ■ ■ □ □ □ □ □ □ □ □ □ □
"""
def __init__(self, module, name):
self.label = name
args = module.trace_kwargs
# compulsory params
if isinstance(args['in_channels'], ValueChoice):
args['in_channels'] = max(args['in_channels'].candidates)
self.out_channel_candidates = None
if isinstance(args['out_channels'], ValueChoice):
self.out_channel_candidates = sorted(args['out_channels'].candidates, reverse=True)
args['out_channels'] = self.out_channel_candidates[0]
# kernel_size may be an int or tuple, we turn it into a tuple for simplicity
self.kernel_size_candidates = None
if isinstance(args['kernel_size'], ValueChoice):
# unify kernel size as tuple
candidates = args['kernel_size'].candidates
if not isinstance(candidates[0], tuple):
candidates = [(k, k) for k in candidates]
# sort kernel size in descending order
self.kernel_size_candidates = sorted(candidates, key=lambda t: t[0], reverse=True)
for i in range(0, len(self.kernel_size_candidates) - 1):
bigger = self.kernel_size_candidates[i]
smaller = self.kernel_size_candidates[i + 1]
assert bigger[1] > smaller[1] or (bigger[1] == smaller[1] and bigger[0] > smaller[0]), f'Kernel_size candidates ' \
f'should be larger or smaller than each other on both dimensions, but found {bigger} and {smaller}.'
args['kernel_size'] = self.kernel_size_candidates[0]
super().__init__(**args)
self.generate_architecture_params()
def forward(self, input):
# Note that there is no need to handle ``in_channels`` here since it is already handle by the ``out_channels`` in the
# previous module. If we multiply alpha with refer to ``in_channels`` here again, the alpha will indeed be considered
# twice, which is not what we expect.
weight = self.weight
def sum_weight(input_weight, masks, thresholds, indicator):
"""
This is to get the weighted sum of weight.
Parameters
----------
input_weight : Tensor
the weight to be weighted summed
masks : list[Tensor]
weight masks.
thresholds : list[float]
thresholds, should have a length of ``len(masks) - 1``
indicator : Callable[[Tensor, float], float]
take a tensor and a threshold as input, and output the weight
Returns
----------
weight : Tensor
weighted sum of ``input_weight``. this is of the same shape as ``input_sum``
"""
# Note that ``masks`` and ``thresholds`` have different lengths. There alignment is shown below:
# self.xxx_candidates = [ c_0 , c_1 , ... , c_n-2 , c_n-1 ] # descending order
# self.xxx_mask = [ mask_0 , mask_1 , ... , mask_n-2, mask_n-1]
# self.t_xxx = [ t_0 , t_2 , ... , t_n-2 ]
# So we zip the first n-1 items, and multiply masks[-1] in the end.
weight = torch.zeros_like(input_weight)
for mask, t in zip(masks[:-1], thresholds):
cur_part = input_weight * mask
alpha = indicator(cur_part, t)
weight = (weight + cur_part) * alpha
# we do not consider skip-op here for out_channel/expansion candidates, which means at least the smallest channel
# candidate is included
weight += input_weight * masks[-1]
return weight
if self.kernel_size_candidates is not None:
weight = sum_weight(weight, self.kernel_masks, self.t_kernel, self.Lasso_sigmoid)
if self.out_channel_candidates is not None:
weight = sum_weight(weight, self.channel_masks, self.t_expansion, self.Lasso_sigmoid)
output = self._conv_forward(input, weight, self.bias)
return output
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super().named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
"""
result = {
'kernel_size': i,
'out_channels': j
}
which means the best candidate for an argument is the i-th one if candidates are sorted in descending order
"""
result = {}
eps = 1e-5
with torch.no_grad():
if self.kernel_size_candidates is not None:
weight = torch.zeros_like(self.weight)
# ascending order
for i in range(len(self.kernel_size_candidates) - 2, -1, -1):
mask = self.kernel_masks[i]
t = self.t_kernel[i]
cur_part = self.weight * mask
alpha = self.Lasso_sigmoid(cur_part, t)
if alpha <= eps: # takes the smaller one
result['kernel_size'] = self.kernel_size_candidates[i + 1]
break
weight = (weight + cur_part) * alpha
if 'kernel_size' not in result:
result['kernel_size'] = self.kernel_size_candidates[0]
else:
weight = self.weight
if self.out_channel_candidates is not None:
for i in range(len(self.out_channel_candidates) - 2, -1, -1):
mask = self.channel_masks[i]
t = self.t_expansion[i]
alpha = self.Lasso_sigmoid(weight * mask, t)
if alpha <= eps:
result['out_channels'] = self.out_channel_candidates[i + 1]
if 'out_channels' not in result:
result['out_channels'] = self.out_channel_candidates[0]
return result
@staticmethod
def Lasso_sigmoid(matrix, t):
"""
A trick that can make use of both the value of bool(lasso > t) and the gradient of sigmoid(lasso - t)
Parameters
----------
matrix : Tensor
the matrix to calculate lasso norm
t : float
the threshold
"""
lasso = torch.norm(matrix) - t
indicator = (lasso > 0).float() # torch.sign(lasso)
with torch.no_grad():
# indicator = indicator / 2 + .5 # realign indicator from (-1, 1) to (0, 1)
indicator -= F.sigmoid(lasso)
indicator += F.sigmoid(lasso)
return indicator
def generate_architecture_params(self):
self.alpha = {}
if self.kernel_size_candidates is not None:
# kernel size arch params
self.t_kernel = nn.Parameter(torch.rand(len(self.kernel_size_candidates) - 1))
self.alpha['kernel_size'] = self.t_kernel
# kernel size mask
self.kernel_masks = []
for i in range(0, len(self.kernel_size_candidates) - 1):
big_size = self.kernel_size_candidates[i]
small_size = self.kernel_size_candidates[i + 1]
mask = torch.zeros_like(self.weight)
mask[:, :, :big_size[0], :big_size[1]] = 1 # if self.weight.shape = (out, in, 7, 7), big_size = (5, 5) and
mask[:, :, :small_size[0], :small_size[1]] = 0 # small_size = (3, 3), mask will look like:
self.kernel_masks.append(mask) # 0 0 0 0 0 0 0
mask = torch.zeros_like(self.weight) # 0 1 1 1 1 1 0
mask[:, :, :self.kernel_size_candidates[-1][0], :self.kernel_size_candidates[-1][1]] = 1 # 0 1 0 0 0 1 0
self.kernel_masks.append(mask) # 0 1 0 0 0 1 0
# 0 1 0 0 0 1 0
if self.out_channel_candidates is not None: # 0 1 1 1 1 1 0
# out_channel (or expansion) arch params. we do not consider skip-op here, so we # 0 0 0 0 0 0 0
# only generate ``len(self.kernel_size_candidates) - 1 `` thresholds
self.t_expansion = nn.Parameter(torch.rand(len(self.out_channel_candidates) - 1))
self.alpha['out_channels'] = self.t_expansion
self.channel_masks = []
for i in range(0, len(self.out_channel_candidates) - 1):
big_channel, small_channel = self.out_channel_candidates[i], self.out_channel_candidates[i + 1]
mask = torch.zeros_like(self.weight)
mask[:big_channel] = 1
mask[:small_channel] = 0
# if self.weight.shape = (32, in, W, H), big_channel = 16 and small_size = 8, mask will look like:
# 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
self.channel_masks.append(mask)
mask = torch.zeros_like(self.weight)
mask[:self.out_channel_candidates[-1]] = 1
self.channel_masks.append(mask)
class DifferentiableBatchNorm2d(nn.BatchNorm2d):
def __init__(self, module, name):
self.label = name
args = module.trace_kwargs
if isinstance(args['num_features'], ValueChoice):
args['num_features'] = max(args['num_features'].candidates)
super().__init__(**args)
# no architecture parameter is needed for BatchNorm2d Layers
self.alpha = nn.Parameter(torch.tensor([]))
def export(self):
"""
No need to export ``BatchNorm2d``. Refer to the ``Conv2d`` layer that has the ``ValueChoice`` as ``out_channels``.
"""
return -1
from nni.nas.oneshot.pytorch.supermodule._singlepathnas import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Utilities to process the value choice compositions,
in the way that is most convenient to one-shot algorithms."""
# pylint: disable=wildcard-import,unused-wildcard-import
from __future__ import annotations
import itertools
from typing import Any, TypeVar, List, cast, Mapping, Sequence, Optional, Iterable
import numpy as np
import torch
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch.api import ChoiceOf, ValueChoiceX
Choice = Any
T = TypeVar('T')
__all__ = [
'dedup_inner_choices',
'evaluate_value_choice_with_dict',
'traverse_all_options',
'weighted_sum',
'evaluate_constant',
]
def dedup_inner_choices(value_choices: list[ValueChoiceX]) -> dict[str, ParameterSpec]:
"""Find all leaf nodes in ``value_choices``,
save them into in the format of ``{label: parameter_spec}``.
"""
result = {}
for value_choice in value_choices:
for choice in value_choice.inner_choices():
param_spec = ParameterSpec(choice.label, 'choice', choice.candidates, (choice.label, ), True, size=len(choice.candidates))
if choice.label in result:
if param_spec != result[choice.label]:
raise ValueError('Value choice conflict: same label with different candidates: '
f'{param_spec} vs. {result[choice.label]}')
else:
result[choice.label] = param_spec
return result
def evaluate_value_choice_with_dict(value_choice: ChoiceOf[T], chosen: dict[str, Choice]) -> T:
"""To evaluate a composition of value-choice with a dict,
with format of ``{label: chosen_value}``.
The implementation is two-pass. We first get a list of values,
then feed the values into ``value_choice.evaluate``.
This can be potentially optimized in terms of speed.
Examples
--------
>>> chosen = {"exp_ratio": 3}
>>> evaluate_value_choice_with_dict(value_choice_in, chosen)
48
>>> evaluate_value_choice_with_dict(value_choice_out, chosen)
96
"""
choice_inner_values = []
for choice in value_choice.inner_choices():
if choice.label not in chosen:
raise KeyError(f'{value_choice} depends on a value with key {choice.label}, but not found in {chosen}')
choice_inner_values.append(chosen[choice.label])
return value_choice.evaluate(choice_inner_values)
def traverse_all_options(
value_choice: ChoiceOf[T],
weights: dict[str, list[float]] | dict[str, np.ndarray] | dict[str, torch.Tensor] | None = None
) -> list[tuple[T, float]] | list[T]:
"""Traverse all possible computation outcome of a value choice.
If ``weights`` is not None, it will also compute the probability of each possible outcome.
Parameters
----------
value_choice : ValueChoiceX
The value choice to traverse.
weights : Optional[dict[str, list[float]]], default = None
If there's a prior on leaf nodes, and we intend to know the (joint) prior on results,
weights can be provided. The key is label, value are list of float indicating probability.
Normally, they should sum up to 1, but we will not check them in this function.
Returns
-------
list[Union[tuple[Any, float], Any]]
Results will be sorted and duplicates will be eliminated.
If weights is provided, the return value will be a list of tuple, with option and its weight.
Otherwise, it will be a list of options.
"""
# get a dict of {label: list of tuple of choice and weight}
leafs: dict[str, list[tuple[T, float]]] = {}
for label, param_spec in dedup_inner_choices([value_choice]).items():
if weights is not None:
if label not in weights:
raise KeyError(f'{value_choice} depends on a weight with key {label}, but not found in {weights}')
if len(weights[label]) != param_spec.size:
raise KeyError(f'Expect weights with {label} to be of length {param_spec.size}, but {len(weights[label])} found')
leafs[label] = list(zip(param_spec.values, cast(List[float], weights[label])))
else:
# create a dummy weight of zero, in case that weights are not provided.
leafs[label] = list(zip(param_spec.values, itertools.repeat(0., param_spec.size)))
# result is a dict from a option to its weight
result: dict[T, float | None] = {}
labels, values = list(leafs.keys()), list(leafs.values())
if not labels:
raise ValueError(f'There expects at least one leaf value choice in {value_choice}, but nothing found')
# get all combinations
for prod_value in itertools.product(*values):
# For example,
# prod_value = ((3, 0.1), ("cat", 0.3), ({"in": 5}, 0.5))
# the first dim is chosen value, second dim is probability
# chosen = {"ks": 3, "animal": "cat", "linear_args": {"in": 5}}
# chosen_weight = np.prod([0.1, 0.3, 0.5])
chosen = {label: value[0] for label, value in zip(labels, prod_value)}
eval_res = evaluate_value_choice_with_dict(value_choice, chosen)
if weights is None:
result[eval_res] = None
else:
# we can't use reduce or inplace product here,
# because weight can sometimes be tensors
chosen_weight = prod_value[0][1]
for value in prod_value[1:]:
if chosen_weight is None:
chosen_weight = value[1]
else:
chosen_weight = chosen_weight * value[1]
if eval_res in result:
result[eval_res] = result[eval_res] + chosen_weight
else:
result[eval_res] = chosen_weight
if weights is None:
return sorted(result.keys()) # type: ignore
else:
return sorted(result.items()) # type: ignore
def evaluate_constant(expr: Any) -> Any:
"""Evaluate a value choice expression to a constant. Raise ValueError if it's not a constant."""
all_options = traverse_all_options(expr)
if len(all_options) > 1:
raise ValueError(f'{expr} is not evaluated to a constant. All possible values are: {all_options}')
res = all_options[0]
return res
def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence[Optional[float]], None)) -> T:
"""Return a weighted sum of items.
Items can be list of tensors, numpy arrays, or nested lists / dicts.
If ``weights`` is None, this is simply an unweighted sum.
"""
if weights is None:
weights = [None] * len(items)
assert len(items) == len(weights) > 0
elem = items[0]
unsupported_msg = f'Unsupported element type in weighted sum: {type(elem)}. Value is: {elem}'
if isinstance(elem, str):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise TypeError(unsupported_msg)
try:
if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)):
if weights[0] is None:
res = elem
else:
res = elem * weights[0]
for it, weight in zip(items[1:], weights[1:]):
if type(it) != type(elem):
raise TypeError(f'Expect type {type(elem)} but found {type(it)}. Can not be summed')
if weight is None:
res = res + it # type: ignore
else:
res = res + it * weight # type: ignore
return cast(T, res)
if isinstance(elem, Mapping):
for item in items:
if not isinstance(item, Mapping):
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
if set(item) != set(elem):
raise KeyError(f'Expect keys {list(elem)} but found {list(item)}')
return cast(T, {
key: weighted_sum(cast(List[dict], [cast(Mapping, d)[key] for d in items]), weights) for key in elem
})
if isinstance(elem, Sequence):
for item in items:
if not isinstance(item, Sequence):
raise TypeError(f'Expect type {type(elem)} but found {type(item)}')
if len(item) != len(elem):
raise ValueError(f'Expect length {len(item)} but found {len(elem)}')
transposed = cast(Iterable[list], zip(*items)) # type: ignore
return cast(T, [weighted_sum(column, weights) for column in transposed])
except (TypeError, ValueError, RuntimeError, KeyError):
raise ValueError(
'Error when summing items. Value format / shape does not match. See full traceback for details.' +
''.join([
f'\n {idx}: {_summarize_elem_format(it)}' for idx, it in enumerate(items)
])
)
# Dealing with all unexpected types.
raise TypeError(unsupported_msg)
def _summarize_elem_format(elem: Any) -> Any:
# Get a summary of one elem
# Helps generate human-readable error messages
class _repr_object:
# empty object is only repr
def __init__(self, representation):
self.representation = representation
def __repr__(self):
return self.representation
if isinstance(elem, torch.Tensor):
return _repr_object('torch.Tensor(' + ', '.join(map(str, elem.shape)) + ')')
if isinstance(elem, np.ndarray):
return _repr_object('np.array(' + ', '.join(map(str, elem.shape)) + ')')
if isinstance(elem, Mapping):
return {key: _summarize_elem_format(value) for key, value in elem.items()}
if isinstance(elem, Sequence):
return [_summarize_elem_format(value) for value in elem]
# fallback to original, for cases like float, int, ...
return elem
from nni.nas.oneshot.pytorch.supermodule._valuechoice_utils import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
# pylint: disable=wildcard-import,unused-wildcard-import
from typing import Any
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
__all__ = ['BaseSuperNetModule']
class BaseSuperNetModule(nn.Module):
"""
Mutated module in super-net.
Usually, the feed-forward of the module itself is undefined.
It has to be resampled with ``resample()`` so that a specific path is selected.
(Sometimes, this is not required. For example, differentiable super-net.)
A super-net module usually corresponds to one sample. But two exceptions:
* A module can have multiple parameter spec. For example, a convolution-2d can sample kernel size, channels at the same time.
* Multiple modules can share one parameter spec. For example, multiple layer choices with the same label.
For value choice compositions, the parameter spec are bounded to the underlying (original) value choices,
rather than their compositions.
"""
def resample(self, memo: dict[str, Any]) -> dict[str, Any]:
"""
Resample the super-net module.
Parameters
----------
memo : dict[str, Any]
Used to ensure the consistency of samples with the same label.
Returns
-------
dict
Sampled result. If nothing new is sampled, it should return an empty dict.
"""
raise NotImplementedError()
def export(self, memo: dict[str, Any]) -> dict[str, Any]:
"""
Export the final architecture within this module.
It should have the same keys as ``search_space_spec()``.
Parameters
----------
memo : dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise NotImplementedError()
def search_space_spec(self) -> dict[str, ParameterSpec]:
"""
Space specification (sample points).
Mapping from spec name to ParameterSpec. The names in choices should be in the same format of export.
For example: ::
{"layer1": ParameterSpec(values=["conv", "pool"])}
"""
raise NotImplementedError()
@classmethod
def mutate(cls, module: nn.Module, name: str, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> \
'BaseSuperNetModule' | bool | tuple['BaseSuperNetModule', bool]:
"""This is a mutation hook that creates a :class:`BaseSuperNetModule`.
The method should be implemented in each specific super-net module,
because they usually have specific rules about what kind of modules to operate on.
Parameters
----------
module : nn.Module
The module to be mutated (replaced).
name : str
Name of this module. With full prefix. For example, ``module1.block1.conv``.
memo : dict
Memo to enable sharing parameters among mutated modules. It should be read and written by
mutate functions themselves.
mutate_kwargs : dict
Algo-related hyper-parameters, and some auxiliary information.
Returns
-------
Union[BaseSuperNetModule, bool, tuple[BaseSuperNetModule, bool]]
The mutation result, along with an optional boolean flag indicating whether to suppress follow-up mutation hooks.
See :class:`BaseOneShotLightningModule <nni.retiarii.oneshot.pytorch.base_lightning.BaseOneShotLightningModule>` for details.
"""
raise NotImplementedError()
from nni.nas.oneshot.pytorch.supermodule.base import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import functools
import logging
import warnings
from typing import Any, Dict, Sequence, List, Tuple, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ChoiceOf, Repeat
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import preprocess_cell_inputs
from .base import BaseSuperNetModule
from .operation import MixedOperation, MixedOperationSamplingPolicy
from .sampling import PathSamplingCell
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, weighted_sum
_logger = logging.getLogger(__name__)
__all__ = [
'DifferentiableMixedLayer', 'DifferentiableMixedInput',
'DifferentiableMixedRepeat', 'DifferentiableMixedCell',
'MixedOpDifferentiablePolicy',
]
class GumbelSoftmax(nn.Softmax):
"""Wrapper of ``F.gumbel_softmax``. dim = -1 by default."""
dim: int
def __init__(self, dim: int = -1) -> None:
super().__init__(dim)
self.tau = 1
self.hard = False
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
return F.gumbel_softmax(inputs, tau=self.tau, hard=self.hard, dim=self.dim)
class DifferentiableMixedLayer(BaseSuperNetModule):
"""
Mixed layer, in which fprop is decided by a weighted sum of several layers.
Proposed in `DARTS: Differentiable Architecture Search <https://arxiv.org/abs/1806.09055>`__.
The weight ``alpha`` is usually learnable, and optimized on validation dataset.
Differentiable sampling layer requires all operators returning the same shape for one input,
as all outputs will be weighted summed to get the final output.
Parameters
----------
paths : list[tuple[str, nn.Module]]
Layers to choose from. Each is a tuple of name, and its module.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
op_names : str
Operator names.
label : str
Name of the choice.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self,
paths: list[tuple[str, nn.Module]],
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__()
self.op_names = []
if len(alpha) != len(paths):
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({len(paths)}).')
for name, module in paths:
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self.label = label
self._arch_alpha = alpha
self._softmax = softmax
def resample(self, memo):
"""Do nothing. Differentiable layer doesn't need resample."""
return {}
def export(self, memo):
"""Choose the operator with the maximum logit."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, LayerChoice):
size = len(module)
if module.label in memo:
alpha = memo[module.label]
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(list(module.named_children()), alpha, softmax, module.label)
def reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, *args, **kwargs):
"""The forward of mixed layer accepts same arguments as its sub-layer."""
all_op_results = [getattr(self, op)(*args, **kwargs) for op in self.op_names]
return self.reduction(all_op_results, self._softmax(self._arch_alpha))
def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
"""Named parameters excluding architecture parameters."""
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name == par_name for par_name in self._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
class DifferentiableMixedInput(BaseSuperNetModule):
"""
Mixed input. Forward returns a weighted sum of candidates.
Implementation is very similar to :class:`DifferentiableMixedLayer`.
Parameters
----------
n_candidates : int
Expect number of input candidates.
n_chosen : int
Expect numebr of inputs finally chosen.
alpha : Tensor
Tensor that stores the "learnable" weights.
softmax : nn.Module
Customizable softmax function. Usually ``nn.Softmax(-1)``.
label : str
Name of the choice.
Attributes
----------
label : str
Name of the choice.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self,
n_candidates: int,
n_chosen: int | None,
alpha: torch.Tensor,
softmax: nn.Module,
label: str):
super().__init__()
self.n_candidates = n_candidates
if len(alpha) != n_candidates:
raise ValueError(f'The size of alpha ({len(alpha)}) must match number of candidates ({n_candidates}).')
if n_chosen is None:
warnings.warn('Differentiable architecture search does not support choosing multiple inputs. Assuming one.',
RuntimeWarning)
self.n_chosen = 1
self.n_chosen = n_chosen
self.label = label
self._softmax = softmax
self._arch_alpha = alpha
def resample(self, memo):
"""Do nothing. Differentiable layer doesn't need resample."""
return {}
def export(self, memo):
"""Choose the operator with the top ``n_chosen`` logits."""
if self.label in memo:
return {} # nothing new to export
chosen = sorted(torch.argsort(-self._arch_alpha).cpu().numpy().tolist()[:self.n_chosen])
if len(chosen) == 1:
chosen = chosen[0]
return {self.label: chosen}
def search_space_spec(self):
return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
(self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen)
}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, InputChoice):
if module.reduction not in ['sum', 'mean']:
raise ValueError('Only input choice of sum/mean reduction is supported.')
size = module.n_candidates
if module.label in memo:
alpha = memo[module.label]
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label)
def reduction(self, items: list[Any], weights: list[float]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, inputs):
"""Forward takes a list of input candidates."""
return self.reduction(inputs, self._softmax(self._arch_alpha))
def parameters(self, *args, **kwargs):
"""Parameters excluding architecture parameters."""
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
"""Named parameters excluding architecture parameters."""
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name == par_name for par_name in self._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
"""Implementes the differentiable sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Thus the ``_arch_alpha`` here is a parameter dict, and ``named_parameters``
filters out multiple parameters with ``_arch_alpha`` as its prefix.
When this class is asked for ``forward_argument``, it returns a distribution,
i.e., a dict from int to float based on its weights.
All the parameters (``_arch_alpha``, ``parameters()``, ``_softmax``) are
saved as attributes of ``operation``, rather than ``self``,
because this class itself is not a ``nn.Module``, and saved parameters here
won't be optimized.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
operation._arch_alpha = nn.ParameterDict()
for name, spec in operation.search_space_spec().items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
operation._arch_alpha[name] = alpha
operation.parameters = functools.partial(self.parameters, module=operation) # bind self
operation.named_parameters = functools.partial(self.named_parameters, module=operation)
operation._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
@staticmethod
def parameters(module, *args, **kwargs):
for _, p in module.named_parameters(*args, **kwargs):
yield p
@staticmethod
def named_parameters(module, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super(module.__class__, module).named_parameters(*args, **kwargs): # pylint: disable=bad-super-call
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Differentiable. Do nothing in resample."""
return {}
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is argmax for each leaf value choice."""
result = {}
for name, spec in operation.search_space_spec().items():
if name in memo:
continue
chosen_index = int(torch.argmax(cast(dict, operation._arch_alpha)[name]).item())
result[name] = spec.values[chosen_index]
return result
def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any:
if name in operation.mutable_arguments:
weights: dict[str, torch.Tensor] = {
label: cast(nn.Module, operation._softmax)(alpha) for label, alpha in cast(dict, operation._arch_alpha).items()
}
return dict(traverse_all_options(operation.mutable_arguments[name], weights=weights))
return operation.init_arguments[name]
class DifferentiableMixedRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a differentiable supernet.
Result is a weighted sum of possible prefixes, sliced by possible depths.
If the output is not a single tensor, it will be summed at every independant dimension.
See :func:`weighted_sum` for details.
"""
_arch_parameter_names: list[str] = ['_arch_alpha']
def __init__(self,
blocks: list[nn.Module],
depth: ChoiceOf[int],
softmax: nn.Module,
memo: dict[str, Any]):
super().__init__()
self.blocks = blocks
self.depth = depth
self._softmax = softmax
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._arch_alpha = nn.ParameterDict()
for name, spec in self._space_spec.items():
if name in memo:
alpha = memo[name]
if len(alpha) != spec.size:
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
self._arch_alpha[name] = alpha
def resample(self, memo):
"""Do nothing."""
return {}
def export(self, memo):
"""Choose argmax for each leaf value choice."""
result = {}
for name, spec in self._space_spec.items():
if name in memo:
continue
chosen_index = int(torch.argmax(self._arch_alpha[name]).item())
result[name] = spec.values[chosen_index]
return result
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
# Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(cast(List[nn.Module], module.blocks), module.depth_choice, softmax, memo)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
def reduction(self, items: list[Any], weights: list[float], depths: list[int]) -> Any:
"""Override this for customized reduction."""
# Use weighted_sum to handle complex cases where sequential output is not a single tensor
return weighted_sum(items, weights)
def forward(self, x):
weights: dict[str, torch.Tensor] = {
label: self._softmax(alpha) for label, alpha in self._arch_alpha.items()
}
depth_weights = dict(cast(List[Tuple[int, float]], traverse_all_options(self.depth, weights=weights)))
res: list[torch.Tensor] = []
weight_list: list[float] = []
depths: list[int] = []
for i, block in enumerate(self.blocks, start=1): # start=1 because depths are 1, 2, 3, 4...
x = block(x)
if i in depth_weights:
weight_list.append(depth_weights[i])
res.append(x)
depths.append(i)
return self.reduction(res, weight_list, depths)
class DifferentiableMixedCell(PathSamplingCell):
"""Implementation of Cell under differentiable context.
An architecture parameter is created on each edge of the full-connected graph.
"""
# TODO: It inherits :class:`PathSamplingCell` to reduce some duplicated code.
# Possibly need another refactor here.
def __init__(
self, op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor, concat_dim,
memo, mutate_kwargs, label
):
super().__init__(
op_factory, num_nodes, num_ops_per_node,
num_predecessors, preprocessor, postprocessor,
concat_dim, memo, mutate_kwargs, label
)
self._arch_alpha = nn.ParameterDict()
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for j in range(i):
edge_label = f'{label}/{i}_{j}'
op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j]
if edge_label in memo:
alpha = memo[edge_label]
if len(alpha) != len(op):
raise ValueError(
f'Architecture parameter size of same label {edge_label} conflict: '
f'{len(alpha)} vs. {len(op)}'
)
else:
alpha = nn.Parameter(torch.randn(len(op)) * 1E-3)
self._arch_alpha[edge_label] = alpha
self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
def resample(self, memo):
"""Differentiable doesn't need to resample."""
return {}
def export(self, memo):
"""Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
"""
exported = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
# Tuple of (weight, input_index, op_name)
all_weights: list[tuple[float, int, str]] = []
for j in range(i):
for k, name in enumerate(self.op_names):
all_weights.append((
float(self._arch_alpha[f'{self.label}/{i}_{j}'][k].item()),
j, name,
))
all_weights.sort(reverse=True)
# We first prefer inputs from different input_index.
# If we have got no other choices, we start to accept duplicates.
# Therefore we gather first occurrences of distinct input_index to the front.
first_occurrence_index: list[int] = [
all_weights.index( # The index of
next(filter(lambda t: t[1] == j, all_weights)) # First occurence of j
)
for j in range(i) # For j < i
]
first_occurrence_index.sort() # Keep them ordered too.
all_weights = [all_weights[k] for k in first_occurrence_index] + \
[w for j, w in enumerate(all_weights) if j not in first_occurrence_index]
_logger.info('Sorted weights in differentiable cell export (node %d): %s', i, all_weights)
for k in range(self.num_ops_per_node):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
_, j, op_name = all_weights[k % len(all_weights)]
exported[f'{self.label}/op_{i}_{k}'] = op_name
exported[f'{self.label}/input_{i}_{k}'] = j
return exported
def forward(self, *inputs: list[torch.Tensor] | torch.Tensor) -> tuple[torch.Tensor, ...] | torch.Tensor:
processed_inputs: list[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: list[torch.Tensor] = self.preprocessor(processed_inputs)
for i, ops in enumerate(cast(Sequence[Sequence[Dict[str, nn.Module]]], self.ops), start=self.num_predecessors):
current_state = []
for j in range(i): # for every previous tensors
op_results = torch.stack([op(states[j]) for op in ops[j].values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0)
current_state.append(edge_sum)
states.append(sum(current_state)) # type: ignore
# Always merge all
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
def parameters(self, *args, **kwargs):
for _, p in self.named_parameters(*args, **kwargs):
yield p
def named_parameters(self, *args, **kwargs):
arch = kwargs.pop('arch', False)
for name, p in super().named_parameters(*args, **kwargs):
if any(name.startswith(par_name) for par_name in MixedOpDifferentiablePolicy._arch_parameter_names):
if arch:
yield name, p
else:
if not arch:
yield name, p
from nni.nas.oneshot.pytorch.supermodule.differentiable import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Operations that support weight sharing at a fine-grained level,
which is commonly known as super-kernel (as in channel search), or weight entanglement.
"""
# pylint: disable=wildcard-import,unused-wildcard-import
from __future__ import annotations
import inspect
import itertools
import warnings
from typing import Any, Type, TypeVar, cast, Union, Tuple, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
import nni.retiarii.nn.pytorch as retiarii_nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from .base import BaseSuperNetModule
from ._valuechoice_utils import traverse_all_options, dedup_inner_choices, evaluate_constant
from ._operation_utils import Slicable as _S, MaybeWeighted as _W, int_or_int_dict, scalar_or_scalar_dict
T = TypeVar('T')
__all__ = [
'MixedOperationSamplingPolicy',
'MixedOperation',
'MixedLinear',
'MixedConv2d',
'MixedBatchNorm2d',
'MixedLayerNorm',
'MixedMultiHeadAttention',
'NATIVE_MIXED_OPERATIONS',
]
_diff_not_compatible_error = 'To be compatible with differentiable one-shot strategy, {} in {} must not be ValueChoice.'
class MixedOperationSamplingPolicy:
"""
Algo-related part for mixed Operation.
:class:`MixedOperation` delegates its resample and export to this policy (or its subclass),
so that one Operation can be easily combined with different kinds of sampling.
One SamplingStrategy corresponds to one mixed operation.
"""
def __init__(self, operation: 'MixedOperation', memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
"""At init, the sampling policy can prepare basic parameters,
and store them in operation if they need back propagation.
This init is called in :meth:`BaseSuperNetModule.mutate`, after the mixed operation is created.
So similar to :meth:`BaseSuperNetModule.mutate`,
memo should also be managed (read and written) by the policy itself.
"""
pass
def resample(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.resample`."""
raise NotImplementedError()
def export(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.export`."""
raise NotImplementedError()
def forward_argument(self, operation: 'MixedOperation', name: str) -> Any:
"""Computing the argument with ``name`` used in operation's forward.
Usually a value, or a distribution of value.
"""
raise NotImplementedError()
class MixedOperation(BaseSuperNetModule):
"""This is the base class for all mixed operations.
It's what you should inherit to support a new operation with ValueChoice.
It contains commonly used utilities that will ease the effort to write customized mixed oeprations,
i.e., operations with ValueChoice in its arguments.
To customize, please write your own mixed operation, and add the hook into ``mutation_hooks`` parameter when using the strategy.
By design, for a mixed operation to work in a specific algorithm,
at least two classes are needed.
1. One class needs to inherit this class, to control operation-related behavior,
such as how to initialize the operation such that the sampled operation can be its sub-operation.
2. The other one needs to inherit :class:`MixedOperationSamplingPolicy`,
which controls algo-related behavior, such as sampling.
The two classes are linked with ``sampling_policy`` attribute in :class:`MixedOperation`,
whose type is set via ``mixed_op_sampling`` in ``mutate_kwargs`` when
:meth:`MixedOperation.mutate` is called.
With this design, one mixed-operation (e.g., MixedConv2d) can work in multiple algorithms
(e.g., both DARTS and ENAS), saving the engineering effort to rewrite all operations for
each specific algo.
This class should also define a ``bound_type``, to control the matching type in mutate,
an ``argument_list``, to control which arguments can be dynamically used in ``forward``.
This list will also be used in mutate for sanity check.
"""
bound_type: Type[nn.Module] # defined in subclass
argument_list: list[str] # defined in subclass
sampling_policy: MixedOperationSamplingPolicy
def super_init_argument(self, name: str, value_choice: ValueChoiceX) -> Any:
"""Get the initialization argument when constructing super-kernel, i.e., calling ``super().__init__()``.
This is often related to specific operator, rather than algo.
For example::
def super_init_argument(self, name, value_choice):
return max(value_choice.candidates)
"""
raise NotImplementedError()
def __post_init__(self) -> None:
"""Can be used to validate, or to do extra processing after calling ``__init__``."""
pass
def forward_with_args(self, *args, **kwargs):
"""To control real fprop. The accepted arguments are ``argument_list``,
appended by forward arguments in the ``bound_type``."""
raise NotImplementedError()
def __init__(self, module_kwargs: dict[str, Any]) -> None:
# Concerned arguments
self.mutable_arguments: dict[str, ValueChoiceX] = {}
# Useful when retrieving arguments without ValueChoice
self.init_arguments: dict[str, Any] = {**module_kwargs}
self._fill_missing_init_arguments()
# get init default
super_init_kwargs = {}
for key, value in module_kwargs.items():
if isinstance(value, ValueChoiceX):
if key not in self.argument_list:
raise TypeError(f'Unsupported value choice on argument of {self.bound_type}: {key}')
super_init_kwargs[key] = self.super_init_argument(key, value)
self.mutable_arguments[key] = value
else:
super_init_kwargs[key] = value
# get all inner leaf value choices
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices(list(self.mutable_arguments.values()))
super().__init__(**super_init_kwargs)
self.__post_init__()
def resample(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
return self.sampling_policy.resample(self, memo)
def export(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
return self.sampling_policy.export(self, memo)
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
"""Find value choice in module's arguments and replace the whole module"""
has_valuechoice = False
if isinstance(module, cls.bound_type) and is_traceable(module):
for arg in itertools.chain(cast(list, module.trace_args), cast(dict, module.trace_kwargs).values()):
if isinstance(arg, ValueChoiceX):
has_valuechoice = True
if has_valuechoice:
if module.trace_args:
raise ValueError('ValueChoice on class arguments cannot appear together with ``trace_args``. '
'Please enable ``kw_only`` on nni.trace.')
# save type and kwargs
mixed_op = cls(cast(dict, module.trace_kwargs))
if 'mixed_op_sampling' not in mutate_kwargs:
raise ValueError("Need a sampling policy for mixed op, but it's not found in `mutate_kwargs`.")
policy_cls: Type[MixedOperationSamplingPolicy] = mutate_kwargs['mixed_op_sampling']
# initialize policy class
# this is put in mutate because we need to access memo
mixed_op.sampling_policy = policy_cls(mixed_op, memo, mutate_kwargs)
return mixed_op
def forward_argument(self, name: str) -> Any:
"""Get the argument used in forward.
This if often related to algo. We redirect this to sampling policy.
"""
return self.sampling_policy.forward_argument(self, name)
def forward(self, *args, **kwargs):
"""First get sampled arguments, then forward with the sampled arguments (by calling ``forward_with_args``)."""
sampled_args = [self.forward_argument(name) for name in self.argument_list]
return self.forward_with_args(*sampled_args, *args, **kwargs)
def _fill_missing_init_arguments(self) -> None:
"""Set the unspecified init arguments in ``self.init_arguments``.
For example, in the case of Conv2d, when user didn't specify argument ``stride``,
this method adds ``stride = 1`` in ``self.init_arguments``.
This is implemented by inspecting the init signature of ``bound_type``.
Arguments in complex cases like ``__new__`` or in super-class is not supported.
"""
def unwrap(cls):
if not hasattr(cls, '__wrapped__'):
return cls
return unwrap(cls.__wrapped__)
for param in inspect.signature(unwrap(self.bound_type).__init__).parameters.values():
if param.default is not param.empty and param.name not in self.init_arguments:
self.init_arguments[param.name] = param.default
class MixedLinear(MixedOperation, nn.Linear):
"""Mixed linear operation.
Supported arguments are:
- ``in_features``
- ``out_features``
Prefix of weight and bias will be sliced.
"""
bound_type = retiarii_nn.Linear
argument_list = ['in_features', 'out_features']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def forward_with_args(self,
in_features: int_or_int_dict,
out_features: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
in_features_ = _W(in_features)
out_features_ = _W(out_features)
weight = _S(self.weight)[:out_features_]
weight = _S(weight)[:, :in_features_]
if self.bias is None:
bias = self.bias
else:
bias = _S(self.bias)[:out_features_]
return F.linear(inputs, weight, bias)
_int_or_tuple = Union[int, Tuple[int, int]]
class MixedConv2d(MixedOperation, nn.Conv2d):
"""Mixed conv2d op.
Supported arguments are:
- ``in_channels``
- ``out_channels``
- ``groups``
- ``stride`` (only supported in path sampling)
- ``kernel_size``
- ``padding``
- ``dilation`` (only supported in path sampling)
``padding`` will be the "max" padding in differentiable mode.
Mutable ``groups`` is NOT supported in most cases of differentiable mode.
However, we do support one special case when the group number is proportional to ``in_channels`` and ``out_channels``.
This is often the case of depth-wise convolutions.
For channels, prefix will be sliced.
For kernels, we take the small kernel from the center and round it to floor (left top). For example ::
max_kernel = 5*5, sampled_kernel = 3*3, then we take [1: 4]
max_kernel = 5*5, sampled_kernel = 2*2, then we take [1: 3]
□ □ □ □ □ □ □ □ □ □
□ ■ ■ ■ □ □ ■ ■ □ □
□ ■ ■ ■ □ □ ■ ■ □ □
□ ■ ■ ■ □ □ □ □ □ □
□ □ □ □ □ □ □ □ □ □
"""
bound_type = retiarii_nn.Conv2d
argument_list = [
'in_channels', 'out_channels', 'kernel_size', 'stride', 'padding', 'dilation', 'groups'
]
@staticmethod
def _to_tuple(value: scalar_or_scalar_dict[Any]) -> tuple[Any, Any]:
if not isinstance(value, tuple):
return (value, value)
return value
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
if name not in ['in_channels', 'out_channels', 'groups', 'stride', 'kernel_size', 'padding', 'dilation']:
raise NotImplementedError(f'Unsupported value choice on argument: {name}')
if name == ['kernel_size', 'padding']:
all_sizes = set(traverse_all_options(value_choice))
if any(isinstance(sz, tuple) for sz in all_sizes):
# maximum kernel should be calculated on every dimension
return (
max(self._to_tuple(sz)[0] for sz in all_sizes),
max(self._to_tuple(sz)[1] for sz in all_sizes)
)
else:
return max(all_sizes)
elif name == 'groups':
if 'in_channels' in self.mutable_arguments:
# If the ratio is constant, we don't need to try the maximum groups.
try:
constant = evaluate_constant(self.mutable_arguments['in_channels'] / value_choice)
return max(cast(List[float], traverse_all_options(value_choice))) // int(constant)
except ValueError:
warnings.warn(
'Both input channels and groups are ValueChoice in a convolution, and their relative ratio is not a constant. '
'This can be problematic for most one-shot algorithms. Please check whether this is your intention.',
RuntimeWarning
)
# minimum groups, maximum kernel
return min(traverse_all_options(value_choice))
else:
return max(traverse_all_options(value_choice))
def forward_with_args(self,
in_channels: int_or_int_dict,
out_channels: int_or_int_dict,
kernel_size: scalar_or_scalar_dict[_int_or_tuple],
stride: _int_or_tuple,
padding: scalar_or_scalar_dict[_int_or_tuple],
dilation: int,
groups: int_or_int_dict,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [stride, dilation]):
raise ValueError(_diff_not_compatible_error.format('stride, dilation', 'Conv2d'))
in_channels_ = _W(in_channels)
out_channels_ = _W(out_channels)
# slice prefix
# For groups > 1, we use groups to slice input weights
weight = _S(self.weight)[:out_channels_]
if not isinstance(groups, dict):
weight = _S(weight)[:, :in_channels_ // groups]
else:
assert 'groups' in self.mutable_arguments
err_message = 'For differentiable one-shot strategy, when groups is a ValueChoice, ' \
'in_channels and out_channels should also be a ValueChoice. ' \
'Also, the ratios of in_channels divided by groups, and out_channels divided by groups ' \
'should be constants.'
if 'in_channels' not in self.mutable_arguments or 'out_channels' not in self.mutable_arguments:
raise ValueError(err_message)
try:
in_channels_per_group = evaluate_constant(self.mutable_arguments['in_channels'] / self.mutable_arguments['groups'])
except ValueError:
raise ValueError(err_message)
if in_channels_per_group != int(in_channels_per_group):
raise ValueError(f'Input channels per group is found to be a non-integer: {in_channels_per_group}')
if inputs.size(1) % in_channels_per_group != 0:
raise RuntimeError(
f'Input channels must be divisible by in_channels_per_group, but the input shape is {inputs.size()}, '
f'while in_channels_per_group = {in_channels_per_group}'
)
# Compute sliced weights and groups (as an integer)
weight = _S(weight)[:, :int(in_channels_per_group)]
groups = inputs.size(1) // int(in_channels_per_group)
# slice center
if isinstance(kernel_size, dict):
# If kernel size is a dict, ignore choices in padding.
if isinstance(self.padding, str):
raise ValueError(f'Use "{self.padding}" in padding is not supported.')
padding = self.padding # max padding, must be a tuple
kernel_a, kernel_b = self._to_tuple(kernel_size)
kernel_a_, kernel_b_ = _W(kernel_a), _W(kernel_b)
max_kernel_a, max_kernel_b = self.kernel_size # self.kernel_size must be a tuple
kernel_a_left, kernel_b_top = (max_kernel_a - kernel_a_) // 2, (max_kernel_b - kernel_b_) // 2
weight = _S(weight)[:, :, kernel_a_left:kernel_a_left + kernel_a_, kernel_b_top:kernel_b_top + kernel_b_]
bias = _S(self.bias)[:out_channels_] if self.bias is not None else None
# The rest parameters only need to be converted to tuple
stride_ = self._to_tuple(stride)
dilation_ = self._to_tuple(dilation)
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(inputs, self._reversed_padding_repeated_twice, mode=self.padding_mode),
weight, bias, stride_, (0, 0), dilation_, groups)
return F.conv2d(inputs, weight, bias, stride_, cast('int | tuple', padding), dilation_, groups)
class MixedBatchNorm2d(MixedOperation, nn.BatchNorm2d):
"""
Mixed BatchNorm2d operation.
Supported arguments are:
- ``num_features``
- ``eps`` (only supported in path sampling)
- ``momentum`` (only supported in path sampling)
For path-sampling, prefix of ``weight``, ``bias``, ``running_mean`` and ``running_var``
are sliced. For weighted cases, the maximum ``num_features`` is used directly.
Momentum is required to be float.
PyTorch BatchNorm supports a case where momentum can be none, which is not supported here.
"""
bound_type = retiarii_nn.BatchNorm2d
argument_list = ['num_features', 'eps', 'momentum']
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def forward_with_args(self,
num_features: int_or_int_dict,
eps: float,
momentum: float,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [eps, momentum]):
raise ValueError(_diff_not_compatible_error.format('eps and momentum', 'BatchNorm2d'))
if isinstance(num_features, dict):
num_features = self.num_features
weight, bias = self.weight, self.bias
running_mean, running_var = self.running_mean, self.running_var
if num_features < self.num_features:
weight = weight[:num_features]
bias = bias[:num_features]
if running_mean is not None:
running_mean = running_mean[:num_features]
if running_var is not None:
running_var = running_var[:num_features]
if self.training:
bn_training = True
else:
bn_training = (running_mean is None) and (running_var is None)
return F.batch_norm(
inputs,
# If buffers are not to be tracked, ensure that they won't be updated
running_mean if not self.training or self.track_running_stats else None,
running_var if not self.training or self.track_running_stats else None,
weight,
bias,
bn_training,
momentum, # originally exponential_average_factor in pytorch code
eps,
)
class MixedLayerNorm(MixedOperation, nn.LayerNorm):
"""
Mixed LayerNorm operation.
Supported arguments are:
- ``normalized_shape``
- ``eps`` (only supported in path sampling)
For path-sampling, prefix of ``weight`` and ``bias`` are sliced.
For weighted cases, the maximum ``normalized_shape`` is used directly.
eps is required to be float.
"""
bound_type = retiarii_nn.LayerNorm
argument_list = ['normalized_shape', 'eps']
@staticmethod
def _to_tuple(value: scalar_or_scalar_dict[Any]) -> tuple[Any, Any]:
if not isinstance(value, tuple):
return (value, value)
return value
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
if name not in ['normalized_shape']:
raise NotImplementedError(f'Unsupported value choice on argument: {name}')
all_sizes = set(traverse_all_options(value_choice))
if any(isinstance(sz, (tuple, list)) for sz in all_sizes):
# transpose
all_sizes = list(zip(*all_sizes))
# maximum dim should be calculated on every dimension
return (max(self._to_tuple(sz)) for sz in all_sizes)
else:
return max(all_sizes)
def forward_with_args(self,
normalized_shape,
eps: float,
inputs: torch.Tensor) -> torch.Tensor:
if any(isinstance(arg, dict) for arg in [eps]):
raise ValueError(_diff_not_compatible_error.format('eps', 'LayerNorm'))
if isinstance(normalized_shape, dict):
normalized_shape = self.normalized_shape
# make it as tuple
if isinstance(normalized_shape, int):
normalized_shape = (normalized_shape, )
if isinstance(self.normalized_shape, int):
normalized_shape = (self.normalized_shape, )
# slice all the normalized shape
indices = [slice(0, min(i, j)) for i, j in zip(normalized_shape, self.normalized_shape)]
# remove _S(*)
weight = self.weight[indices] if self.weight is not None else None
bias = self.bias[indices] if self.bias is not None else None
return F.layer_norm(
inputs,
normalized_shape,
weight,
bias,
eps
)
class MixedMultiHeadAttention(MixedOperation, nn.MultiheadAttention):
"""
Mixed multi-head attention.
Supported arguments are:
- ``embed_dim``
- ``num_heads`` (only supported in path sampling)
- ``kdim``
- ``vdim``
- ``dropout`` (only supported in path sampling)
At init, it constructs the largest possible Q, K, V dimension.
At forward, it slices the prefix to weight matrices according to the sampled value.
For ``in_proj_bias`` and ``in_proj_weight``, three parts will be sliced and concatenated together:
``[0, embed_dim)``, ``[max_embed_dim, max_embed_dim + embed_dim)``,
``[max_embed_dim * 2, max_embed_dim * 2 + embed_dim)``.
Warnings
----------
All candidates of ``embed_dim`` should be divisible by all candidates of ``num_heads``.
"""
bound_type = retiarii_nn.MultiheadAttention
argument_list = ['embed_dim', 'num_heads', 'kdim', 'vdim', 'dropout']
def __post_init__(self):
# sometimes super-class believes qkv have the same embed_dim.
# but actually they do not, because we can have dynamic (mutable) kdim/vdim.
_qkv_same_embed_dim = True
for dimension in ['kdim', 'vdim']:
if self.init_arguments[dimension] is None:
# must follow embed_dim is this case
continue
if getattr(self, dimension) == self.embed_dim and \
(dimension in self.mutable_arguments or 'embed_dim' in self.mutable_arguments):
_qkv_same_embed_dim = False
if self._qkv_same_embed_dim and not _qkv_same_embed_dim:
self._qkv_same_embed_dim = _qkv_same_embed_dim
# adding back missing parameters
# factory_kwargs could be empty for legacy pytorch versions
factory_kwargs = {}
if 'device' in self.init_arguments:
factory_kwargs['device'] = self.init_arguments['device']
if 'dtype' in self.init_arguments:
factory_kwargs['dtype'] = self.init_arguments['dtype']
self.q_proj_weight = nn.Parameter(torch.empty((self.embed_dim, self.embed_dim), **factory_kwargs))
self.k_proj_weight = nn.Parameter(torch.empty((self.embed_dim, self.kdim), **factory_kwargs))
self.v_proj_weight = nn.Parameter(torch.empty((self.embed_dim, self.vdim), **factory_kwargs))
self.register_parameter('in_proj_weight', None)
# reset parameters
nn.init.xavier_uniform_(self.q_proj_weight)
nn.init.xavier_uniform_(self.k_proj_weight)
nn.init.xavier_uniform_(self.v_proj_weight)
def super_init_argument(self, name: str, value_choice: ValueChoiceX):
return max(traverse_all_options(value_choice))
def _to_proj_slice(self, embed_dim: _W) -> list[slice]:
# slice three parts, corresponding to q, k, v respectively
return [
slice(embed_dim),
slice(self.embed_dim, self.embed_dim + embed_dim),
slice(self.embed_dim * 2, self.embed_dim * 2 + embed_dim)
]
def forward_with_args(
self,
embed_dim: int_or_int_dict, num_heads: int,
kdim: int_or_int_dict | None, vdim: int_or_int_dict | None,
dropout: float,
query: torch.Tensor, key: torch.Tensor, value: torch.Tensor,
key_padding_mask: torch.Tensor | None = None,
need_weights: bool = True, attn_mask: torch.Tensor | None = None
) -> tuple[torch.Tensor, torch.Tensor | None]:
if any(isinstance(arg, dict) for arg in [num_heads, dropout]):
raise ValueError(_diff_not_compatible_error.format('num_heads and dropout', 'MultiHeadAttention'))
# by default, kdim, vdim can be none
if kdim is None:
kdim = embed_dim
if vdim is None:
vdim = embed_dim
qkv_same_embed_dim = kdim == embed_dim and vdim == embed_dim
if getattr(self, 'batch_first', False):
# for backward compatibility: v1.7 doesn't have batch_first
query, key, value = [x.transpose(1, 0) for x in (query, key, value)]
if isinstance(embed_dim, dict):
used_embed_dim = self.embed_dim
else:
used_embed_dim = embed_dim
embed_dim_ = _W(embed_dim)
# in projection weights & biases has q, k, v weights concatenated together
in_proj_bias: Tensor | None = None
in_proj_weight: Tensor | None = None
if self.in_proj_bias is not None:
in_proj_bias = _S(cast(Tensor, self.in_proj_bias))[self._to_proj_slice(embed_dim_)]
if self.in_proj_weight is not None:
in_proj_weight = _S(cast(Tensor, self.in_proj_weight))[self._to_proj_slice(embed_dim_), :embed_dim_]
bias_k = _S(cast(Tensor, self.bias_k))[:, :, :embed_dim_] if self.bias_k is not None else None
bias_v = _S(cast(Tensor, self.bias_v))[:, :, :embed_dim_] if self.bias_v is not None else None
out_proj_weight = _S(cast(Tensor, self.out_proj.weight))[:embed_dim_, :embed_dim_]
out_proj_bias = _S(cast(Tensor, self.out_proj.bias))[:embed_dim_] if self.out_proj.bias is not None else None
if not qkv_same_embed_dim:
q_proj = _S(cast(Tensor, self.q_proj_weight))[:embed_dim_, :embed_dim_]
k_proj = _S(cast(Tensor, self.k_proj_weight))[:embed_dim_]
k_proj = _S(k_proj)[:, :_W(kdim)]
v_proj = _S(cast(Tensor, self.v_proj_weight))[:embed_dim_]
v_proj = _S(v_proj)[:, :_W(vdim)]
# The rest part is basically same as pytorch
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads,
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask, use_separate_proj_weight=True,
q_proj_weight=q_proj, k_proj_weight=k_proj, v_proj_weight=v_proj)
else:
# Cast tensor here because of a bug in pytorch stub
attn_output, attn_output_weights = F.multi_head_attention_forward(
query, key, value, used_embed_dim, num_heads,
cast(Tensor, in_proj_weight), cast(Tensor, in_proj_bias),
bias_k, bias_v, self.add_zero_attn,
dropout, out_proj_weight, cast(Tensor, out_proj_bias),
training=self.training,
key_padding_mask=key_padding_mask, need_weights=need_weights,
attn_mask=attn_mask)
if getattr(self, 'batch_first', False): # backward compatibility
return attn_output.transpose(1, 0), attn_output_weights
else:
return attn_output, attn_output_weights
NATIVE_MIXED_OPERATIONS: list[Type[MixedOperation]] = [
MixedLinear,
MixedConv2d,
MixedBatchNorm2d,
MixedLayerNorm,
MixedMultiHeadAttention,
]
# For the supported operations to be properly rendered in documentation
NATIVE_SUPPORTED_OP_NAMES: list[str] = [op.bound_type.__name__ for op in NATIVE_MIXED_OPERATIONS]
from nni.nas.oneshot.pytorch.supermodule.operation import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""Implementation of ProxylessNAS: a hyrbid approach between differentiable and sampling.
The support remains limited. Known limitations include:
# pylint: disable=wildcard-import,unused-wildcard-import
- No support for multiple arguments in forward.
- No support for mixed-operation (value choice).
- The code contains duplicates. Needs refactor.
"""
from __future__ import annotations
from typing import cast
import torch
import torch.nn as nn
from .differentiable import DifferentiableMixedLayer, DifferentiableMixedInput
__all__ = ['ProxylessMixedLayer', 'ProxylessMixedInput']
class _ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = x.detach()
detached_x.requires_grad = x.requires_grad
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class ProxylessMixedLayer(DifferentiableMixedLayer):
"""Proxyless version of differentiable mixed layer.
It resamples a single-path every time, rather than go through the softmax.
"""
_arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(paths, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(len(paths)) * 1E-3)
# like sampling-based methods, it has a ``_sampled``.
self._sampled: str | None = None
self._sample_idx: int | None = None
def forward(self, *args, **kwargs):
def run_function(ops, active_id, **kwargs):
def forward(_x):
return ops[active_id](_x, **kwargs)
return forward
def backward_function(ops, active_id, binary_gates, **kwargs):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data, **kwargs)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1, 'ProxylessMixedLayer only supports exactly one input argument.'
x = args[0]
assert self._sampled is not None, 'Need to call resample() before running fprop.'
list_ops = [getattr(self, op) for op in self.op_names]
return _ArchGradientFunction.apply(
x, self._binary_gates, run_function(list_ops, self._sample_idx, **kwargs),
backward_function(list_ops, self._sample_idx, self._binary_gates, **kwargs)
)
def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo."""
if self.label in memo:
self._sampled = memo[self.label]
self._sample_idx = self.op_names.index(self._sampled)
else:
probs = self._softmax(self._arch_alpha)
self._sample_idx = int(torch.multinomial(probs, 1)[0].item())
self._sampled = self.op_names[self._sample_idx]
# set binary gates
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[self._sample_idx] = 1.0
return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(len(self._arch_alpha)):
for j in range(len(self._arch_alpha)):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessMixedInput(DifferentiableMixedInput):
"""Proxyless version of differentiable input choice.
See :class:`ProxylessLayerChoice` for implementation details.
"""
_arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(n_candidates, n_chosen, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(n_candidates) * 1E-3)
self._sampled: int | None = None
def forward(self, inputs):
def run_function(active_sample):
return lambda x: x[active_sample]
def backward_function(binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(self.n_candidates):
out_k = _x[k].data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
inputs = torch.stack(inputs, 0)
assert self._sampled is not None, 'Need to call resample() before running fprop.'
return _ArchGradientFunction.apply(
inputs, self._binary_gates, run_function(self._sampled),
backward_function(self._binary_gates)
)
def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo."""
if self.label in memo:
self._sampled = memo[self.label]
else:
probs = self._softmax(self._arch_alpha)
sample = torch.multinomial(probs, 1)[0].item()
self._sampled = int(sample)
# set binary gates
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[cast(int, self._sampled)] = 1.0
return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: torch.argmax(self._arch_alpha).item()}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(self.n_candidates):
for j in range(self.n_candidates):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
from nni.nas.oneshot.pytorch.supermodule.proxyless import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
# pylint: disable=wildcard-import,unused-wildcard-import
import copy
import random
from typing import Any, List, Dict, Sequence, cast
import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, Repeat, ChoiceOf, Cell
from nni.retiarii.nn.pytorch.api import ValueChoiceX
from nni.retiarii.nn.pytorch.cell import CellOpFactory, create_cell_op_candidates, preprocess_cell_inputs
from .base import BaseSuperNetModule
from ._valuechoice_utils import evaluate_value_choice_with_dict, dedup_inner_choices, weighted_sum
from .operation import MixedOperationSamplingPolicy, MixedOperation
__all__ = [
'PathSamplingLayer', 'PathSamplingInput',
'PathSamplingRepeat', 'PathSamplingCell',
'MixedOpPathSamplingPolicy'
]
class PathSamplingLayer(BaseSuperNetModule):
"""
Mixed layer, in which fprop is decided by exactly one inner layer or sum of multiple (sampled) layers.
If multiple modules are selected, the result will be summed and returned.
Attributes
----------
_sampled : int or list of str
Sampled module indices.
label : str
Name of the choice.
"""
def __init__(self, paths: list[tuple[str, nn.Module]], label: str):
super().__init__()
self.op_names = []
for name, module in paths:
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self._sampled: list[str] | str | None = None # sampled can be either a list of indices or an index
self.label = label
def resample(self, memo):
"""Random choose one path if label is not found in memo."""
if self.label in memo:
self._sampled = memo[self.label]
else:
self._sampled = random.choice(self.op_names)
return {self.label: self._sampled}
def export(self, memo):
"""Random choose one name if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: random.choice(self.op_names)}
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, LayerChoice):
return cls(list(module.named_children()), module.label)
def reduction(self, items: list[Any], sampled: list[Any]):
"""Override this to implement customized reduction."""
return weighted_sum(items)
def forward(self, *args, **kwargs):
if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
# str(samp) is needed here because samp can sometimes be integers, but attr are always str
res = [getattr(self, str(samp))(*args, **kwargs) for samp in sampled]
return self.reduction(res, sampled)
class PathSamplingInput(BaseSuperNetModule):
"""
Mixed input. Take a list of tensor as input, select some of them and return the sum.
Attributes
----------
_sampled : int or list of int
Sampled input indices.
"""
def __init__(self, n_candidates: int, n_chosen: int, reduction_type: str, label: str):
super().__init__()
self.n_candidates = n_candidates
self.n_chosen = n_chosen
self.reduction_type = reduction_type
self._sampled: list[int] | int | None = None
self.label = label
def _random_choose_n(self):
sampling = list(range(self.n_candidates))
random.shuffle(sampling)
sampling = sorted(sampling[:self.n_chosen])
if len(sampling) == 1:
return sampling[0]
else:
return sampling
def resample(self, memo):
"""Random choose one path / multiple paths if label is not found in memo.
If one path is selected, only one integer will be in ``self._sampled``.
If multiple paths are selected, a list will be in ``self._sampled``.
"""
if self.label in memo:
self._sampled = memo[self.label]
else:
self._sampled = self._random_choose_n()
return {self.label: self._sampled}
def export(self, memo):
"""Random choose one name if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self._random_choose_n()}
def search_space_spec(self):
return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
(self.label, ), True, size=self.n_candidates, chosen_size=self.n_chosen)
}
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, InputChoice):
if module.reduction not in ['sum', 'mean', 'concat']:
raise ValueError('Only input choice of sum/mean/concat reduction is supported.')
if module.n_chosen is None:
raise ValueError('n_chosen is None is not supported yet.')
return cls(module.n_candidates, module.n_chosen, module.reduction, module.label)
def reduction(self, items: list[Any], sampled: list[Any]) -> Any:
"""Override this to implement customized reduction."""
if len(items) == 1:
return items[0]
else:
if self.reduction_type == 'sum':
return sum(items)
elif self.reduction_type == 'mean':
return sum(items) / len(items)
elif self.reduction_type == 'concat':
return torch.cat(items, 1)
raise ValueError(f'Unsupported reduction type: {self.reduction_type}')
def forward(self, input_tensors):
if self._sampled is None:
raise RuntimeError('At least one path needs to be sampled before fprop.')
if len(input_tensors) != self.n_candidates:
raise ValueError(f'Expect {self.n_candidates} input tensors, found {len(input_tensors)}.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
res = [input_tensors[samp] for samp in sampled]
return self.reduction(res, sampled)
class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
"""Implementes the path sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Each value choice can be further decomposed into "leaf value choices".
We sample the leaf nodes, and composits them into the values on arguments.
"""
def __init__(self, operation: MixedOperation, memo: dict[str, Any], mutate_kwargs: dict[str, Any]) -> None:
# Sampling arguments. This should have the same keys with `operation.mutable_arguments`
self._sampled: dict[str, Any] | None = None
def resample(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Random sample for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
for label in space_spec:
if label in memo:
result[label] = memo[label]
else:
result[label] = random.choice(space_spec[label].values)
# composits to kwargs
# example: result = {"exp_ratio": 3}, self._sampled = {"in_channels": 48, "out_channels": 96}
self._sampled = {}
for key, value in operation.mutable_arguments.items():
self._sampled[key] = evaluate_value_choice_with_dict(value, result)
return result
def export(self, operation: MixedOperation, memo: dict[str, Any]) -> dict[str, Any]:
"""Export is also random for each leaf value choice."""
result = {}
space_spec = operation.search_space_spec()
for label in space_spec:
if label not in memo:
result[label] = random.choice(space_spec[label].values)
return result
def forward_argument(self, operation: MixedOperation, name: str) -> Any:
# NOTE: we don't support sampling a list here.
if self._sampled is None:
raise ValueError('Need to call resample() before running forward')
if name in operation.mutable_arguments:
return self._sampled[name]
return operation.init_arguments[name]
class PathSamplingRepeat(BaseSuperNetModule):
"""
Implementaion of Repeat in a path-sampling supernet.
Samples one / some of the prefixes of the repeated blocks.
Attributes
----------
_sampled : int or list of int
Sampled depth.
"""
def __init__(self, blocks: list[nn.Module], depth: ChoiceOf[int]):
super().__init__()
self.blocks = blocks
self.depth = depth
self._space_spec: dict[str, ParameterSpec] = dedup_inner_choices([depth])
self._sampled: list[int] | int | None = None
def resample(self, memo):
"""Since depth is based on ValueChoice, we only need to randomly sample every leaf value choices."""
result = {}
for label in self._space_spec:
if label in memo:
result[label] = memo[label]
else:
result[label] = random.choice(self._space_spec[label].values)
self._sampled = evaluate_value_choice_with_dict(self.depth, result)
return result
def export(self, memo):
"""Random choose one if every choice not in memo."""
result = {}
for label in self._space_spec:
if label not in memo:
result[label] = random.choice(self._space_spec[label].values)
return result
def search_space_spec(self):
return self._space_spec
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Repeat) and isinstance(module.depth_choice, ValueChoiceX):
# Only interesting when depth is mutable
return cls(cast(List[nn.Module], module.blocks), module.depth_choice)
def reduction(self, items: list[Any], sampled: list[Any]):
"""Override this to implement customized reduction."""
return weighted_sum(items)
def forward(self, x):
if self._sampled is None:
raise RuntimeError('At least one depth needs to be sampled before fprop.')
sampled = [self._sampled] if not isinstance(self._sampled, list) else self._sampled
res = []
for cur_depth, block in enumerate(self.blocks, start=1):
x = block(x)
if cur_depth in sampled:
res.append(x)
if not any(d > cur_depth for d in sampled):
break
return self.reduction(res, sampled)
class PathSamplingCell(BaseSuperNetModule):
"""The implementation of super-net cell follows `DARTS <https://github.com/quark0/darts>`__.
When ``factory_used`` is true, it reconstructs the cell for every possible combination of operation and input index,
because for different input index, the cell factory could instantiate different operations (e.g., with different stride).
On export, we first have best (operation, input) pairs, the select the best ``num_ops_per_node``.
``loose_end`` is not supported yet, because it will cause more problems (e.g., shape mismatch).
We assumes ``loose_end`` to be ``all`` regardless of its configuration.
A supernet cell can't slim its own weight to fit into a sub network, which is also a known issue.
"""
def __init__(
self,
op_factory: list[CellOpFactory] | dict[str, CellOpFactory],
num_nodes: int,
num_ops_per_node: int,
num_predecessors: int,
preprocessor: Any,
postprocessor: Any,
concat_dim: int,
memo: dict, # although not used here, useful in subclass
mutate_kwargs: dict, # same as memo
label: str,
):
super().__init__()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
self.num_predecessors = num_predecessors
self.preprocessor = preprocessor
self.ops = nn.ModuleList()
self.postprocessor = postprocessor
self.concat_dim = concat_dim
self.op_names: list[str] = cast(List[str], None)
self.output_node_indices = list(range(self.num_predecessors, self.num_nodes + self.num_predecessors))
# Create a fully-connected graph.
# Each edge is a ModuleDict with op candidates.
# Can not reuse LayerChoice here, because the spec, resample, export all need to be customized.
# InputChoice is implicit in this graph.
for i in self.output_node_indices:
self.ops.append(nn.ModuleList())
for k in range(i + self.num_predecessors):
# Second argument in (i, **0**, k) is always 0.
# One-shot strategy can't handle the cases where op spec is dependent on `op_index`.
ops, _ = create_cell_op_candidates(op_factory, i, 0, k)
self.op_names = list(ops.keys())
cast(nn.ModuleList, self.ops[-1]).append(nn.ModuleDict(ops))
self.label = label
self._sampled: dict[str, str | int] = {}
def search_space_spec(self) -> dict[str, ParameterSpec]:
# TODO: Recreating the space here.
# The spec should be moved to definition of Cell itself.
space_spec = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
for k in range(self.num_ops_per_node):
op_label = f'{self.label}/op_{i}_{k}'
input_label = f'{self.label}/input_{i}_{k}'
space_spec[op_label] = ParameterSpec(op_label, 'choice', self.op_names, (op_label,), True, size=len(self.op_names))
space_spec[input_label] = ParameterSpec(input_label, 'choice', list(range(i)), (input_label, ), True, size=i)
return space_spec
def resample(self, memo):
"""Random choose one path if label is not found in memo."""
self._sampled = {}
new_sampled = {}
for label, param_spec in self.search_space_spec().items():
if label in memo:
assert not isinstance(memo[label], list), 'Multi-path sampling is currently unsupported on cell.'
self._sampled[label] = memo[label]
else:
self._sampled[label] = new_sampled[label] = random.choice(param_spec.values)
return new_sampled
def export(self, memo):
"""Randomly choose one to export."""
return self.resample(memo)
def forward(self, *inputs: list[torch.Tensor] | torch.Tensor) -> tuple[torch.Tensor, ...] | torch.Tensor:
processed_inputs: List[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: List[torch.Tensor] = self.preprocessor(processed_inputs)
for i, ops in enumerate(cast(Sequence[Sequence[Dict[str, nn.Module]]], self.ops), start=self.num_predecessors):
current_state = []
for k in range(self.num_ops_per_node):
# Select op list based on the input chosen
input_index = self._sampled[f'{self.label}/input_{i}_{k}']
op_candidates = ops[cast(int, input_index)]
# Select op from op list based on the op chosen
op_index = self._sampled[f'{self.label}/op_{i}_{k}']
op = op_candidates[cast(str, op_index)]
current_state.append(op(states[cast(int, input_index)]))
states.append(sum(current_state)) # type: ignore
# Always merge all
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, Cell):
op_factory = None # not all the cells need to be replaced
if module.op_candidates_factory is not None:
op_factory = module.op_candidates_factory
assert isinstance(op_factory, list) or isinstance(op_factory, dict), \
'Only support op_factory of type list or dict.'
elif module.merge_op == 'loose_end':
op_candidates_lc = module.ops[-1][-1] # type: ignore
assert isinstance(op_candidates_lc, LayerChoice)
op_factory = { # create a factory
name: lambda _, __, ___: copy.deepcopy(op_candidates_lc[name])
for name in op_candidates_lc.names
}
if op_factory is not None:
return cls(
op_factory,
module.num_nodes,
module.num_ops_per_node,
module.num_predecessors,
module.preprocessor,
module.postprocessor,
module.concat_dim,
memo,
mutate_kwargs,
module.label
)
from nni.nas.oneshot.pytorch.supermodule.sampling import *
......@@ -12,7 +12,6 @@ import torch
from torch.utils.data import DataLoader, Dataset
import nni.retiarii.nn.pytorch as nn
from nni.nas.pytorch.mutables import InputChoice, LayerChoice
_logger = logging.getLogger(__name__)
......@@ -163,7 +162,7 @@ def replace_layer_choice(root_module, init_fn, modules=None):
list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, (LayerChoice, nn.LayerChoice), modules)
return _replace_module_with_type(root_module, init_fn, nn.LayerChoice, modules)
def replace_input_choice(root_module, init_fn, modules=None):
......@@ -184,7 +183,7 @@ def replace_input_choice(root_module, init_fn, modules=None):
list[tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules)
return _replace_module_with_type(root_module, init_fn, nn.InputChoice, modules)
class InterleavedTrainValDataLoader(DataLoader):
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Dict, List, Optional, cast)
# pylint: disable=wildcard-import,unused-wildcard-import
from . import debug_configs
__all__ = ['Operation', 'Cell']
def _convert_name(name: str) -> str:
"""
Convert the names using separator '.' to valid variable name in code
"""
return name.replace('.', '__')
class Operation:
"""
Calculation logic of a graph node.
The constructor is private. Use `Operation.new()` to create operation object.
`Operation` is a naive record.
Do not "mutate" its attributes or store information relate to specific node.
All complex logic should be implemented in `Node` class.
Attributes
----------
type
Operation type name (e.g. Conv2D).
If it starts with underscore, the "operation" is a special one (e.g. subgraph, input/output).
parameters
Arbitrary key-value parameters (e.g. kernel_size).
"""
io_names: List[str] = []
def __init__(self, type_name: str, parameters: Dict[str, Any] = {}, _internal: bool = False, attributes: Dict[str, Any] = {}):
assert _internal, '`Operation()` is private, use `Operation.new()` instead'
self.type: str = type_name
self.parameters: Dict[str, Any] = parameters
self.attributes: Dict[str, Any] = attributes
def to_init_code(self, field: str) -> str:
raise NotImplementedError()
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise NotImplementedError()
def _to_class_name(self) -> str:
raise NotImplementedError()
def __bool__(self) -> bool:
return True
@staticmethod
def new(type_name: str, parameters: Dict[str, Any] = cast(Dict[str, Any], None), cell_name: str = cast(str, None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)) -> 'Operation':
parameters = parameters or {}
attributes = attributes or {}
if type_name == '_cell':
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node
return Cell(cell_name, parameters)
else:
if debug_configs.framework.lower() in ('torch', 'pytorch'):
from .operation_def import torch_op_def # pylint: disable=unused-import
cls = PyTorchOperation._find_subclass(type_name)
elif debug_configs.framework.lower() in ('tf', 'tensorflow'):
from .operation_def import tf_op_def # pylint: disable=unused-import
cls = TensorFlowOperation._find_subclass(type_name)
else:
raise ValueError(f'Unsupported framework: {debug_configs.framework}')
return cls(type_name, parameters, _internal=True, attributes=attributes)
@classmethod
def _find_subclass(cls, subclass_name):
for subclass in cls.__subclasses__():
if subclass.__name__ == subclass_name:
return subclass
return cls
def __repr__(self):
type_name = type(self).__name__
args = [f'{key}={repr(value)}' for key, value in self.parameters.items()]
if type_name != self.type:
args = [f'type="{self.type}"'] + args
return f'{type_name}({", ".join(args)})'
def __eq__(self, other):
return type(other) is type(self) and other.type == self.type and other.parameters == self.parameters
class PyTorchOperation(Operation):
@classmethod
def _find_subclass(cls, subclass_name):
if cls.to_class_name(subclass_name) is not None:
subclass_name = 'ModuleOperator'
if cls.is_functional(subclass_name):
subclass_name = 'FunctionalOperator'
for subclass in cls.__subclasses__():
if hasattr(subclass, '_ori_type_name') and \
subclass_name in cast(Any, subclass)._ori_type_name:
return subclass
for subclass in cls.__subclasses__():
if hasattr(subclass, '_artificial_op_name') and \
subclass_name in cast(Any, subclass)._artificial_op_name:
return subclass
return cls
@classmethod
def to_class_name(cls, type_name) -> Optional[str]:
if type_name.startswith('__torch__.'):
return type_name[len('__torch__.'):]
elif type_name.startswith('__mutated__.'):
return type_name[len('__mutated__.'):]
else:
return None
@classmethod
def is_functional(cls, type_name) -> bool:
return type_name.startswith('Function.')
def _to_class_name(self) -> Optional[str]:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):]
elif self.type.startswith('__mutated__.'):
return self.type[len('__mutated__.'):]
else:
return None
def get_import_pkg(self) -> Optional[str]:
if self.type.startswith('__torch__.'):
return self.type[len('__torch__.'):].split('.')[0]
elif self.type.startswith('__mutated__.'):
return self.type[len('__mutated__.'):].split('.')[0]
else:
return None
def to_init_code(self, field: str) -> Optional[str]:
if self._to_class_name() is not None:
assert 'positional_args' not in self.parameters
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items())
return f'self.{field} = {self._to_class_name()}({kw_params})'
return None
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
"""
Parameters
----------
field : str
the name of member submodule
output : str
the output name (lvalue) of this line of code
inputs : List[str]
variables used in this line of code
inputs_value : List[Any]
some variables are actually constant, their real values are recorded in ```inputs_value```.
if not constant, we simply put None at the corresponding index
Returns
-------
str
generated code line
"""
if self.type == 'aten::slice':
raise RuntimeError('not supposed to have aten::slice operation')
else:
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}')
class TensorFlowOperation(Operation):
def _to_class_name(self) -> str:
return 'K.layers.' + self.type
class Cell(PyTorchOperation):
"""
TODO: this is pytorch cell
An operation reference to a subgraph.
Example code:
```
def __init__(...):
...
self.cell = CustomCell(...)
self.relu = K.layers.ReLU()
...
def forward(...):
...
x = self.cell(x)
...
```
In above example, node `self.cell`'s operation is `Cell(cell_name='CustomCell')`.
For comparison, `self.relu`'s operation is `Operation(type='ReLU')`.
TODO: parameters of subgraph (see `Node` class)
Attributes
----------
type
Always "_cell".
parameters
A dict with only one item; the key is "cell" and the value is cell's name.
framework
No real usage. Exists for compatibility with base class.
"""
def __init__(self, cell_name: str,
parameters: Dict[str, Any] = cast(Dict[str, Any], None),
attributes: Dict[str, Any] = cast(Dict[str, Any], None)):
self.type = '_cell'
self.cell_name = cell_name
self.parameters = parameters or {}
self.attributes = attributes or {}
def _to_class_name(self):
# TODO: ugly, think about how to refactor this part
return _convert_name(self.cell_name)
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class _IOPseudoOperation(Operation):
"""
This is the pseudo operation used by I/O nodes.
The benefit is that users no longer need to verify `Node.operation is not None`,
especially in static type checking.
"""
def __init__(self, type_name: str, io_names: List[str] = cast(List[str], None)):
assert type_name.startswith('_')
super(_IOPseudoOperation, self).__init__(type_name, {}, True)
self.io_names = io_names
def to_init_code(self, field: str) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"')
def __bool__(self) -> bool:
return False
from nni.nas.execution.common.graph_op import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from ..operation import TensorFlowOperation
# pylint: disable=wildcard-import,unused-wildcard-import
class Conv2D(TensorFlowOperation):
def __init__(self, type_name, parameters, _internal, attributes=None):
if 'padding' not in parameters:
parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal)
from nni.nas.execution.tensorflow.op_def import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
# pylint: disable=wildcard-import,unused-wildcard-import
from typing import (Any, Dict, List)
import torch
import torch.nn.functional as nn_functional
from ..operation import PyTorchOperation
mem_format = [
'torch.contiguous_format', # 0
'torch.preserve_format', # 1
'torch.channels_last', # 2
]
# this snippet is copied from torch/onnx/symbolic_helper.py,
# the original definition is in c10/core/ScalarType.h
# This indicates each scalar type's corresponding
scalar_type_to_pytorch_type = [
'torch.uint8', # 0
'torch.int8', # 1
'torch.short', # 2
'torch.int', # 3
'torch.int64', # 4
'torch.half', # 5
'torch.float', # 6
'torch.double', # 7
'torch.complex32', # 8
'torch.complex64', # 9
'torch.complex128', # 10
'torch.bool', # 11
]
class NoOpIdentity(PyTorchOperation):
"""
this operator type is added by us
"""
_ori_type_name = ['noop_identity']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {", ".join(inputs)}'
class ModuleOperator(PyTorchOperation):
_ori_type_name = ['ModuleOperator', 'shared']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = self.{field}({", ".join(inputs)})'
class FunctionalOperator(PyTorchOperation):
_ori_type_name = ['FunctionalOperator']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
func_name = self.type[len('Function.'):]
if not hasattr(nn_functional, func_name):
raise RuntimeError('For now, we only support calling independent functions from `torch.nn.functional`, '
f'{func_name} is not in it.')
return f'{output} = F.{func_name}({", ".join(inputs)})'
class PrimConstant(PyTorchOperation):
_ori_type_name = ['prim::Constant']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: refactor this part, maybe we can remove the code gen of prim::Constant
# TODO: deal with all the types
if self.parameters['type'] in ['None', 'NoneType']:
return f'{output} = None'
elif self.parameters['type'] in ('int', 'float', 'bool', 'int[]'): # 'Long()' ???
return f'{output} = {self.parameters["value"]}'
elif self.parameters['type'] == 'str':
str_val = self.parameters["value"]
return f'{output} = "{str_val}"'
elif self.parameters['type'] == 'Device':
value = self.parameters['value']
return f'{output} = torch.device("{value}")'
elif self.parameters['type'] in ('dict', 'list', 'tuple'):
# TODO: prim::TupleIndex is not supported yet
return f'{output} = {repr(self.parameters["value"])}'
else:
raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}')
class PrimListConstruct(PyTorchOperation):
_ori_type_name = ['prim::ListConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = [{", ".join(inputs)}]'
class PrimListUnpack(PyTorchOperation):
_ori_type_name = ['prim::ListUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}'
class PrimTupleConstruct(PyTorchOperation):
_ori_type_name = ['prim::TupleConstruct']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = ({", ".join(inputs)})'
class PrimTupleUnpack(PyTorchOperation):
_ori_type_name = ['prim::TupleUnpack']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# have single output here, because the following code uses index to access the unpacked values
assert len(inputs) == 1
return f'{output} = {inputs[0]}'
class PrimGetAttr(PyTorchOperation):
_ori_type_name = ['prim::GetAttr']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.parameters['value'] is not None:
return f"{output} = {self.parameters['value']}"
else:
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
class PrimUncheckedCast(PyTorchOperation):
_ori_type_name = ['prim::unchecked_cast']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}'
class SimpleMember(PyTorchOperation):
_ori_type_name = ['prim::is_cuda', 'prim::data']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
member_name = self.type.split('::')[-1]
return f'{output} = {inputs[0]}.{member_name}'
class AtenContiguous(PyTorchOperation):
_ori_type_name = ['aten::contiguous']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# defined in pytorch/c10/core/MemoryFormat.h
assert inputs_value is not None and inputs_value[1] in [0, 1, 2]
return f'{output} = {inputs[0]}.contiguous(memory_format={mem_format[inputs_value[1]]})'
class AtenGetitem(PyTorchOperation):
_ori_type_name = ['aten::__getitem__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'{output} = {inputs[0]}[{inputs[1]}]'
class AtenAppend(PyTorchOperation):
_ori_type_name = ['aten::append']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'_, {output} = {inputs[0]}.append({inputs[1]}), {inputs[0]}'
class MergedSlice(PyTorchOperation):
_ori_type_name = ['MergedSlice']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if (len(inputs) - 1) % 4 == 0:
slices = []
dim = int((len(inputs) - 1) / 4)
for i in range(dim):
slices.append(f'{inputs[i*4+2]}:{inputs[i*4+3]}:{inputs[i*4+4]}')
slice_str = ','.join(slices)
return f'{output} = {inputs[0]}[{slice_str}]'
elif len(inputs) == 4:
# this case is for simple list
return f'{output} = {inputs[0]}[{inputs[1]}:{inputs[2]}:{inputs[3]}]'
else:
raise RuntimeError('Unsupported slice pattern')
# the following Aten classes means these aten ops are not in torch.Tensor
class AtenBool(PyTorchOperation):
_ori_type_name = ['aten::Bool']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = bool({inputs[0]})'
class AtenNot(PyTorchOperation):
_ori_type_name = ['aten::__not__']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = not {inputs[0]}'
class AtenCat(PyTorchOperation):
_ori_type_name = ['aten::cat']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
assert len(inputs) == 2
return f'{output} = torch.cat({inputs[0]}, dim={inputs[1]})'
# ====================================
class AtenTensors(PyTorchOperation):
_ori_type_name = ['aten::full', 'aten::full_like', 'aten::empty_like',
'aten::ones_like', 'aten::zeros_like', 'aten::rand',
'aten::randn', 'aten::scalar_tensor', 'aten::new_full',
'aten::new_empty', 'aten::new_zeros', 'aten::arange',
'aten::tensor', 'aten::ones', 'aten::zeros', 'aten::as_tensor']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
schemas = torch._C._jit_get_schemas_for_operator(self.type)
# match number of inputs
overloaded_defs = [len(s.arguments) for s in schemas]
matched = overloaded_defs.index(len(inputs))
args_list = []
for idx, arg in enumerate(schemas[matched].arguments):
if arg.name == 'dtype':
arg_str = f'dtype={scalar_type_to_pytorch_type[inputs_value[idx]]}' if inputs_value[idx] is not None else ''
elif arg.name == 'layout':
if inputs_value[idx] is not None:
arg_str = f'layout=torch.strided'
print('Warning: only support `torch.strided` for now!!!')
else:
arg_str = ''
elif arg.name == 'device':
arg_str = f'device=torch.device({inputs[idx]})' if inputs_value[idx] is not None else ''
elif arg.name == 'memory_format':
arg_str = f'memory_format={mem_format[inputs_value[idx]]}' if inputs_value[idx] is not None else ''
elif arg.name == 'pin_memory':
# TODO: deal with this argument
continue
elif arg.name == 'requires_grad':
arg_str = f'requires_grad={inputs[idx]}' if inputs_value[idx] else ''
elif str(arg.type).startswith('Optional['):
arg_str = f'{arg.name}={inputs[idx]}'
else:
arg_str = f'{inputs[idx]}'
if arg_str != '':
args_list.append(arg_str)
op_name = self.type.split('::')[-1]
if hasattr(torch, op_name):
return f'{output} = torch.{op_name}({", ".join(args_list)})'
else:
return f'{output} = {inputs[0]}.{op_name}({", ".join(args_list[1:])})'
# ====================================
class AtenFloordiv(PyTorchOperation):
_ori_type_name = ['aten::floordiv']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} // {inputs[1]}'
class AtenMul(PyTorchOperation):
_ori_type_name = ['aten::mul']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]} * {inputs[1]}'
class AtenLen(PyTorchOperation):
_ori_type_name = ['aten::len']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = len({inputs[0]})'
class AtenIntImplicit(PyTorchOperation):
_ori_type_name = ['aten::IntImplicit', 'aten::Float', 'aten::Int', 'aten::ScalarImplicit']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.type.endswith('Implicit'):
return f'{output} = {inputs[0]}'
elif self.type == 'aten::Int':
return f'{output} = int({inputs[0]})'
elif self.type == 'aten::Float':
return f'{output} = float({inputs[0]})'
raise TypeError(f'Unexpected type: {self.type}')
class AtenIndex(PyTorchOperation):
_ori_type_name = ['aten::index']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = {inputs[0]}[{inputs[1]}]'
ManuallyChooseDef = {
'aten::flatten': [('start_dim', 'int', '0'), ('end_dim', 'int', '-1')],
'aten::split': [('split_size', 'int', 'None'), ('dim', 'int', '0')],
# in v1.9 dtype is supported as input argument for view, but torch script does not support it
'aten::view': [('size', 'List[int]', 'None')],
# NOTE: dim supports different types: List[int], List[str], Optional[List[int]], now we only support the first two, refactor needed
# torch.std(input, dim, unbiased, keepdim=False, *, out=None) Tensor
# torch.std(input, unbiased) Tensor
'aten::std': [('dim', 'List[int]', 'None'), ('unbiased', 'bool', 'True'), ('keepdim', 'bool', 'False')]
}
TensorOpExceptions = {
'aten::sub': lambda output, inputs: f'{output} = {inputs[0]} - {inputs[1]}', # example: x.size(1) - 3
'aten::add': lambda output, inputs: f'{output} = {inputs[0]} + {inputs[1]}' # example: input.shape[0] + 5
}
TorchOpExclude = ['aten::Size', 'aten::as_tensor', 'aten::device',
'aten::manual_seed', 'aten::quantized_gru', 'aten::quantized_lstm',
'aten::save', 'aten::tensor', 'aten::wait'
]
def _hidden(name):
return name.startswith('_') and not name.startswith('__')
def _emit_args(args):
# filter out the `out` argument here
return [(arg.name, str(arg.type), str(arg.default_value)) for arg in args] # if arg.name != 'out'
def _get_tensor_ops():
def is_tensor_method(schema):
if len(schema.arguments) == 0:
return False
self = schema.arguments[0]
if self.name != 'self':
return False
if not self.type.isSubtypeOf(torch._C.TensorType.get()):
return False
return True
op_args = {}
# discover methods
for elem in dir(torch.Tensor):
if not _hidden(elem):
schemas = torch._C._jit_get_schemas_for_operator("aten::" + elem)
for schema in schemas:
if is_tensor_method(schema):
op_name = 'aten::' + elem
args = _emit_args(schema.arguments[1:])
if op_name in op_args:
op_args[op_name].append(args)
else:
op_args[op_name] = [args]
return op_args.keys(), op_args
def _get_torch_ops():
torch_op_args = {}
for mod in torch.jit._builtins._modules_containing_builtins: # type: ignore
name = mod.__name__
if name == 'torch._C._nn':
continue
# only process 'torch.XXX'
for elem in dir(mod):
builtin = torch.jit._builtins._find_builtin(getattr(mod, elem)) # type: ignore
if builtin is not None:
schemas = torch._C._jit_get_schemas_for_operator(builtin)
for schema in schemas:
# remove _tan but not __and__
if not _hidden(elem):
op_name = 'aten::' + elem
if len(schema.arguments) > 0 and schema.arguments[0].name == 'self':
continue
args = _emit_args(schema.arguments)
if op_name in torch_op_args:
torch_op_args[op_name].append(args)
else:
torch_op_args[op_name] = [args]
return torch_op_args.keys(), torch_op_args
def _get_torch_ops_exclude_tensor_ops():
tensor_op_names, _ = _get_tensor_ops()
torch_op_names, torch_ops = _get_torch_ops()
torch_exclude_ops = {}
for name in torch_op_names:
if name not in tensor_op_names:
if name not in TorchOpExclude:
# exclude the ops that are not in
# https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/native_functions.yaml
torch_exclude_ops[name] = torch_ops[name]
return torch_exclude_ops.keys(), torch_exclude_ops
class TensorOps(PyTorchOperation):
"""
corresponding to _get_tensor_ops in torch.jit.supported_ops
"""
_ori_type_name, _op_args = _get_tensor_ops()
comparison_ops = {'aten::eq': '==', 'aten::ne': '!=', 'aten::le': '<=', 'aten::ge': '>=', 'aten::lt': '<', 'aten::gt': '>'}
@staticmethod
def _get_matched_args(_type, inputs):
def has_same_arg_name(matched):
concated_names = []
for i, each in enumerate(matched):
name = ','.join([arg[0] for arg in each])
concated_names.append(name)
for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i + 1]:
return False
return True
overloaded_defs = TensorOps._op_args[_type]
matched = []
for each in overloaded_defs:
# plus 1 because we skip the first argument when generating tensor op def
if len(each) + 1 == len(inputs):
matched.append(each)
if len(matched) == 1:
return matched[0]
elif len(matched) > 1:
# TODO: match with arg's type. manually choose for now
if has_same_arg_name(matched):
# return any one is okay
return matched[0]
elif _type in ManuallyChooseDef:
return ManuallyChooseDef[_type]
else:
raise RuntimeError(f'tensor op type {_type} has more than one matched: {matched}')
else:
if _type in TensorOpExceptions:
return None
raise RuntimeError(f'tensor op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
# TODO: deal with conditional ops
if self.type in TensorOps.comparison_ops:
return f'{output} = ({inputs[0]} {TensorOps.comparison_ops[self.type]} {inputs[1]})'
matched_args = TensorOps._get_matched_args(self.type, inputs)
if matched_args is None:
return TensorOpExceptions[self.type](output, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i+1]}' for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = {inputs[0]}.{op_name}({args_str})'
class TorchOps(PyTorchOperation):
"""
corresponding to _get_nn_functional_ops in torch.jit.supported_ops
"""
_ori_type_name, _op_args = _get_torch_ops_exclude_tensor_ops()
# add 'aten::pixel_shuffle'
_op_args['aten::pixel_shuffle'] = [[('input', 'Tensor', 'None'), ('upscale_factor', 'Optional[int]', 'None')]]
_ori_type_name = _op_args.keys()
@staticmethod
def _get_matched_args(_type, inputs):
def has_same_arg_name(matched):
concated_names = []
for i, each in enumerate(matched):
name = ','.join([arg[0] for arg in each])
concated_names.append(name)
for i in range(len(concated_names) - 1):
if concated_names[i] != concated_names[i + 1]:
return False
return True
overloaded_defs = TorchOps._op_args[_type]
matched = []
for each in overloaded_defs:
if len(each) == len(inputs):
matched.append(each)
if len(matched) == 1:
return matched[0]
elif len(matched) > 1:
# TODO: match with arg's type. manually choose for now
if has_same_arg_name(matched):
# return any one is okay
return matched[0]
else:
raise RuntimeError(f'torch op type {_type} has more than one matched: {matched}')
else:
raise RuntimeError(f'torch op type {_type} has no matched')
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
matched_args = TorchOps._get_matched_args(self.type, inputs)
op_name = self.type.split('::')[-1]
args_str = ', '.join([f'{name}={inputs[i]}' if t.startswith('Optional[') else f'{inputs[i]}'
for i, (name, t, default) in enumerate(matched_args)])
return f'{output} = torch.{op_name}({args_str})'
class AtenAvgpool2d(PyTorchOperation):
# NOTE: it is not included in the above aten ops for unkown reason
_ori_type_name = ['aten::avg_pool2d']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = F.avg_pool2d({", ".join(inputs)})'
class ToDevice(PyTorchOperation):
_artificial_op_name = "ToDevice"
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False,
attributes: Dict[str, Any] = {}):
self.type = "ToDevice"
self.device = parameters['device']
self.overridden_device_repr = None
self.src = parameters['src']
self.dst = parameters['dst']
def override_device_repr(self, device_repr):
# CUDA GPUDevice may remap GPU physical ID to CUDA ID. The device repr is different from GPUDevice.device_repr()
# override_device_repr will be called in pytorch.graph_to_pytorch_model to replace device_repr with the correct
# CUDA ID, e.g., when a job uses Physical GPU-1,2, its CUDA ID should be "cuda:0" and "cuda:1".
# self.device.device_repr() would return "cuda:1" and "cuda:2", but override_device_repr should be "cuda:0" and
# "cuda:1"
self.overridden_device_repr = device_repr
def __repr__(self):
if self.overridden_device_repr is None:
return f'to("{self.device.device_repr()}")'
else:
return f'to("{self.overridden_device_repr}")'
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
if self.overridden_device_repr is None:
forward_code = f'{output} = {inputs[0]}.to("{self.device.device_repr()}")'
else:
forward_code = f'{output} = {inputs[0]}.to("{self.overridden_device_repr}")'
return forward_code
class AtenDet(PyTorchOperation):
# for torch 1.9
# NOTE: it is not included in the above aten ops, maybe because torch.det is alias for torch.linalg.det
_ori_type_name = ['aten::linalg_det']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any]) -> str:
return f'{output} = torch.det({inputs[0]})'
from nni.nas.execution.pytorch.op_def import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment