Unverified Commit f77db747 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Enhancement of one-shot NAS (v2.9) (#5049)

parent 125ec21f
...@@ -70,7 +70,7 @@ class NasBench201Cell(nn.Module): ...@@ -70,7 +70,7 @@ class NasBench201Cell(nn.Module):
inp = in_features if j == 0 else out_features inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features)) op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()]) for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine node_ops.append(LayerChoice(op_choices, label=f'{self._label}/{j}_{tid}'))
self.layers.append(node_ops) self.layers.append(node_ops)
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
......
...@@ -179,7 +179,7 @@ class NasBench201(nn.Module): ...@@ -179,7 +179,7 @@ class NasBench201(nn.Module):
cell = ResNetBasicblock(C_prev, C_curr, 2) cell = ResNetBasicblock(C_prev, C_curr, 2)
else: else:
ops: Dict[str, Callable[[int, int], nn.Module]] = { ops: Dict[str, Callable[[int, int], nn.Module]] = {
prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES prim: self._make_op_factory(prim) for prim in PRIMITIVES
} }
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell') cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
self.cells.append(cell) self.cells.append(cell)
...@@ -192,6 +192,9 @@ class NasBench201(nn.Module): ...@@ -192,6 +192,9 @@ class NasBench201(nn.Module):
self.global_pooling = nn.AdaptiveAvgPool2d(1) self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, self.num_labels) self.classifier = nn.Linear(C_prev, self.num_labels)
def _make_op_factory(self, prim):
return lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1)
def forward(self, inputs): def forward(self, inputs):
feature = self.stem(inputs) feature = self.stem(inputs)
for cell in self.cells: for cell in self.cells:
......
...@@ -5,23 +5,21 @@ from __future__ import annotations ...@@ -5,23 +5,21 @@ from __future__ import annotations
import warnings import warnings
from itertools import chain from itertools import chain
from typing import Callable, Any, Dict, Union, Tuple, List, cast from typing import Callable, Any, Dict, Union, Tuple, cast
import pytorch_lightning as pl import pytorch_lightning as pl
import torch.optim as optim import torch.optim as optim
import torch.nn as nn import torch.nn as nn
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import nni.nas.nn.pytorch as nas_nn import nni.nas.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable from nni.common.serializer import is_traceable
from nni.nas.nn.pytorch.choice import ValueChoiceX from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.typehint import Literal
from .supermodule.base import BaseSuperNetModule from .supermodule.base import BaseSuperNetModule
__all__ = [ __all__ = [
'MANUAL_OPTIMIZATION_NOTE',
'MutationHook', 'MutationHook',
'BaseSuperNetModule', 'BaseSuperNetModule',
'BaseOneShotLightningModule', 'BaseOneShotLightningModule',
...@@ -30,6 +28,22 @@ __all__ = [ ...@@ -30,6 +28,22 @@ __all__ = [
] ]
MANUAL_OPTIMIZATION_NOTE = """
.. warning::
The strategy, under the hood, creates a Lightning module that wraps the Lightning module defined in evaluator,
and enables `Manual optimization <https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html>`_,
although we assume **the inner evaluator has enabled automatic optimization**.
We call the optimizers and schedulers configured in evaluator, following the definition in Lightning at best effort,
but we make no guarantee that the behaviors are exactly same as automatic optimization.
We call :meth:`~BaseSuperNetModule.advance_optimization` and :meth:`~BaseSuperNetModule.advance_lr_schedulers`
to invoke the optimizers and schedulers configured in evaluators.
Moreover, some advanced features like gradient clipping will not be supported.
If you encounter any issues, please contact us by `creating an issue <https://github.com/microsoft/nni/issues>`_.
"""
MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]] MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
...@@ -122,7 +136,7 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k ...@@ -122,7 +136,7 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k
nas_nn.LayerChoice, nas_nn.LayerChoice,
nas_nn.InputChoice, nas_nn.InputChoice,
nas_nn.Repeat, nas_nn.Repeat,
# nas_nn.NasBench101Cell, # FIXME: nasbench101 is moved to hub, can't check any more. # nas_nn.NasBench101Cell,
# nas_nn.ValueChoice, # could be false positive # nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later # nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet # nas_nn.NasBench201Cell, # forward = supernet
...@@ -156,8 +170,8 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -156,8 +170,8 @@ class BaseOneShotLightningModule(pl.LightningModule):
Extra mutation hooks to support customized mutation on primitives other than built-ins. Extra mutation hooks to support customized mutation on primitives other than built-ins.
Mutation hooks are callable that inputs an Module and returns a Mutation hooks are callable that inputs an Module and returns a
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`. :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.nas.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules. They are invoked in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently, For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks. the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module, The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
...@@ -177,21 +191,21 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -177,21 +191,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
The returned arguments can be also one of the three kinds: The returned arguments can be also one of the three kinds:
1. tuple of: :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean, 1. tuple of: :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
2. boolean, 2. boolean,
3. :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None. 3. :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
The boolean value is ``suppress`` indicates whether the following hooks should be called. The boolean value is ``suppress`` indicates whether the following hooks should be called.
When it's true, it suppresses the subsequent hooks, and they will never be invoked. When it's true, it suppresses the subsequent hooks, and they will never be invoked.
Without boolean value specified, it's assumed to be false. Without boolean value specified, it's assumed to be false.
If a none value appears on the place of If a none value appears on the place of
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`, :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
it means the hook suggests to it means the hook suggests to
keep the module unchanged, and nothing will happen. keep the module unchanged, and nothing will happen.
An example of mutation hook is given in :func:`~nni.nas.oneshot.pytorch.base_lightning.no_default_hook`. An example of mutation hook is given in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.no_default_hook`.
However it's recommended to implement mutation hooks by deriving However it's recommended to implement mutation hooks by deriving
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`, :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
and add its classmethod ``mutate`` to this list. and add its classmethod ``mutate`` to this list.
""" """
...@@ -295,236 +309,232 @@ class BaseOneShotLightningModule(pl.LightningModule): ...@@ -295,236 +309,232 @@ class BaseOneShotLightningModule(pl.LightningModule):
result.update(module.export(memo=result)) result.update(module.export(memo=result))
return result return result
def forward(self, x): def export_probs(self) -> dict[str, Any]:
return self.model(x)
def training_step(self, batch, batch_idx):
"""This is the implementation of what happens in training loops of one-shot algos.
It usually calls ``self.model.training_step`` which implements the real training recipe of the users' model.
""" """
return self.model.training_step(batch, batch_idx) Export the probability of every choice in the search space got chosen.
def configure_optimizers(self): .. note:: If such method of some modules is not implemented, they will be simply ignored.
"""
Combine architecture optimizers and user's model optimizers.
You can overwrite :meth:`configure_architecture_optimizers` if architecture optimizers are needed in your NAS algorithm.
For now :attr:`model` is tested against evaluators in :mod:`nni.nas.evaluator.pytorch.lightning` Returns
and it only returns 1 optimizer. -------
But for extendibility, codes for other return value types are also implemented. dict
In most cases, keys are names of ``nas_modules`` suffixed with ``/`` and choice name.
Values are the probability / logits depending on the implementation.
""" """
# pylint: disable=assignment-from-none result = {}
arc_optimizers = self.configure_architecture_optimizers() for module in self.nas_modules:
if arc_optimizers is None:
return self.model.configure_optimizers()
if isinstance(arc_optimizers, optim.Optimizer):
arc_optimizers = [arc_optimizers]
self.arc_optim_count = len(arc_optimizers)
# FIXME: this part uses non-official lightning API.
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
try: try:
# above v1.6 result.update(module.export_probs(memo=result))
from pytorch_lightning.core.optimizer import ( # pylint: disable=import-error except NotImplementedError:
_configure_optimizers, # type: ignore warnings.warn(
_configure_schedulers_automatic_opt, # type: ignore 'Some super-modules you have used did not implement export_probs. You might find some logs are missing.',
_configure_schedulers_manual_opt # type: ignore UserWarning
)
w_optimizers, lr_schedulers, self.frequencies, monitor = \
_configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = (
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
if self.automatic_optimization
else _configure_schedulers_manual_opt(lr_schedulers)
)
except ImportError:
# under v1.5
w_optimizers, lr_schedulers, self.frequencies, monitor = \
self.trainer._configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization) # type: ignore
if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers): # type: ignore
raise Exception(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
) )
return result
# variables used to handle optimizer frequency def forward(self, x):
self.cur_optimizer_step = 0 return self.model(x)
self.cur_optimizer_index = 0
return arc_optimizers + w_optimizers, lr_schedulers def configure_optimizers(self) -> Any:
"""
Transparently configure optimizers for the inner model,
unless one-shot algorithm has its own optimizer (via :meth:`configure_architecture_optimizers`),
in which case, the optimizer will be appended to the list.
def on_train_start(self): The return value is still one of the 6 types defined in PyTorch-Lightning.
return self.model.on_train_start() """
arch_optimizers = self.configure_architecture_optimizers() or []
if not arch_optimizers: # no architecture optimizer available
return self.model.configure_optimizers()
def on_train_end(self): if isinstance(arch_optimizers, optim.Optimizer):
return self.model.on_train_end() arch_optimizers = [arch_optimizers]
def on_fit_start(self): # Set the flag to True so that they can differ from other optimizers
for optimizer in arch_optimizers:
optimizer.is_arch_optimizer = True # type: ignore
optim_conf: Any = self.model.configure_optimizers()
# 0. optimizer is none
if optim_conf is None:
return arch_optimizers
# 1. single optimizer
if isinstance(optim_conf, Optimizer):
return [optim_conf] + arch_optimizers
# 2. two lists, optimizer + lr schedulers
if (
isinstance(optim_conf, (list, tuple))
and len(optim_conf) == 2
and isinstance(optim_conf[0], list)
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
):
return list(optim_conf[0]) + arch_optimizers, optim_conf[1]
# 3. single dictionary
if isinstance(optim_conf, dict):
return [optim_conf] + [{'optimizer': optimizer} for optimizer in arch_optimizers]
# 4. multiple dictionaries
if isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
return list(optim_conf) + [{'optimizer': optimizer} for optimizer in arch_optimizers]
# 5. single list or tuple, multiple optimizer
if isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
return list(optim_conf) + arch_optimizers
# unknown configuration
warnings.warn('Unknown optimizer configuration. Architecture optimizers will be ignored. Strategy might fail.', UserWarning)
return optim_conf
def setup(self, stage=None):
# redirect the access to trainer/log to this module # redirect the access to trainer/log to this module
# but note that we might be missing other attributes, # but note that we might be missing other attributes,
# which could potentially be a problem # which could potentially be a problem
self.model.trainer = self.trainer # type: ignore self.model.trainer = self.trainer # type: ignore
self.model.log = self.log self.model.log = self.log
return self.model.on_fit_start()
def on_fit_end(self): # Reset the optimizer progress (only once at the very beginning)
return self.model.on_fit_end() self._optimizer_progress = 0
def on_train_batch_start(self, batch, batch_idx, *args, **kwargs): return self.model.setup(stage)
return self.model.on_train_batch_start(batch, batch_idx, *args, **kwargs)
def on_train_batch_end(self, outputs, batch, batch_idx, *args, **kwargs): def teardown(self, stage=None):
return self.model.on_train_batch_end(outputs, batch, batch_idx, *args, **kwargs) return self.model.teardown(stage)
# Deprecated hooks in pytorch-lightning def configure_architecture_optimizers(self) -> list[optim.Optimizer] | optim.Optimizer | None:
def on_epoch_start(self):
return self.model.on_epoch_start()
def on_epoch_end(self):
return self.model.on_epoch_end()
def on_train_epoch_start(self):
return self.model.on_train_epoch_start()
def on_train_epoch_end(self):
return self.model.on_train_epoch_end()
def on_before_backward(self, loss):
return self.model.on_before_backward(loss)
def on_after_backward(self):
return self.model.on_after_backward()
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None):
return self.model.configure_gradient_clipping(optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm)
def configure_architecture_optimizers(self):
""" """
Hook kept for subclasses. A specific NAS method inheriting this base class should return its architecture optimizers here Hook kept for subclasses. A specific NAS method inheriting this base class should return its architecture optimizers here
if architecture parameters are needed. Note that lr schedulers are not supported now for architecture_optimizers. if architecture parameters are needed. Note that lr schedulers are not supported now for architecture_optimizers.
Returns Returns
---------- -------
arc_optimizers : list[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed. Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
""" """
return None return None
def call_lr_schedulers(self, batch_index): def advance_optimization(
self,
loss: Any,
batch_idx: int,
gradient_clip_val: int | float | None = None,
gradient_clip_algorithm: str | None = None
):
""" """
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off Run the optimizer defined in evaluators, when manual optimization is turned on.
by this class, you can use this function to make schedulers behave as they were automatically handled by the lightning trainer.
Call this method when the model should be optimized.
To keep it as neat as possible, we only implement the basic ``zero_grad``, ``backward``, ``grad_clip``, and ``step`` here.
Many hooks and pre/post-processing are omitted.
Inherit this method if you need more advanced behavior.
The full optimizer step could be found
`here <https://github.com/Lightning-AI/lightning/blob/0e531283/src/pytorch_lightning/loops/optimization/optimizer_loop.py>`__.
We only implement part of the optimizer loop here.
Parameters Parameters
---------- ----------
batch_idx : int batch_idx: int
batch index The current batch index.
""" """
def apply(lr_scheduler): if self.automatic_optimization:
# single scheduler is called every epoch raise ValueError('This method should not be used when automatic optimization is turned on.')
if isinstance(lr_scheduler, _LRScheduler):
if self.trainer.is_last_batch:
lr_scheduler.step()
# lr_scheduler_config is called as configured
elif isinstance(lr_scheduler, dict):
interval = lr_scheduler['interval']
frequency = lr_scheduler['frequency']
if (
interval == 'step' and
batch_index % frequency == 0
) or \
(
interval == 'epoch' and
self.trainer.is_last_batch and
(self.trainer.current_epoch + 1) % frequency == 0
):
lr_scheduler['scheduler'].step()
lr_schedulers = self.lr_schedulers() if self.trainer.optimizer_frequencies:
warnings.warn('optimizer_frequencies is not supported in NAS. It will be ignored.', UserWarning)
if isinstance(lr_schedulers, list): # Filter out optimizers for architecture parameters
for lr_scheduler in lr_schedulers: optimizers = [opt for opt in self.trainer.optimizers if not getattr(opt, 'is_arch_optimizer', False)]
apply(lr_scheduler)
else:
apply(lr_schedulers)
def call_weight_optimizers(self, method: Literal['step', 'zero_grad']): opt_idx = self._optimizer_progress % len(optimizers)
optimizer = optimizers[opt_idx]
# There should be many before/after hooks called here, but they are omitted in this implementation.
# 1. zero gradient
self.model.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
# 2. backward
self.manual_backward(loss)
# 3. grad clip
self.model.configure_gradient_clipping(optimizer, opt_idx, gradient_clip_val, gradient_clip_algorithm)
# 4. optimizer step
self.model.optimizer_step(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self._optimizer_progress += 1
def advance_lr_schedulers(self, batch_idx: int):
""" """
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this Advance the learning rates, when manual optimization is turned on.
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Parameters The full implementation is
---------- `here <https://github.com/Lightning-AI/lightning/blob/0e531283/src/pytorch_lightning/loops/epoch/training_epoch_loop.py>`__.
method : str We only include a partial implementation here.
Method to call. Only ``step`` and ``zero_grad`` are supported now. Advanced features like Reduce-lr-on-plateau are not supported.
""" """
def apply_method(optimizer, method): if self.automatic_optimization:
if method == 'step': raise ValueError('This method should not be used when automatic optimization is turned on.')
optimizer.step()
elif method == 'zero_grad': self._advance_lr_schedulers_impl(batch_idx, 'step')
optimizer.zero_grad() if self.trainer.is_last_batch:
self._advance_lr_schedulers_impl(batch_idx, 'epoch')
optimizers = self.weight_optimizers()
if optimizers is None: def _advance_lr_schedulers_impl(self, batch_idx: int, interval: str):
return current_idx = batch_idx if interval == 'step' else self.trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
assert isinstance(optimizers, list), 'Did you forget to set use_pl_optimizers to true?'
try:
if len(self.frequencies) > 0: # lightning >= 1.6
self.cur_optimizer_step += 1 for config in self.trainer.lr_scheduler_configs:
if self.frequencies[self.cur_optimizer_index] == self.cur_optimizer_step: scheduler, opt_idx = config.scheduler, config.opt_idx
self.cur_optimizer_step = 0 if config.reduce_on_plateau:
self.cur_optimizer_index = self.cur_optimizer_index + 1 \ warnings.warn('Reduce-lr-on-plateau is not supported in NAS. It will be ignored.', UserWarning)
if self.cur_optimizer_index + 1 < len(optimizers) \ if config.interval == interval and current_idx % config.frequency == 0:
else 0 self.model.lr_scheduler_step(cast(Any, scheduler), cast(int, opt_idx), None)
apply_method(optimizers[self.cur_optimizer_index], method) except AttributeError:
else: # lightning < 1.6
for optimizer in optimizers: for lr_scheduler in self.trainer.lr_schedulers:
apply_method(optimizer, method) if lr_scheduler['reduce_on_plateau']:
warnings.warn('Reduce-lr-on-plateau is not supported in NAS. It will be ignored.', UserWarning)
if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency']:
lr_scheduler['scheduler'].step()
def architecture_optimizers(self) -> list[Optimizer] | Optimizer | None: def architecture_optimizers(self) -> list[Optimizer] | Optimizer | None:
""" """
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in :meth:`training_step`. Get the optimizers configured in :meth:`configure_architecture_optimizers`.
Returns
----------
opts : list[Optimizer], Optimizer, None
Architecture optimizers defined in :meth:`configure_architecture_optimizers`. This will be None if there is no
architecture optimizers.
""" """
opts = self.optimizers() optimizers = [opt for opt in self.trainer.optimizers if getattr(opt, 'is_arch_optimizer', False)]
if isinstance(opts, list): if not optimizers:
# pylint: disable=unsubscriptable-object
arc_opts = opts[:self.arc_optim_count]
if len(arc_opts) == 1:
return cast(Optimizer, arc_opts[0])
return cast(List[Optimizer], arc_opts)
# If there is only 1 optimizer and it is the architecture optimizer
if self.arc_optim_count == 1:
return cast(Union[List[Optimizer], Optimizer], opts)
return None return None
if len(optimizers) == 1:
return optimizers[0]
return optimizers
def weight_optimizers(self) -> list[Optimizer] | Optimizer | None: # The following methods redirects the callbacks to inner module.
""" # It's not the complete list though.
Get user optimizers from all optimizers. Use this to get user optimizers in :meth:`training_step`. # More methods can be added if needed.
Returns def on_train_start(self):
---------- return self.model.on_train_start()
opts : list[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers. def on_train_end(self):
""" return self.model.on_train_end()
# Since use_pl_optimizer is set true (by default) here.
# opts always return a list def on_fit_start(self):
opts = self.optimizers() return self.model.on_fit_start()
if isinstance(opts, list):
# pylint: disable=unsubscriptable-object def on_fit_end(self):
return cast(List[Optimizer], opts[self.arc_optim_count:]) return self.model.on_fit_end()
# FIXME: this case is actually not correctly handled
# If there is only 1 optimizer and no architecture optimizer def on_train_batch_start(self, batch, batch_idx, *args, **kwargs):
if self.arc_optim_count == 0: return self.model.on_train_batch_start(batch, batch_idx, *args, **kwargs)
return cast(Union[List[Optimizer], Optimizer], opts)
return None def on_train_batch_end(self, outputs, batch, batch_idx, *args, **kwargs):
return self.model.on_train_batch_end(outputs, batch, batch_idx, *args, **kwargs)
def on_train_epoch_start(self):
return self.model.on_train_epoch_start()
def on_train_epoch_end(self):
return self.model.on_train_epoch_end()
def on_before_backward(self, loss):
return self.model.on_before_backward(loss)
def on_after_backward(self):
return self.model.on_after_backward()
...@@ -9,7 +9,7 @@ import pytorch_lightning as pl ...@@ -9,7 +9,7 @@ import pytorch_lightning as pl
import torch import torch
import torch.optim as optim import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook from .base_lightning import BaseOneShotLightningModule, MANUAL_OPTIMIZATION_NOTE, MutationHook, no_default_hook
from .supermodule.differentiable import ( from .supermodule.differentiable import (
DifferentiableMixedLayer, DifferentiableMixedInput, DifferentiableMixedLayer, DifferentiableMixedInput,
MixedOpDifferentiablePolicy, GumbelSoftmax, MixedOpDifferentiablePolicy, GumbelSoftmax,
...@@ -28,6 +28,7 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -28,6 +28,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
DARTS repeats iterations, where each iteration consists of 2 training phases. DARTS repeats iterations, where each iteration consists of 2 training phases.
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained. The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained. The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
In both phases, ``training_step`` of the Lightning evaluator will be used.
The current implementation corresponds to DARTS (1st order) in paper. The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet. Second order (unrolled 2nd-order derivatives) is not supported yet.
...@@ -49,15 +50,20 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -49,15 +50,20 @@ class DartsLightningModule(BaseOneShotLightningModule):
{{module_notes}} {{module_notes}}
{optimization_note}
Parameters Parameters
---------- ----------
{{module_params}} {{module_params}}
{base_params} {base_params}
arc_learning_rate : float arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
""".format( """.format(
base_params=BaseOneShotLightningModule._mutation_hooks_note, base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES) supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES),
optimization_note=MANUAL_OPTIMIZATION_NOTE
) )
__doc__ = _darts_note.format( __doc__ = _darts_note.format(
...@@ -85,8 +91,10 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -85,8 +91,10 @@ class DartsLightningModule(BaseOneShotLightningModule):
def __init__(self, inner_module: pl.LightningModule, def __init__(self, inner_module: pl.LightningModule,
mutation_hooks: list[MutationHook] | None = None, mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0E-4): arc_learning_rate: float = 3.0E-4,
gradient_clip_val: float | None = None):
self.arc_learning_rate = arc_learning_rate self.arc_learning_rate = arc_learning_rate
self.gradient_clip_val = gradient_clip_val
super().__init__(inner_module, mutation_hooks=mutation_hooks) super().__init__(inner_module, mutation_hooks=mutation_hooks)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
...@@ -108,33 +116,32 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -108,33 +116,32 @@ class DartsLightningModule(BaseOneShotLightningModule):
if isinstance(arc_step_loss, dict): if isinstance(arc_step_loss, dict):
arc_step_loss = arc_step_loss['loss'] arc_step_loss = arc_step_loss['loss']
self.manual_backward(arc_step_loss) self.manual_backward(arc_step_loss)
self.finalize_grad()
arc_optim.step() arc_optim.step()
# phase 2: model step # phase 2: model step
self.resample() self.resample()
self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1) loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1)
w_step_loss = loss_and_metrics['loss'] \ w_step_loss = loss_and_metrics['loss'] if isinstance(loss_and_metrics, dict) else loss_and_metrics
if isinstance(loss_and_metrics, dict) else loss_and_metrics self.advance_optimization(w_step_loss, batch_idx, self.gradient_clip_val)
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
self.call_lr_schedulers(batch_idx) # Update learning rates
self.advance_lr_schedulers(batch_idx)
return loss_and_metrics self.log_dict({'prob/' + k: v for k, v in self.export_probs().items()})
def finalize_grad(self): return loss_and_metrics
# Note: This hook is currently kept for Proxyless NAS.
pass
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer. # The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params = [] ctrl_params = []
for m in self.nas_modules: for m in self.nas_modules:
ctrl_params += list(m.parameters(arch=True)) # type: ignore ctrl_params += list(m.parameters(arch=True)) # type: ignore
ctrl_optim = torch.optim.Adam(list(set(ctrl_params)), 3.e-4, betas=(0.5, 0.999), # Follow the hyper-parameters used in
weight_decay=1.0E-3) # https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/architect.py#L17
params = list(set(ctrl_params))
if not params:
raise ValueError('No architecture parameters found. Nothing to search.')
ctrl_optim = torch.optim.Adam(params, 3.e-4, betas=(0.5, 0.999), weight_decay=1.0E-3)
return ctrl_optim return ctrl_optim
...@@ -153,13 +160,20 @@ class ProxylessLightningModule(DartsLightningModule): ...@@ -153,13 +160,20 @@ class ProxylessLightningModule(DartsLightningModule):
{{module_notes}} {{module_notes}}
{optimization_note}
Parameters Parameters
---------- ----------
{{module_params}} {{module_params}}
{base_params} {base_params}
arc_learning_rate : float arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note) gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
optimization_note=MANUAL_OPTIMIZATION_NOTE
)
__doc__ = _proxyless_note.format( __doc__ = _proxyless_note.format(
module_notes='This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.', module_notes='This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
...@@ -176,10 +190,6 @@ class ProxylessLightningModule(DartsLightningModule): ...@@ -176,10 +190,6 @@ class ProxylessLightningModule(DartsLightningModule):
# FIXME: no support for mixed operation currently # FIXME: no support for mixed operation currently
return hooks return hooks
def finalize_grad(self):
for m in self.nas_modules:
m.finalize_grad() # type: ignore
class GumbelDartsLightningModule(DartsLightningModule): class GumbelDartsLightningModule(DartsLightningModule):
_gumbel_darts_note = """ _gumbel_darts_note = """
...@@ -207,6 +217,8 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -207,6 +217,8 @@ class GumbelDartsLightningModule(DartsLightningModule):
{{module_notes}} {{module_notes}}
{optimization_note}
Parameters Parameters
---------- ----------
{{module_params}} {{module_params}}
...@@ -216,13 +228,17 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -216,13 +228,17 @@ class GumbelDartsLightningModule(DartsLightningModule):
use_temp_anneal : bool use_temp_anneal : bool
If true, a linear annealing will be applied to ``gumbel_temperature``. If true, a linear annealing will be applied to ``gumbel_temperature``.
Otherwise, run at a fixed temperature. See `SNAS <https://arxiv.org/abs/1812.09926>`__ for details. Otherwise, run at a fixed temperature. See `SNAS <https://arxiv.org/abs/1812.09926>`__ for details.
Default is false.
min_temp : float min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False. The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
arc_learning_rate : float arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
""".format( """.format(
base_params=BaseOneShotLightningModule._mutation_hooks_note, base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES) supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES),
optimization_note=MANUAL_OPTIMIZATION_NOTE
) )
def mutate_kwargs(self): def mutate_kwargs(self):
...@@ -235,22 +251,25 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -235,22 +251,25 @@ class GumbelDartsLightningModule(DartsLightningModule):
def __init__(self, inner_module, def __init__(self, inner_module,
mutation_hooks: list[MutationHook] | None = None, mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0e-4, arc_learning_rate: float = 3.0e-4,
gradient_clip_val: float | None = None,
gumbel_temperature: float = 1., gumbel_temperature: float = 1.,
use_temp_anneal: bool = False, use_temp_anneal: bool = False,
min_temp: float = .33): min_temp: float = .33):
super().__init__(inner_module, mutation_hooks, arc_learning_rate=arc_learning_rate) super().__init__(inner_module, mutation_hooks, arc_learning_rate=arc_learning_rate, gradient_clip_val=gradient_clip_val)
self.temp = gumbel_temperature self.temp = gumbel_temperature
self.init_temp = gumbel_temperature self.init_temp = gumbel_temperature
self.use_temp_anneal = use_temp_anneal self.use_temp_anneal = use_temp_anneal
self.min_temp = min_temp self.min_temp = min_temp
def on_train_epoch_end(self): def on_train_epoch_start(self):
if self.use_temp_anneal: if self.use_temp_anneal:
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp) self.temp = max(self.temp, self.min_temp)
self.log('gumbel_temperature', self.temp)
for module in self.nas_modules: for module in self.nas_modules:
if hasattr(module, '_softmax'): if hasattr(module, '_softmax') and isinstance(module, GumbelSoftmax):
module._softmax.temp = self.temp # type: ignore module._softmax.tau = self.temp # type: ignore
return self.model.on_train_epoch_end() return self.model.on_train_epoch_start()
...@@ -94,11 +94,11 @@ class ReinforceController(nn.Module): ...@@ -94,11 +94,11 @@ class ReinforceController(nn.Module):
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
}) })
def resample(self): def resample(self, return_prob=False):
self._initialize() self._initialize()
result = dict() result = dict()
for field in self.fields: for field in self.fields:
result[field.name] = self._sample_single(field) result[field.name] = self._sample_single(field, return_prob=return_prob)
return result return result
def _initialize(self): def _initialize(self):
...@@ -116,7 +116,7 @@ class ReinforceController(nn.Module): ...@@ -116,7 +116,7 @@ class ReinforceController(nn.Module):
def _lstm_next_step(self): def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c)) self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field): def _sample_single(self, field, return_prob):
self._lstm_next_step() self._lstm_next_step()
logit = self.soft[field.name](self._h[-1]) logit = self.soft[field.name](self._h[-1])
if self.temperature is not None: if self.temperature is not None:
...@@ -124,10 +124,12 @@ class ReinforceController(nn.Module): ...@@ -124,10 +124,12 @@ class ReinforceController(nn.Module):
if self.tanh_constant is not None: if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit) logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one: if field.choose_one:
sampled_dist = F.softmax(logit, dim=-1)
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled) log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled) self._inputs = self.embedding[field.name](sampled)
else: else:
sampled_dist = torch.sigmoid(logit)
logit = logit.view(-1, 1) logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
...@@ -147,4 +149,7 @@ class ReinforceController(nn.Module): ...@@ -147,4 +149,7 @@ class ReinforceController(nn.Module):
self.sample_entropy += self.entropy_reduction(entropy) self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1: if len(sampled) == 1:
sampled = sampled[0] sampled = sampled[0]
if return_prob:
return sampled_dist.flatten().detach().cpu().numpy().tolist()
return sampled return sampled
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any from typing import Any, cast
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook from .base_lightning import MANUAL_OPTIMIZATION_NOTE, BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES
from .supermodule.sampling import ( from .supermodule.sampling import (
PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy, PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy,
...@@ -37,6 +37,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule): ...@@ -37,6 +37,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
* :class:`nni.retiarii.nn.pytorch.Cell`. * :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`. * :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
This strategy assumes inner evaluator has set
`automatic optimization <https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html>`__ to true.
Parameters Parameters
---------- ----------
{{module_params}} {{module_params}}
...@@ -73,9 +76,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule): ...@@ -73,9 +76,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
'mixed_op_sampling': MixedOpPathSamplingPolicy 'mixed_op_sampling': MixedOpPathSamplingPolicy
} }
def training_step(self, batch, batch_idx): def training_step(self, *args, **kwargs):
self.resample() self.resample()
return self.model.training_step(batch, batch_idx) return self.model.training_step(*args, **kwargs)
def export(self) -> dict[str, Any]: def export(self) -> dict[str, Any]:
""" """
...@@ -115,6 +118,8 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -115,6 +118,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
{{module_notes}} {{module_notes}}
{optimization_note}
Parameters Parameters
---------- ----------
{{module_params}} {{module_params}}
...@@ -133,6 +138,8 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -133,6 +138,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
before updating the weights of RL controller. before updating the weights of RL controller.
ctrl_grad_clip : float ctrl_grad_clip : float
Gradient clipping value of controller. Gradient clipping value of controller.
log_prob_every_n_step : int
Log the probability of choices every N steps. Useful for visualization and debugging.
reward_metric_name : str or None reward_metric_name : str or None
The name of the metric which is treated as reward. The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator. This will be not effective when there's only one metric returned from evaluator.
...@@ -141,11 +148,12 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -141,11 +148,12 @@ class EnasLightningModule(RandomSamplingLightningModule):
Otherwise it raises an exception indicating multiple metrics are found. Otherwise it raises an exception indicating multiple metrics are found.
""".format( """.format(
base_params=BaseOneShotLightningModule._mutation_hooks_note, base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES) supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES),
optimization_note=MANUAL_OPTIMIZATION_NOTE
) )
__doc__ = _enas_note.format( __doc__ = _enas_note.format(
module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.', module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.pytorch.dataloader.ConcatLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note, module_params=BaseOneShotLightningModule._inner_module_note,
) )
...@@ -162,6 +170,7 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -162,6 +170,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
baseline_decay: float = .999, baseline_decay: float = .999,
ctrl_steps_aggregate: float = 20, ctrl_steps_aggregate: float = 20,
ctrl_grad_clip: float = 0, ctrl_grad_clip: float = 0,
log_prob_every_n_step: int = 10,
reward_metric_name: str | None = None, reward_metric_name: str | None = None,
mutation_hooks: list[MutationHook] | None = None): mutation_hooks: list[MutationHook] | None = None):
super().__init__(inner_module, mutation_hooks) super().__init__(inner_module, mutation_hooks)
...@@ -181,33 +190,29 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -181,33 +190,29 @@ class EnasLightningModule(RandomSamplingLightningModule):
self.baseline = 0. self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.ctrl_grad_clip = ctrl_grad_clip self.ctrl_grad_clip = ctrl_grad_clip
self.log_prob_every_n_step = log_prob_every_n_step
self.reward_metric_name = reward_metric_name self.reward_metric_name = reward_metric_name
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4) return optim.Adam(self.controller.parameters(), lr=3.5e-4)
def training_step(self, batch_packed, batch_idx): def training_step(self, batch_packed, batch_idx):
# The received batch is a tuple of (data, "train" | "val")
batch, mode = batch_packed batch, mode = batch_packed
if mode == 'train': if mode == 'train':
# train model params # train model params
with torch.no_grad(): with torch.no_grad():
self.resample() self.resample()
self.call_weight_optimizers('zero_grad')
step_output = self.model.training_step(batch, batch_idx) step_output = self.model.training_step(batch, batch_idx)
w_step_loss = step_output['loss'] \ w_step_loss = step_output['loss'] if isinstance(step_output, dict) else step_output
if isinstance(step_output, dict) else step_output self.advance_optimization(w_step_loss, batch_idx)
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
else: else:
# train ENAS agent # train ENAS agent
arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad()
self.resample()
# Run a sample to retrieve the reward
self.resample()
step_output = self.model.validation_step(batch, batch_idx) step_output = self.model.validation_step(batch, batch_idx)
# use the default metric of self.model as reward function # use the default metric of self.model as reward function
...@@ -218,11 +223,13 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -218,11 +223,13 @@ class EnasLightningModule(RandomSamplingLightningModule):
if metric_name not in self.trainer.callback_metrics: if metric_name not in self.trainer.callback_metrics:
raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but ' raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but '
f'found multiple (or zero) metrics without default: {list(self.trainer.callback_metrics.keys())}. ' f'found multiple (or zero) metrics without default: {list(self.trainer.callback_metrics.keys())}. '
f'Try to use self.log to report metrics with the specified key ``{metric_name}`` in validation_step, ' f'Please try to set ``reward_metric_name`` to be one of the keys listed above. '
'and remember to set on_step=True.') f'If it is not working use self.log to report metrics with the specified key ``{metric_name}`` '
'in validation_step, and remember to set on_step=True.')
metric = self.trainer.callback_metrics[metric_name] metric = self.trainer.callback_metrics[metric_name]
reward: float = metric.item() reward: float = metric.item()
# Compute the loss and run back propagation
if self.entropy_weight: if self.entropy_weight:
reward = reward + self.entropy_weight * self.controller.sample_entropy.item() # type: ignore reward = reward + self.entropy_weight * self.controller.sample_entropy.item() # type: ignore
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
...@@ -236,11 +243,29 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -236,11 +243,29 @@ class EnasLightningModule(RandomSamplingLightningModule):
if (batch_idx + 1) % self.ctrl_steps_aggregate == 0: if (batch_idx + 1) % self.ctrl_steps_aggregate == 0:
if self.ctrl_grad_clip > 0: if self.ctrl_grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.ctrl_grad_clip) nn.utils.clip_grad_norm_(self.controller.parameters(), self.ctrl_grad_clip)
# Update the controller and zero out its gradients
arc_opt = cast(optim.Optimizer, self.architecture_optimizers())
arc_opt.step() arc_opt.step()
arc_opt.zero_grad() arc_opt.zero_grad()
self.advance_lr_schedulers(batch_idx)
if (batch_idx + 1) % self.log_prob_every_n_step == 0:
with torch.no_grad():
self.log_dict({'prob/' + k: v for k, v in self.export_probs().items()})
return step_output return step_output
def on_train_epoch_start(self):
# Always zero out the gradients of ENAS controller at the beginning of epochs.
arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad()
return self.model.on_train_epoch_start()
def resample(self): def resample(self):
"""Resample the architecture with ENAS controller.""" """Resample the architecture with ENAS controller."""
sample = self.controller.resample() sample = self.controller.resample()
...@@ -249,6 +274,14 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -249,6 +274,14 @@ class EnasLightningModule(RandomSamplingLightningModule):
module.resample(memo=result) module.resample(memo=result)
return result return result
def export_probs(self):
"""Export the probability from ENAS controller directly."""
sample = self.controller.resample(return_prob=True)
result = self._interpret_controller_probability_result(sample)
for module in self.nas_modules:
module.resample(memo=result)
return result
def export(self): def export(self):
"""Run one more inference of ENAS controller.""" """Run one more inference of ENAS controller."""
self.controller.eval() self.controller.eval()
...@@ -261,3 +294,14 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -261,3 +294,14 @@ class EnasLightningModule(RandomSamplingLightningModule):
for key in list(sample.keys()): for key in list(sample.keys()):
sample[key] = space_spec[key].values[sample[key]] sample[key] = space_spec[key].values[sample[key]]
return sample return sample
def _interpret_controller_probability_result(self, sample: dict[str, list[float]]) -> dict[str, Any]:
"""Convert ``{label: [prob1, prob2, prob3]} to ``{label/choice: prob}``"""
space_spec = self.search_space_spec()
result = {}
for key in list(sample.keys()):
if len(space_spec[key].values) != len(sample[key]):
raise ValueError(f'Expect {space_spec[key].values} to be of the same length as {sample[key]}')
for value, weight in zip(space_spec[key].values, sample[key]):
result[f'{key}/{value}'] = weight
return result
...@@ -168,11 +168,11 @@ def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence ...@@ -168,11 +168,11 @@ def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence
assert len(items) == len(weights) > 0 assert len(items) == len(weights) > 0
elem = items[0] elem = items[0]
unsupported_msg = f'Unsupported element type in weighted sum: {type(elem)}. Value is: {elem}' unsupported_msg = 'Unsupported element type in weighted sum: {}. Value is: {}'
if isinstance(elem, str): if isinstance(elem, str):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion. # Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise TypeError(unsupported_msg) raise TypeError(unsupported_msg.format(type(elem), elem))
try: try:
if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)): if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)):
......
...@@ -56,6 +56,17 @@ class BaseSuperNetModule(nn.Module): ...@@ -56,6 +56,17 @@ class BaseSuperNetModule(nn.Module):
""" """
raise NotImplementedError() raise NotImplementedError()
def export_probs(self, memo: dict[str, Any]) -> dict[str, Any]:
"""
Export the probability / logits of every choice got chosen.
Parameters
----------
memo : dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise NotImplementedError()
def search_space_spec(self) -> dict[str, ParameterSpec]: def search_space_spec(self) -> dict[str, ParameterSpec]:
""" """
Space specification (sample points). Space specification (sample points).
......
...@@ -104,6 +104,13 @@ class DifferentiableMixedLayer(BaseSuperNetModule): ...@@ -104,6 +104,13 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
return {} # nothing new to export return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]} return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def export_probs(self, memo):
if any(k.startswith(self.label + '/') for k in memo):
return {} # nothing new
weights = self._softmax(self._arch_alpha).cpu().tolist()
ret = {f'{self.label}/{name}': value for name, value in zip(self.op_names, weights)}
return ret
def search_space_spec(self): def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ), return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))} True, size=len(self.op_names))}
...@@ -117,7 +124,8 @@ class DifferentiableMixedLayer(BaseSuperNetModule): ...@@ -117,7 +124,8 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
if len(alpha) != size: if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}') raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else: else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later
memo[module.label] = alpha
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(list(module.named_children()), alpha, softmax, module.label) return cls(list(module.named_children()), alpha, softmax, module.label)
...@@ -208,6 +216,13 @@ class DifferentiableMixedInput(BaseSuperNetModule): ...@@ -208,6 +216,13 @@ class DifferentiableMixedInput(BaseSuperNetModule):
chosen = chosen[0] chosen = chosen[0]
return {self.label: chosen} return {self.label: chosen}
def export_probs(self, memo):
if any(k.startswith(self.label + '/') for k in memo):
return {} # nothing new
weights = self._softmax(self._arch_alpha).cpu().tolist()
ret = {f'{self.label}/{index}': value for index, value in enumerate(weights)}
return ret
def search_space_spec(self): def search_space_spec(self):
return { return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)), self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
...@@ -225,7 +240,8 @@ class DifferentiableMixedInput(BaseSuperNetModule): ...@@ -225,7 +240,8 @@ class DifferentiableMixedInput(BaseSuperNetModule):
if len(alpha) != size: if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}') raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else: else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later
memo[module.label] = alpha
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label) return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label)
...@@ -284,6 +300,7 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): ...@@ -284,6 +300,7 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}') raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else: else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3) alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
memo[name] = alpha
operation._arch_alpha[name] = alpha operation._arch_alpha[name] = alpha
operation.parameters = functools.partial(self.parameters, module=operation) # bind self operation.parameters = functools.partial(self.parameters, module=operation) # bind self
...@@ -321,6 +338,16 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): ...@@ -321,6 +338,16 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
result[name] = spec.values[chosen_index] result[name] = spec.values[chosen_index]
return result return result
def export_probs(self, operation: MixedOperation, memo: dict[str, Any]):
"""Export the weight for every leaf value choice."""
ret = {}
for name, spec in operation.search_space_spec().items():
if any(k.startswith(name + '/') for k in memo):
continue
weights = operation._softmax(operation._arch_alpha[name]).cpu().tolist() # type: ignore
ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)})
return ret
def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any: def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any:
if name in operation.mutable_arguments: if name in operation.mutable_arguments:
weights: dict[str, torch.Tensor] = { weights: dict[str, torch.Tensor] = {
...@@ -360,6 +387,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): ...@@ -360,6 +387,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}') raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else: else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3) alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
memo[name] = alpha
self._arch_alpha[name] = alpha self._arch_alpha[name] = alpha
def resample(self, memo): def resample(self, memo):
...@@ -376,6 +404,16 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): ...@@ -376,6 +404,16 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
result[name] = spec.values[chosen_index] result[name] = spec.values[chosen_index]
return result return result
def export_probs(self, memo):
"""Export the weight for every leaf value choice."""
ret = {}
for name, spec in self.search_space_spec().items():
if any(k.startswith(name + '/') for k in memo):
continue
weights = self._softmax(self._arch_alpha[name]).cpu().tolist()
ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)})
return ret
def search_space_spec(self): def search_space_spec(self):
return self._space_spec return self._space_spec
...@@ -427,6 +465,8 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): ...@@ -427,6 +465,8 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
class DifferentiableMixedCell(PathSamplingCell): class DifferentiableMixedCell(PathSamplingCell):
"""Implementation of Cell under differentiable context. """Implementation of Cell under differentiable context.
Similar to PathSamplingCell, this cell only handles cells of specific kinds (e.g., with loose end).
An architecture parameter is created on each edge of the full-connected graph. An architecture parameter is created on each edge of the full-connected graph.
""" """
...@@ -450,13 +490,21 @@ class DifferentiableMixedCell(PathSamplingCell): ...@@ -450,13 +490,21 @@ class DifferentiableMixedCell(PathSamplingCell):
op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j] op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j]
if edge_label in memo: if edge_label in memo:
alpha = memo[edge_label] alpha = memo[edge_label]
if len(alpha) != len(op) + 1:
if len(alpha) != len(op): if len(alpha) != len(op):
raise ValueError( raise ValueError(
f'Architecture parameter size of same label {edge_label} conflict: ' f'Architecture parameter size of same label {edge_label} conflict: '
f'{len(alpha)} vs. {len(op)}' f'{len(alpha)} vs. {len(op)}'
) )
warnings.warn(
f'Architecture parameter size {len(alpha)} is not same as expected: {len(op) + 1}. '
'This is likely due to the label being shared by a LayerChoice inside the cell and outside.',
UserWarning
)
else: else:
alpha = nn.Parameter(torch.randn(len(op)) * 1E-3) # +1 to emulate the input choice.
alpha = nn.Parameter(torch.randn(len(op) + 1) * 1E-3)
memo[edge_label] = alpha
self._arch_alpha[edge_label] = alpha self._arch_alpha[edge_label] = alpha
self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
...@@ -465,18 +513,32 @@ class DifferentiableMixedCell(PathSamplingCell): ...@@ -465,18 +513,32 @@ class DifferentiableMixedCell(PathSamplingCell):
"""Differentiable doesn't need to resample.""" """Differentiable doesn't need to resample."""
return {} return {}
def export_probs(self, memo):
"""When export probability, we follow the structure in arch alpha."""
ret = {}
for name, parameter in self._arch_alpha.items():
if any(k.startswith(name + '/') for k in memo):
continue
weights = self._softmax(parameter).cpu().tolist()
ret.update({f'{name}/{value}': weight for value, weight in zip(self.op_names, weights)})
return ret
def export(self, memo): def export(self, memo):
"""Tricky export. """Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135 Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
""" """
exported = {} exported = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors): for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
# If label already exists, no need to re-export.
if all(f'{self.label}/op_{i}_{k}' in memo and f'{self.label}/input_{i}_{k}' in memo for k in range(self.num_ops_per_node)):
continue
# Tuple of (weight, input_index, op_name) # Tuple of (weight, input_index, op_name)
all_weights: list[tuple[float, int, str]] = [] all_weights: list[tuple[float, int, str]] = []
for j in range(i): for j in range(i):
for k, name in enumerate(self.op_names): for k, name in enumerate(self.op_names):
# The last appended weight is automatically skipped in export.
all_weights.append(( all_weights.append((
float(self._arch_alpha[f'{self.label}/{i}_{j}'][k].item()), float(self._arch_alpha[f'{self.label}/{i}_{j}'][k].item()),
j, name, j, name,
...@@ -497,7 +559,7 @@ class DifferentiableMixedCell(PathSamplingCell): ...@@ -497,7 +559,7 @@ class DifferentiableMixedCell(PathSamplingCell):
all_weights = [all_weights[k] for k in first_occurrence_index] + \ all_weights = [all_weights[k] for k in first_occurrence_index] + \
[w for j, w in enumerate(all_weights) if j not in first_occurrence_index] [w for j, w in enumerate(all_weights) if j not in first_occurrence_index]
_logger.info('Sorted weights in differentiable cell export (node %d): %s', i, all_weights) _logger.info('Sorted weights in differentiable cell export (%s cell, node %d): %s', self.label, i, all_weights)
for k in range(self.num_ops_per_node): for k in range(self.num_ops_per_node):
# all_weights could be too short in case ``num_ops_per_node`` is too large. # all_weights could be too short in case ``num_ops_per_node`` is too large.
...@@ -515,7 +577,11 @@ class DifferentiableMixedCell(PathSamplingCell): ...@@ -515,7 +577,11 @@ class DifferentiableMixedCell(PathSamplingCell):
for j in range(i): # for every previous tensors for j in range(i): # for every previous tensors
op_results = torch.stack([op(states[j]) for op in ops[j].values()]) op_results = torch.stack([op(states[j]) for op in ops[j].values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) # (-1, 1, 1, 1, 1, ...)
op_weights = self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}'])
if len(op_weights) == len(op_results) + 1:
# concatenate with a zero operation, indicating this path is not chosen at all.
op_results = torch.cat((op_results, torch.zeros_like(op_results[:1])), 0)
edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0) edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0)
current_state.append(edge_sum) current_state.append(edge_sum)
......
...@@ -71,6 +71,10 @@ class MixedOperationSamplingPolicy: ...@@ -71,6 +71,10 @@ class MixedOperationSamplingPolicy:
"""The handler of :meth:`MixedOperation.export`.""" """The handler of :meth:`MixedOperation.export`."""
raise NotImplementedError() raise NotImplementedError()
def export_probs(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.export_probs`."""
raise NotImplementedError()
def forward_argument(self, operation: 'MixedOperation', name: str) -> Any: def forward_argument(self, operation: 'MixedOperation', name: str) -> Any:
"""Computing the argument with ``name`` used in operation's forward. """Computing the argument with ``name`` used in operation's forward.
Usually a value, or a distribution of value. Usually a value, or a distribution of value.
...@@ -162,6 +166,10 @@ class MixedOperation(BaseSuperNetModule): ...@@ -162,6 +166,10 @@ class MixedOperation(BaseSuperNetModule):
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`.""" """Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
return self.sampling_policy.resample(self, memo) return self.sampling_policy.resample(self, memo)
def export_probs(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export_probs`."""
return self.sampling_policy.export_probs(self, memo)
def export(self, memo): def export(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`.""" """Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
return self.sampling_policy.export(self, memo) return self.sampling_policy.export(self, memo)
......
...@@ -11,7 +11,7 @@ The support remains limited. Known limitations include: ...@@ -11,7 +11,7 @@ The support remains limited. Known limitations include:
from __future__ import annotations from __future__ import annotations
from typing import cast from typing import Any, Tuple, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -21,28 +21,115 @@ from .differentiable import DifferentiableMixedLayer, DifferentiableMixedInput ...@@ -21,28 +21,115 @@ from .differentiable import DifferentiableMixedLayer, DifferentiableMixedInput
__all__ = ['ProxylessMixedLayer', 'ProxylessMixedInput'] __all__ = ['ProxylessMixedLayer', 'ProxylessMixedInput']
class _ArchGradientFunction(torch.autograd.Function): def _detach_tensor(tensor: Any) -> Any:
@staticmethod """Recursively detach all the tensors."""
def forward(ctx, x, binary_gates, run_func, backward_func): if isinstance(tensor, (list, tuple)):
ctx.run_func = run_func return tuple(_detach_tensor(t) for t in tensor)
ctx.backward_func = backward_func elif isinstance(tensor, dict):
return {k: _detach_tensor(v) for k, v in tensor.items()}
elif isinstance(tensor, torch.Tensor):
return tensor.detach()
else:
return tensor
detached_x = x.detach() def _iter_tensors(tensor: Any) -> Any:
detached_x.requires_grad = x.requires_grad """Recursively iterate over all the tensors.
with torch.enable_grad():
output = run_func(detached_x) This is kept for complex outputs (like dicts / lists).
ctx.save_for_backward(detached_x, output) However, complex outputs are not supported by PyTorch backward hooks yet.
return output.data """
if isinstance(tensor, torch.Tensor):
yield tensor
elif isinstance(tensor, (list, tuple)):
for t in tensor:
yield from _iter_tensors(t)
elif isinstance(tensor, dict):
for t in tensor.values():
yield from _iter_tensors(t)
def _pack_as_tuple(tensor: Any) -> tuple:
"""Return a tuple of tensor with only one element if tensor it's not a tuple."""
if isinstance(tensor, (tuple, list)):
return tuple(tensor)
return (tensor,)
def element_product_sum(tensor1: tuple[torch.Tensor, ...], tensor2: tuple[torch.Tensor, ...]) -> torch.Tensor:
"""Compute the sum of all the element-wise product."""
assert len(tensor1) == len(tensor2), 'The number of tensors must be the same.'
# Skip zero gradients
ret = [torch.sum(t1 * t2) for t1, t2 in zip(tensor1, tensor2) if t1 is not None and t2 is not None]
if not ret:
return torch.tensor(0)
if len(ret) == 1:
return ret[0]
return cast(torch.Tensor, sum(ret))
class ProxylessContext:
def __init__(self, arch_alpha: torch.Tensor, softmax: nn.Module) -> None:
self.arch_alpha = arch_alpha
self.softmax = softmax
# When a layer is called multiple times, the inputs and outputs are saved in order.
# In backward propagation, we assume that they are used in the reversed order.
self.layer_input: list[Any] = []
self.layer_output: list[Any] = []
self.layer_sample_idx: list[int] = []
def clear_context(self) -> None:
self.layer_input = []
self.layer_output = []
self.layer_sample_idx = []
def save_forward_context(self, layer_input: Any, layer_output: Any, layer_sample_idx: int):
self.layer_input.append(_detach_tensor(layer_input))
self.layer_output.append(_detach_tensor(layer_output))
self.layer_sample_idx.append(layer_sample_idx)
def backward_hook(self, module: nn.Module,
grad_input: Union[Tuple[torch.Tensor, ...], torch.Tensor],
grad_output: Union[Tuple[torch.Tensor, ...], torch.Tensor]) -> None:
# binary_grads is the gradient of binary gates.
# Binary gates is a one-hot tensor where 1 is on the sampled index, and others are 0.
# By chain rule, it's gradient is grad_output times the layer_output (of the corresponding path).
binary_grads = torch.zeros_like(self.arch_alpha)
# Retrieve the layer input/output in reverse order.
if not self.layer_input:
raise ValueError('Unexpected backward call. The saved context is empty.')
layer_input = self.layer_input.pop()
layer_output = self.layer_output.pop()
layer_sample_idx = self.layer_sample_idx.pop()
@staticmethod with torch.no_grad():
def backward(ctx, grad_output): # Compute binary grads.
detached_x, output = ctx.saved_tensors for k in range(len(binary_grads)):
if k != layer_sample_idx:
args, kwargs = layer_input
out_k = module.forward_path(k, *args, **kwargs) # type: ignore
else:
out_k = layer_output
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True) # FIXME: One limitation here is that out_k can't be complex objects like dict.
# compute gradients w.r.t. binary_gates # I think it's also a limitation of backward hook.
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data) binary_grads[k] = element_product_sum(
_pack_as_tuple(out_k), # In case out_k is a single tensor
_pack_as_tuple(grad_output)
)
return grad_x[0], binary_grads, None, None # Compute the gradient of the arch_alpha, based on binary_grads.
if self.arch_alpha.grad is None:
self.arch_alpha.grad = torch.zeros_like(self.arch_alpha)
probs = self.softmax(self.arch_alpha)
for i in range(len(self.arch_alpha)):
for j in range(len(self.arch_alpha)):
# Arch alpha's gradients are accumulated for all backwards through this layer.
self.arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessMixedLayer(DifferentiableMixedLayer): class ProxylessMixedLayer(DifferentiableMixedLayer):
...@@ -50,46 +137,32 @@ class ProxylessMixedLayer(DifferentiableMixedLayer): ...@@ -50,46 +137,32 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
It resamples a single-path every time, rather than go through the softmax. It resamples a single-path every time, rather than go through the softmax.
""" """
_arch_parameter_names = ['_arch_alpha', '_binary_gates'] _arch_parameter_names = ['_arch_alpha']
def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str): def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(paths, alpha, softmax, label) super().__init__(paths, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(len(paths)) * 1E-3) # Binary gates should be created here, but it's not because it's never used in the forward pass.
# self._binary_gates = nn.Parameter(torch.zeros(len(paths)))
# like sampling-based methods, it has a ``_sampled``. # like sampling-based methods, it has a ``_sampled``.
self._sampled: str | None = None self._sampled: str | None = None
self._sample_idx: int | None = None self._sample_idx: int | None = None
def forward(self, *args, **kwargs): # arch_alpha could be shared by multiple layers,
def run_function(ops, active_id, **kwargs): # but binary_gates is owned by the current layer.
def forward(_x): self.ctx = ProxylessContext(alpha, softmax)
return ops[active_id](_x, **kwargs) self.register_full_backward_hook(self.ctx.backward_hook)
return forward
def backward_function(ops, active_id, binary_gates, **kwargs):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data, **kwargs)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1, 'ProxylessMixedLayer only supports exactly one input argument.'
x = args[0]
assert self._sampled is not None, 'Need to call resample() before running fprop.' def forward(self, *args, **kwargs):
list_ops = [getattr(self, op) for op in self.op_names] """Forward pass of one single path."""
if self._sample_idx is None:
raise RuntimeError('resample() needs to be called before fprop.')
output = self.forward_path(self._sample_idx, *args, **kwargs)
self.ctx.save_forward_context((args, kwargs), output, self._sample_idx)
return output
return _ArchGradientFunction.apply( def forward_path(self, index, *args, **kwargs):
x, self._binary_gates, run_function(list_ops, self._sample_idx, **kwargs), return getattr(self, self.op_names[index])(*args, **kwargs)
backward_function(list_ops, self._sample_idx, self._binary_gates, **kwargs)
)
def resample(self, memo): def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo.""" """Sample one path based on alpha if label is not found in memo."""
...@@ -101,66 +174,37 @@ class ProxylessMixedLayer(DifferentiableMixedLayer): ...@@ -101,66 +174,37 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
self._sample_idx = int(torch.multinomial(probs, 1)[0].item()) self._sample_idx = int(torch.multinomial(probs, 1)[0].item())
self._sampled = self.op_names[self._sample_idx] self._sampled = self.op_names[self._sample_idx]
# set binary gates self.ctx.clear_context()
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[self._sample_idx] = 1.0
return {self.label: self._sampled} return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(len(self._arch_alpha)):
for j in range(len(self._arch_alpha)):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessMixedInput(DifferentiableMixedInput): class ProxylessMixedInput(DifferentiableMixedInput):
"""Proxyless version of differentiable input choice. """Proxyless version of differentiable input choice.
See :class:`ProxylessLayerChoice` for implementation details. See :class:`ProxylessMixedLayer` for implementation details.
""" """
_arch_parameter_names = ['_arch_alpha', '_binary_gates'] _arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str): def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(n_candidates, n_chosen, alpha, softmax, label) super().__init__(n_candidates, n_chosen, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(n_candidates) * 1E-3)
# We only support choosing a particular one here.
# Nevertheless, we rank the score and export the tops in export.
self._sampled: int | None = None self._sampled: int | None = None
self.ctx = ProxylessContext(alpha, softmax)
self.register_full_backward_hook(self.ctx.backward_hook)
def forward(self, inputs): def forward(self, inputs):
def run_function(active_sample): """Choose one single input."""
return lambda x: x[active_sample] if self._sampled is None:
raise RuntimeError('resample() needs to be called before fprop.')
output = self.forward_path(self._sampled, inputs)
self.ctx.save_forward_context(((inputs,), {}), output, self._sampled)
return output
def backward_function(binary_gates): def forward_path(self, index, inputs):
def backward(_x, _output, grad_output): return inputs[index]
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(self.n_candidates):
out_k = _x[k].data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
inputs = torch.stack(inputs, 0)
assert self._sampled is not None, 'Need to call resample() before running fprop.'
return _ArchGradientFunction.apply(
inputs, self._binary_gates, run_function(self._sampled),
backward_function(self._binary_gates)
)
def resample(self, memo): def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo.""" """Sample one path based on alpha if label is not found in memo."""
...@@ -171,27 +215,6 @@ class ProxylessMixedInput(DifferentiableMixedInput): ...@@ -171,27 +215,6 @@ class ProxylessMixedInput(DifferentiableMixedInput):
sample = torch.multinomial(probs, 1)[0].item() sample = torch.multinomial(probs, 1)[0].item()
self._sampled = int(sample) self._sampled = int(sample)
# set binary gates self.ctx.clear_context()
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[cast(int, self._sampled)] = 1.0
return {self.label: self._sampled} return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: torch.argmax(self._arch_alpha).item()}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(self.n_candidates):
for j in range(self.n_candidates):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
...@@ -169,7 +169,7 @@ class PathSamplingInput(BaseSuperNetModule): ...@@ -169,7 +169,7 @@ class PathSamplingInput(BaseSuperNetModule):
class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
"""Implementes the path sampling in mixed operation. """Implements the path sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments. One mixed operation can have multiple value choices in its arguments.
Each value choice can be further decomposed into "leaf value choices". Each value choice can be further decomposed into "leaf value choices".
...@@ -388,6 +388,10 @@ class PathSamplingCell(BaseSuperNetModule): ...@@ -388,6 +388,10 @@ class PathSamplingCell(BaseSuperNetModule):
@classmethod @classmethod
def mutate(cls, module, name, memo, mutate_kwargs): def mutate(cls, module, name, memo, mutate_kwargs):
"""
Mutate only handles cells of specific configurations (e.g., with loose end).
Fallback to the default mutate if the cell is not handled here.
"""
if isinstance(module, Cell): if isinstance(module, Cell):
op_factory = None # not all the cells need to be replaced op_factory = None # not all the cells need to be replaced
if module.op_candidates_factory is not None: if module.op_candidates_factory is not None:
......
...@@ -5,6 +5,7 @@ import pytorch_lightning as pl ...@@ -5,6 +5,7 @@ import pytorch_lightning as pl
import pytest import pytest
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torch import nn
from torch.utils.data import Dataset, RandomSampler from torch.utils.data import Dataset, RandomSampler
import nni import nni
...@@ -13,7 +14,11 @@ from nni.retiarii import strategy, model_wrapper, basic_unit ...@@ -13,7 +14,11 @@ from nni.retiarii import strategy, model_wrapper, basic_unit
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.evaluator.pytorch.lightning import Classification, Regression, DataLoader from nni.retiarii.evaluator.pytorch.lightning import Classification, Regression, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
from nni.retiarii.oneshot.pytorch import DartsLightningModule
from nni.retiarii.strategy import BaseStrategy from nni.retiarii.strategy import BaseStrategy
from pytorch_lightning import LightningModule, Trainer
from .test_oneshot_utils import RandomDataset
pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
...@@ -338,17 +343,49 @@ def test_gumbel_darts(): ...@@ -338,17 +343,49 @@ def test_gumbel_darts():
_test_strategy(strategy.GumbelDARTS()) _test_strategy(strategy.GumbelDARTS())
if __name__ == '__main__': def test_optimizer_lr_scheduler():
parser = argparse.ArgumentParser() learning_rates = []
parser.add_argument('--exp', type=str, default='all', metavar='E',
help='experiment to run, default = all')
args = parser.parse_args()
if args.exp == 'all': class CustomLightningModule(LightningModule):
test_darts() def __init__(self):
test_proxyless() super().__init__()
test_enas() self.layer1 = nn.Linear(32, 2)
test_random() self.layer2 = nn.LayerChoice([nn.Linear(2, 2), nn.Linear(2, 2, bias=False)])
test_gumbel_darts()
else: def forward(self, x):
globals()[f'test_{args.exp}']() return self.layer2(self.layer1(x))
def configure_optimizers(self):
opt1 = torch.optim.SGD(self.layer1.parameters(), lr=0.1)
opt2 = torch.optim.Adam(self.layer2.parameters(), lr=0.2)
return [opt1, opt2], [torch.optim.lr_scheduler.StepLR(opt1, step_size=2, gamma=0.1)]
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('train_loss', loss)
return {'loss': loss}
def on_train_epoch_start(self) -> None:
learning_rates.append(self.optimizers()[0].param_groups[0]['lr'])
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('test_loss', loss)
train_data = RandomDataset(32, 32)
valid_data = RandomDataset(32, 16)
model = CustomLightningModule()
darts_module = DartsLightningModule(model, gradient_clip_val=5)
trainer = Trainer(max_epochs=10)
trainer.fit(
darts_module,
dict(train=DataLoader(train_data, batch_size=8), val=DataLoader(valid_data, batch_size=8))
)
assert len(learning_rates) == 10 and abs(learning_rates[0] - 0.1) < 1e-5 and \
abs(learning_rates[2] - 0.01) < 1e-5 and abs(learning_rates[-1] - 1e-5) < 1e-6
import torch
import torch.nn as nn
from nni.nas.hub.pytorch.nasbench201 import OPS_WITH_STRIDE
from nni.nas.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput, _iter_tensors
def test_proxyless_bp():
op = ProxylessMixedLayer(
[(name, value(3, 3, 1)) for name, value in OPS_WITH_STRIDE.items()],
nn.Parameter(torch.randn(len(OPS_WITH_STRIDE))),
nn.Softmax(-1), 'proxyless'
)
optimizer = torch.optim.SGD(op.parameters(arch=True), 0.1)
for _ in range(10):
x = torch.randn(1, 3, 9, 9).requires_grad_()
op.resample({})
y = op(x).sum()
optimizer.zero_grad()
y.backward()
assert op._arch_alpha.grad.abs().sum().item() != 0
def test_proxyless_input():
inp = ProxylessMixedInput(6, 2, nn.Parameter(torch.zeros(6)), nn.Softmax(-1), 'proxyless')
optimizer = torch.optim.SGD(inp.parameters(arch=True), 0.1)
for _ in range(10):
x = [torch.randn(1, 3, 9, 9).requires_grad_() for _ in range(6)]
inp.resample({})
y = inp(x).sum()
optimizer.zero_grad()
y.backward()
def test_iter_tensors():
a = (torch.zeros(3, 1), {'a': torch.zeros(5, 1), 'b': torch.zeros(6, 1)}, [torch.zeros(7, 1)])
ret = []
for x in _iter_tensors(a):
ret.append(x.shape[0])
assert ret == [3, 5, 6, 7]
class MultiInputLayer(nn.Module):
def __init__(self, d):
super().__init__()
self.d = d
def forward(self, q, k, v=None, mask=None):
return q + self.d, 2 * k - 2 * self.d, v, mask
def test_proxyless_multi_input():
op = ProxylessMixedLayer(
[
('a', MultiInputLayer(1)),
('b', MultiInputLayer(3))
],
nn.Parameter(torch.randn(2)),
nn.Softmax(-1), 'proxyless'
)
optimizer = torch.optim.SGD(op.parameters(arch=True), 0.1)
for retry in range(10):
q = torch.randn(1, 3, 9, 9).requires_grad_()
k = torch.randn(1, 3, 9, 8).requires_grad_()
v = None if retry < 5 else torch.randn(1, 3, 9, 7).requires_grad_()
mask = None if retry % 5 < 2 else torch.randn(1, 3, 9, 6).requires_grad_()
op.resample({})
y = op(q, k, v, mask=mask)
y = y[0].sum() + y[1].sum()
optimizer.zero_grad()
y.backward()
assert op._arch_alpha.grad.abs().sum().item() != 0, op._arch_alpha.grad
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, LayerNorm, Linear, MultiheadAttention from nni.retiarii.nn.pytorch import ValueChoice, LayerChoice, Conv2d, BatchNorm2d, LayerNorm, Linear, MultiheadAttention
from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import ( from nni.retiarii.oneshot.pytorch.supermodule.differentiable import (
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax, MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax,
...@@ -144,6 +144,16 @@ def test_differentiable_valuechoice(): ...@@ -144,6 +144,16 @@ def test_differentiable_valuechoice():
assert set(conv.export({}).keys()) == {'123', '456'} assert set(conv.export({}).keys()) == {'123', '456'}
def test_differentiable_layerchoice_dedup():
layerchoice1 = LayerChoice([Conv2d(3, 3, 3), Conv2d(3, 3, 3)], label='a')
layerchoice2 = LayerChoice([Conv2d(3, 3, 3), Conv2d(3, 3, 3)], label='a')
memo = {}
DifferentiableMixedLayer.mutate(layerchoice1, 'x', memo, {})
DifferentiableMixedLayer.mutate(layerchoice2, 'x', memo, {})
assert len(memo) == 1 and 'a' in memo
def _mixed_operation_sampling_sanity_check(operation, memo, *input): def _mixed_operation_sampling_sanity_check(operation, memo, *input):
for native_op in NATIVE_MIXED_OPERATIONS: for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation): if native_op.bound_type == type(operation):
...@@ -160,7 +170,9 @@ def _mixed_operation_differentiable_sanity_check(operation, *input): ...@@ -160,7 +170,9 @@ def _mixed_operation_differentiable_sanity_check(operation, *input):
mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpDifferentiablePolicy}) mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
break break
return mutate_op(*input) mutate_op(*input)
mutate_op.export({})
mutate_op.export_probs({})
def test_mixed_linear(): def test_mixed_linear():
...@@ -319,6 +331,9 @@ def test_differentiable_layer_input(): ...@@ -319,6 +331,9 @@ def test_differentiable_layer_input():
op = DifferentiableMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee') op = DifferentiableMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
assert op(torch.randn(4, 2)).size(-1) == 3 assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b'] assert op.export({})['eee'] in ['a', 'b']
probs = op.export_probs({})
assert len(probs) == 2
assert abs(probs['eee/a'] + probs['eee/b'] - 1) < 1e-4
assert len(list(op.parameters())) == 3 assert len(list(op.parameters())) == 3
with pytest.raises(ValueError): with pytest.raises(ValueError):
...@@ -328,6 +343,8 @@ def test_differentiable_layer_input(): ...@@ -328,6 +343,8 @@ def test_differentiable_layer_input():
input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd') input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2 assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2
assert len(input.export({})['ddd']) == 2 assert len(input.export({})['ddd']) == 2
assert len(input.export_probs({})) == 5
assert 'ddd/3' in input.export_probs({})
def test_proxyless_layer_input(): def test_proxyless_layer_input():
...@@ -341,7 +358,8 @@ def test_proxyless_layer_input(): ...@@ -341,7 +358,8 @@ def test_proxyless_layer_input():
input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd') input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input.resample({})['ddd'] in list(range(5)) assert input.resample({})['ddd'] in list(range(5))
assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2]) assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2])
assert input.export({})['ddd'] in list(range(5)) exported = input.export({})['ddd']
assert len(exported) == 2 and all(e in list(range(5)) for e in exported)
def test_pathsampling_repeat(): def test_pathsampling_repeat():
...@@ -373,6 +391,7 @@ def test_differentiable_repeat(): ...@@ -373,6 +391,7 @@ def test_differentiable_repeat():
assert op(torch.randn(2, 8)).size() == torch.Size([2, 16]) assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
sample = op.export({}) sample = op.export({})
assert 'ccc' in sample and sample['ccc'] in [0, 1] assert 'ccc' in sample and sample['ccc'] in [0, 1]
assert sorted(op.export_probs({}).keys()) == ['ccc/0', 'ccc/1']
class TupleModule(nn.Module): class TupleModule(nn.Module):
def __init__(self, num): def __init__(self, num):
...@@ -452,11 +471,16 @@ def test_differentiable_cell(): ...@@ -452,11 +471,16 @@ def test_differentiable_cell():
result.update(module.export(memo=result)) result.update(module.export(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2 assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
result_prob = {}
for module in nas_modules:
result_prob.update(module.export_probs(memo=result_prob))
ctrl_params = [] ctrl_params = []
for m in nas_modules: for m in nas_modules:
ctrl_params += list(m.parameters(arch=True)) ctrl_params += list(m.parameters(arch=True))
if cell_cls in [CellLooseEnd, CellOpFactory]: if cell_cls in [CellLooseEnd, CellOpFactory]:
assert len(ctrl_params) == model.cell.num_nodes * (model.cell.num_nodes + 3) // 2 assert len(ctrl_params) == model.cell.num_nodes * (model.cell.num_nodes + 3) // 2
assert len(result_prob) == len(ctrl_params) * 2 # len(op_names) == 2
assert isinstance(model.cell, DifferentiableMixedCell) assert isinstance(model.cell, DifferentiableMixedCell)
else: else:
assert not isinstance(model.cell, DifferentiableMixedCell) assert not isinstance(model.cell, DifferentiableMixedCell)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment