Unverified Commit f77db747 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Enhancement of one-shot NAS (v2.9) (#5049)

parent 125ec21f
......@@ -70,7 +70,7 @@ class NasBench201Cell(nn.Module):
inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
node_ops.append(LayerChoice(op_choices, label=f'{self._label}/{j}_{tid}'))
self.layers.append(node_ops)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
......
......@@ -179,7 +179,7 @@ class NasBench201(nn.Module):
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
ops: Dict[str, Callable[[int, int], nn.Module]] = {
prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES
prim: self._make_op_factory(prim) for prim in PRIMITIVES
}
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
self.cells.append(cell)
......@@ -192,6 +192,9 @@ class NasBench201(nn.Module):
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, self.num_labels)
def _make_op_factory(self, prim):
return lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1)
def forward(self, inputs):
feature = self.stem(inputs)
for cell in self.cells:
......
......@@ -5,23 +5,21 @@ from __future__ import annotations
import warnings
from itertools import chain
from typing import Callable, Any, Dict, Union, Tuple, List, cast
from typing import Callable, Any, Dict, Union, Tuple, cast
import pytorch_lightning as pl
import torch.optim as optim
import torch.nn as nn
from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler
import nni.nas.nn.pytorch as nas_nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import is_traceable
from nni.nas.nn.pytorch.choice import ValueChoiceX
from nni.typehint import Literal
from .supermodule.base import BaseSuperNetModule
__all__ = [
'MANUAL_OPTIMIZATION_NOTE',
'MutationHook',
'BaseSuperNetModule',
'BaseOneShotLightningModule',
......@@ -30,6 +28,22 @@ __all__ = [
]
MANUAL_OPTIMIZATION_NOTE = """
.. warning::
The strategy, under the hood, creates a Lightning module that wraps the Lightning module defined in evaluator,
and enables `Manual optimization <https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html>`_,
although we assume **the inner evaluator has enabled automatic optimization**.
We call the optimizers and schedulers configured in evaluator, following the definition in Lightning at best effort,
but we make no guarantee that the behaviors are exactly same as automatic optimization.
We call :meth:`~BaseSuperNetModule.advance_optimization` and :meth:`~BaseSuperNetModule.advance_lr_schedulers`
to invoke the optimizers and schedulers configured in evaluators.
Moreover, some advanced features like gradient clipping will not be supported.
If you encounter any issues, please contact us by `creating an issue <https://github.com/microsoft/nni/issues>`_.
"""
MutationHook = Callable[[nn.Module, str, Dict[str, Any], Dict[str, Any]], Union[nn.Module, bool, Tuple[nn.Module, bool]]]
......@@ -122,7 +136,7 @@ def no_default_hook(module: nn.Module, name: str, memo: dict[str, Any], mutate_k
nas_nn.LayerChoice,
nas_nn.InputChoice,
nas_nn.Repeat,
# nas_nn.NasBench101Cell, # FIXME: nasbench101 is moved to hub, can't check any more.
# nas_nn.NasBench101Cell,
# nas_nn.ValueChoice, # could be false positive
# nas_nn.Cell, # later
# nas_nn.NasBench201Cell, # forward = supernet
......@@ -156,8 +170,8 @@ class BaseOneShotLightningModule(pl.LightningModule):
Extra mutation hooks to support customized mutation on primitives other than built-ins.
Mutation hooks are callable that inputs an Module and returns a
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.nas.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`.
They are invoked in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.traverse_and_mutate_submodules`, on each submodules.
For each submodule, the hook list are invoked subsequently,
the later hooks can see the result from previous hooks.
The modules that are processed by ``mutation_hooks`` will be replaced by the returned module,
......@@ -177,21 +191,21 @@ class BaseOneShotLightningModule(pl.LightningModule):
The returned arguments can be also one of the three kinds:
1. tuple of: :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
1. tuple of: :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None, and boolean,
2. boolean,
3. :class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
3. :class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule` or None.
The boolean value is ``suppress`` indicates whether the following hooks should be called.
When it's true, it suppresses the subsequent hooks, and they will never be invoked.
Without boolean value specified, it's assumed to be false.
If a none value appears on the place of
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
it means the hook suggests to
keep the module unchanged, and nothing will happen.
An example of mutation hook is given in :func:`~nni.nas.oneshot.pytorch.base_lightning.no_default_hook`.
An example of mutation hook is given in :func:`~nni.retiarii.oneshot.pytorch.base_lightning.no_default_hook`.
However it's recommended to implement mutation hooks by deriving
:class:`~nni.nas.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
:class:`~nni.retiarii.oneshot.pytorch.supermodule.base.BaseSuperNetModule`,
and add its classmethod ``mutate`` to this list.
"""
......@@ -295,236 +309,232 @@ class BaseOneShotLightningModule(pl.LightningModule):
result.update(module.export(memo=result))
return result
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
"""This is the implementation of what happens in training loops of one-shot algos.
It usually calls ``self.model.training_step`` which implements the real training recipe of the users' model.
def export_probs(self) -> dict[str, Any]:
"""
return self.model.training_step(batch, batch_idx)
Export the probability of every choice in the search space got chosen.
def configure_optimizers(self):
"""
Combine architecture optimizers and user's model optimizers.
You can overwrite :meth:`configure_architecture_optimizers` if architecture optimizers are needed in your NAS algorithm.
.. note:: If such method of some modules is not implemented, they will be simply ignored.
For now :attr:`model` is tested against evaluators in :mod:`nni.nas.evaluator.pytorch.lightning`
and it only returns 1 optimizer.
But for extendibility, codes for other return value types are also implemented.
Returns
-------
dict
In most cases, keys are names of ``nas_modules`` suffixed with ``/`` and choice name.
Values are the probability / logits depending on the implementation.
"""
# pylint: disable=assignment-from-none
arc_optimizers = self.configure_architecture_optimizers()
if arc_optimizers is None:
return self.model.configure_optimizers()
if isinstance(arc_optimizers, optim.Optimizer):
arc_optimizers = [arc_optimizers]
self.arc_optim_count = len(arc_optimizers)
result = {}
for module in self.nas_modules:
try:
result.update(module.export_probs(memo=result))
except NotImplementedError:
warnings.warn(
'Some super-modules you have used did not implement export_probs. You might find some logs are missing.',
UserWarning
)
return result
# FIXME: this part uses non-official lightning API.
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
try:
# above v1.6
from pytorch_lightning.core.optimizer import ( # pylint: disable=import-error
_configure_optimizers, # type: ignore
_configure_schedulers_automatic_opt, # type: ignore
_configure_schedulers_manual_opt # type: ignore
)
w_optimizers, lr_schedulers, self.frequencies, monitor = \
_configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = (
_configure_schedulers_automatic_opt(lr_schedulers, monitor)
if self.automatic_optimization
else _configure_schedulers_manual_opt(lr_schedulers)
)
except ImportError:
# under v1.5
w_optimizers, lr_schedulers, self.frequencies, monitor = \
self.trainer._configure_optimizers(self.model.configure_optimizers()) # type: ignore
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization) # type: ignore
if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers): # type: ignore
raise Exception(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
# variables used to handle optimizer frequency
self.cur_optimizer_step = 0
self.cur_optimizer_index = 0
return arc_optimizers + w_optimizers, lr_schedulers
def forward(self, x):
return self.model(x)
def on_train_start(self):
return self.model.on_train_start()
def configure_optimizers(self) -> Any:
"""
Transparently configure optimizers for the inner model,
unless one-shot algorithm has its own optimizer (via :meth:`configure_architecture_optimizers`),
in which case, the optimizer will be appended to the list.
def on_train_end(self):
return self.model.on_train_end()
The return value is still one of the 6 types defined in PyTorch-Lightning.
"""
arch_optimizers = self.configure_architecture_optimizers() or []
if not arch_optimizers: # no architecture optimizer available
return self.model.configure_optimizers()
def on_fit_start(self):
if isinstance(arch_optimizers, optim.Optimizer):
arch_optimizers = [arch_optimizers]
# Set the flag to True so that they can differ from other optimizers
for optimizer in arch_optimizers:
optimizer.is_arch_optimizer = True # type: ignore
optim_conf: Any = self.model.configure_optimizers()
# 0. optimizer is none
if optim_conf is None:
return arch_optimizers
# 1. single optimizer
if isinstance(optim_conf, Optimizer):
return [optim_conf] + arch_optimizers
# 2. two lists, optimizer + lr schedulers
if (
isinstance(optim_conf, (list, tuple))
and len(optim_conf) == 2
and isinstance(optim_conf[0], list)
and all(isinstance(opt, Optimizer) for opt in optim_conf[0])
):
return list(optim_conf[0]) + arch_optimizers, optim_conf[1]
# 3. single dictionary
if isinstance(optim_conf, dict):
return [optim_conf] + [{'optimizer': optimizer} for optimizer in arch_optimizers]
# 4. multiple dictionaries
if isinstance(optim_conf, (list, tuple)) and all(isinstance(d, dict) for d in optim_conf):
return list(optim_conf) + [{'optimizer': optimizer} for optimizer in arch_optimizers]
# 5. single list or tuple, multiple optimizer
if isinstance(optim_conf, (list, tuple)) and all(isinstance(opt, Optimizer) for opt in optim_conf):
return list(optim_conf) + arch_optimizers
# unknown configuration
warnings.warn('Unknown optimizer configuration. Architecture optimizers will be ignored. Strategy might fail.', UserWarning)
return optim_conf
def setup(self, stage=None):
# redirect the access to trainer/log to this module
# but note that we might be missing other attributes,
# which could potentially be a problem
self.model.trainer = self.trainer # type: ignore
self.model.log = self.log
return self.model.on_fit_start()
def on_fit_end(self):
return self.model.on_fit_end()
def on_train_batch_start(self, batch, batch_idx, *args, **kwargs):
return self.model.on_train_batch_start(batch, batch_idx, *args, **kwargs)
def on_train_batch_end(self, outputs, batch, batch_idx, *args, **kwargs):
return self.model.on_train_batch_end(outputs, batch, batch_idx, *args, **kwargs)
# Deprecated hooks in pytorch-lightning
def on_epoch_start(self):
return self.model.on_epoch_start()
def on_epoch_end(self):
return self.model.on_epoch_end()
# Reset the optimizer progress (only once at the very beginning)
self._optimizer_progress = 0
def on_train_epoch_start(self):
return self.model.on_train_epoch_start()
return self.model.setup(stage)
def on_train_epoch_end(self):
return self.model.on_train_epoch_end()
def on_before_backward(self, loss):
return self.model.on_before_backward(loss)
def teardown(self, stage=None):
return self.model.teardown(stage)
def on_after_backward(self):
return self.model.on_after_backward()
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val=None, gradient_clip_algorithm=None):
return self.model.configure_gradient_clipping(optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm)
def configure_architecture_optimizers(self):
def configure_architecture_optimizers(self) -> list[optim.Optimizer] | optim.Optimizer | None:
"""
Hook kept for subclasses. A specific NAS method inheriting this base class should return its architecture optimizers here
if architecture parameters are needed. Note that lr schedulers are not supported now for architecture_optimizers.
Returns
----------
arc_optimizers : list[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
-------
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
"""
return None
def call_lr_schedulers(self, batch_index):
def advance_optimization(
self,
loss: Any,
batch_idx: int,
gradient_clip_val: int | float | None = None,
gradient_clip_algorithm: str | None = None
):
"""
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off
by this class, you can use this function to make schedulers behave as they were automatically handled by the lightning trainer.
Run the optimizer defined in evaluators, when manual optimization is turned on.
Call this method when the model should be optimized.
To keep it as neat as possible, we only implement the basic ``zero_grad``, ``backward``, ``grad_clip``, and ``step`` here.
Many hooks and pre/post-processing are omitted.
Inherit this method if you need more advanced behavior.
The full optimizer step could be found
`here <https://github.com/Lightning-AI/lightning/blob/0e531283/src/pytorch_lightning/loops/optimization/optimizer_loop.py>`__.
We only implement part of the optimizer loop here.
Parameters
----------
batch_idx : int
batch index
batch_idx: int
The current batch index.
"""
def apply(lr_scheduler):
# single scheduler is called every epoch
if isinstance(lr_scheduler, _LRScheduler):
if self.trainer.is_last_batch:
lr_scheduler.step()
# lr_scheduler_config is called as configured
elif isinstance(lr_scheduler, dict):
interval = lr_scheduler['interval']
frequency = lr_scheduler['frequency']
if (
interval == 'step' and
batch_index % frequency == 0
) or \
(
interval == 'epoch' and
self.trainer.is_last_batch and
(self.trainer.current_epoch + 1) % frequency == 0
):
lr_scheduler['scheduler'].step()
if self.automatic_optimization:
raise ValueError('This method should not be used when automatic optimization is turned on.')
if self.trainer.optimizer_frequencies:
warnings.warn('optimizer_frequencies is not supported in NAS. It will be ignored.', UserWarning)
lr_schedulers = self.lr_schedulers()
# Filter out optimizers for architecture parameters
optimizers = [opt for opt in self.trainer.optimizers if not getattr(opt, 'is_arch_optimizer', False)]
if isinstance(lr_schedulers, list):
for lr_scheduler in lr_schedulers:
apply(lr_scheduler)
else:
apply(lr_schedulers)
opt_idx = self._optimizer_progress % len(optimizers)
optimizer = optimizers[opt_idx]
def call_weight_optimizers(self, method: Literal['step', 'zero_grad']):
# There should be many before/after hooks called here, but they are omitted in this implementation.
# 1. zero gradient
self.model.optimizer_zero_grad(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
# 2. backward
self.manual_backward(loss)
# 3. grad clip
self.model.configure_gradient_clipping(optimizer, opt_idx, gradient_clip_val, gradient_clip_algorithm)
# 4. optimizer step
self.model.optimizer_step(self.trainer.current_epoch, batch_idx, optimizer, opt_idx)
self._optimizer_progress += 1
def advance_lr_schedulers(self, batch_idx: int):
"""
Function that imitates lightning trainer's behavior of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Advance the learning rates, when manual optimization is turned on.
Parameters
----------
method : str
Method to call. Only ``step`` and ``zero_grad`` are supported now.
The full implementation is
`here <https://github.com/Lightning-AI/lightning/blob/0e531283/src/pytorch_lightning/loops/epoch/training_epoch_loop.py>`__.
We only include a partial implementation here.
Advanced features like Reduce-lr-on-plateau are not supported.
"""
def apply_method(optimizer, method):
if method == 'step':
optimizer.step()
elif method == 'zero_grad':
optimizer.zero_grad()
optimizers = self.weight_optimizers()
if optimizers is None:
return
assert isinstance(optimizers, list), 'Did you forget to set use_pl_optimizers to true?'
if len(self.frequencies) > 0:
self.cur_optimizer_step += 1
if self.frequencies[self.cur_optimizer_index] == self.cur_optimizer_step:
self.cur_optimizer_step = 0
self.cur_optimizer_index = self.cur_optimizer_index + 1 \
if self.cur_optimizer_index + 1 < len(optimizers) \
else 0
apply_method(optimizers[self.cur_optimizer_index], method)
else:
for optimizer in optimizers:
apply_method(optimizer, method)
if self.automatic_optimization:
raise ValueError('This method should not be used when automatic optimization is turned on.')
self._advance_lr_schedulers_impl(batch_idx, 'step')
if self.trainer.is_last_batch:
self._advance_lr_schedulers_impl(batch_idx, 'epoch')
def _advance_lr_schedulers_impl(self, batch_idx: int, interval: str):
current_idx = batch_idx if interval == 'step' else self.trainer.current_epoch
current_idx += 1 # account for both batch and epoch starts from 0
try:
# lightning >= 1.6
for config in self.trainer.lr_scheduler_configs:
scheduler, opt_idx = config.scheduler, config.opt_idx
if config.reduce_on_plateau:
warnings.warn('Reduce-lr-on-plateau is not supported in NAS. It will be ignored.', UserWarning)
if config.interval == interval and current_idx % config.frequency == 0:
self.model.lr_scheduler_step(cast(Any, scheduler), cast(int, opt_idx), None)
except AttributeError:
# lightning < 1.6
for lr_scheduler in self.trainer.lr_schedulers:
if lr_scheduler['reduce_on_plateau']:
warnings.warn('Reduce-lr-on-plateau is not supported in NAS. It will be ignored.', UserWarning)
if lr_scheduler['interval'] == interval and current_idx % lr_scheduler['frequency']:
lr_scheduler['scheduler'].step()
def architecture_optimizers(self) -> list[Optimizer] | Optimizer | None:
"""
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in :meth:`training_step`.
Returns
----------
opts : list[Optimizer], Optimizer, None
Architecture optimizers defined in :meth:`configure_architecture_optimizers`. This will be None if there is no
architecture optimizers.
Get the optimizers configured in :meth:`configure_architecture_optimizers`.
"""
opts = self.optimizers()
if isinstance(opts, list):
# pylint: disable=unsubscriptable-object
arc_opts = opts[:self.arc_optim_count]
if len(arc_opts) == 1:
return cast(Optimizer, arc_opts[0])
return cast(List[Optimizer], arc_opts)
# If there is only 1 optimizer and it is the architecture optimizer
if self.arc_optim_count == 1:
return cast(Union[List[Optimizer], Optimizer], opts)
return None
optimizers = [opt for opt in self.trainer.optimizers if getattr(opt, 'is_arch_optimizer', False)]
if not optimizers:
return None
if len(optimizers) == 1:
return optimizers[0]
return optimizers
def weight_optimizers(self) -> list[Optimizer] | Optimizer | None:
"""
Get user optimizers from all optimizers. Use this to get user optimizers in :meth:`training_step`.
# The following methods redirects the callbacks to inner module.
# It's not the complete list though.
# More methods can be added if needed.
Returns
----------
opts : list[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
# Since use_pl_optimizer is set true (by default) here.
# opts always return a list
opts = self.optimizers()
if isinstance(opts, list):
# pylint: disable=unsubscriptable-object
return cast(List[Optimizer], opts[self.arc_optim_count:])
# FIXME: this case is actually not correctly handled
# If there is only 1 optimizer and no architecture optimizer
if self.arc_optim_count == 0:
return cast(Union[List[Optimizer], Optimizer], opts)
return None
def on_train_start(self):
return self.model.on_train_start()
def on_train_end(self):
return self.model.on_train_end()
def on_fit_start(self):
return self.model.on_fit_start()
def on_fit_end(self):
return self.model.on_fit_end()
def on_train_batch_start(self, batch, batch_idx, *args, **kwargs):
return self.model.on_train_batch_start(batch, batch_idx, *args, **kwargs)
def on_train_batch_end(self, outputs, batch, batch_idx, *args, **kwargs):
return self.model.on_train_batch_end(outputs, batch, batch_idx, *args, **kwargs)
def on_train_epoch_start(self):
return self.model.on_train_epoch_start()
def on_train_epoch_end(self):
return self.model.on_train_epoch_end()
def on_before_backward(self, loss):
return self.model.on_before_backward(loss)
def on_after_backward(self):
return self.model.on_after_backward()
......@@ -9,7 +9,7 @@ import pytorch_lightning as pl
import torch
import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .base_lightning import BaseOneShotLightningModule, MANUAL_OPTIMIZATION_NOTE, MutationHook, no_default_hook
from .supermodule.differentiable import (
DifferentiableMixedLayer, DifferentiableMixedInput,
MixedOpDifferentiablePolicy, GumbelSoftmax,
......@@ -28,6 +28,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
DARTS repeats iterations, where each iteration consists of 2 training phases.
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
In both phases, ``training_step`` of the Lightning evaluator will be used.
The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet.
......@@ -49,15 +50,20 @@ class DartsLightningModule(BaseOneShotLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES),
optimization_note=MANUAL_OPTIMIZATION_NOTE
)
__doc__ = _darts_note.format(
......@@ -85,8 +91,10 @@ class DartsLightningModule(BaseOneShotLightningModule):
def __init__(self, inner_module: pl.LightningModule,
mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0E-4):
arc_learning_rate: float = 3.0E-4,
gradient_clip_val: float | None = None):
self.arc_learning_rate = arc_learning_rate
self.gradient_clip_val = gradient_clip_val
super().__init__(inner_module, mutation_hooks=mutation_hooks)
def training_step(self, batch, batch_idx):
......@@ -108,33 +116,32 @@ class DartsLightningModule(BaseOneShotLightningModule):
if isinstance(arc_step_loss, dict):
arc_step_loss = arc_step_loss['loss']
self.manual_backward(arc_step_loss)
self.finalize_grad()
arc_optim.step()
# phase 2: model step
self.resample()
self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1)
w_step_loss = loss_and_metrics['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
w_step_loss = loss_and_metrics['loss'] if isinstance(loss_and_metrics, dict) else loss_and_metrics
self.advance_optimization(w_step_loss, batch_idx, self.gradient_clip_val)
self.call_lr_schedulers(batch_idx)
# Update learning rates
self.advance_lr_schedulers(batch_idx)
return loss_and_metrics
self.log_dict({'prob/' + k: v for k, v in self.export_probs().items()})
def finalize_grad(self):
# Note: This hook is currently kept for Proxyless NAS.
pass
return loss_and_metrics
def configure_architecture_optimizers(self):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params = []
for m in self.nas_modules:
ctrl_params += list(m.parameters(arch=True)) # type: ignore
ctrl_optim = torch.optim.Adam(list(set(ctrl_params)), 3.e-4, betas=(0.5, 0.999),
weight_decay=1.0E-3)
# Follow the hyper-parameters used in
# https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/architect.py#L17
params = list(set(ctrl_params))
if not params:
raise ValueError('No architecture parameters found. Nothing to search.')
ctrl_optim = torch.optim.Adam(params, 3.e-4, betas=(0.5, 0.999), weight_decay=1.0E-3)
return ctrl_optim
......@@ -153,13 +160,20 @@ class ProxylessLightningModule(DartsLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
{base_params}
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note)
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
optimization_note=MANUAL_OPTIMIZATION_NOTE
)
__doc__ = _proxyless_note.format(
module_notes='This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
......@@ -176,10 +190,6 @@ class ProxylessLightningModule(DartsLightningModule):
# FIXME: no support for mixed operation currently
return hooks
def finalize_grad(self):
for m in self.nas_modules:
m.finalize_grad() # type: ignore
class GumbelDartsLightningModule(DartsLightningModule):
_gumbel_darts_note = """
......@@ -207,6 +217,8 @@ class GumbelDartsLightningModule(DartsLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
......@@ -216,13 +228,17 @@ class GumbelDartsLightningModule(DartsLightningModule):
use_temp_anneal : bool
If true, a linear annealing will be applied to ``gumbel_temperature``.
Otherwise, run at a fixed temperature. See `SNAS <https://arxiv.org/abs/1812.09926>`__ for details.
Default is false.
min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES),
optimization_note=MANUAL_OPTIMIZATION_NOTE
)
def mutate_kwargs(self):
......@@ -235,22 +251,25 @@ class GumbelDartsLightningModule(DartsLightningModule):
def __init__(self, inner_module,
mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0e-4,
gradient_clip_val: float | None = None,
gumbel_temperature: float = 1.,
use_temp_anneal: bool = False,
min_temp: float = .33):
super().__init__(inner_module, mutation_hooks, arc_learning_rate=arc_learning_rate)
super().__init__(inner_module, mutation_hooks, arc_learning_rate=arc_learning_rate, gradient_clip_val=gradient_clip_val)
self.temp = gumbel_temperature
self.init_temp = gumbel_temperature
self.use_temp_anneal = use_temp_anneal
self.min_temp = min_temp
def on_train_epoch_end(self):
def on_train_epoch_start(self):
if self.use_temp_anneal:
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp)
self.log('gumbel_temperature', self.temp)
for module in self.nas_modules:
if hasattr(module, '_softmax'):
module._softmax.temp = self.temp # type: ignore
if hasattr(module, '_softmax') and isinstance(module, GumbelSoftmax):
module._softmax.tau = self.temp # type: ignore
return self.model.on_train_epoch_end()
return self.model.on_train_epoch_start()
......@@ -94,11 +94,11 @@ class ReinforceController(nn.Module):
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})
def resample(self):
def resample(self, return_prob=False):
self._initialize()
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
result[field.name] = self._sample_single(field, return_prob=return_prob)
return result
def _initialize(self):
......@@ -116,7 +116,7 @@ class ReinforceController(nn.Module):
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field):
def _sample_single(self, field, return_prob):
self._lstm_next_step()
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
......@@ -124,10 +124,12 @@ class ReinforceController(nn.Module):
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one:
sampled_dist = F.softmax(logit, dim=-1)
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
sampled_dist = torch.sigmoid(logit)
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
......@@ -147,4 +149,7 @@ class ReinforceController(nn.Module):
self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1:
sampled = sampled[0]
if return_prob:
return sampled_dist.flatten().detach().cpu().numpy().tolist()
return sampled
......@@ -5,14 +5,14 @@
from __future__ import annotations
import warnings
from typing import Any
from typing import Any, cast
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook
from .base_lightning import MANUAL_OPTIMIZATION_NOTE, BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES
from .supermodule.sampling import (
PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy,
......@@ -37,6 +37,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
* :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
This strategy assumes inner evaluator has set
`automatic optimization <https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html>`__ to true.
Parameters
----------
{{module_params}}
......@@ -73,9 +76,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
'mixed_op_sampling': MixedOpPathSamplingPolicy
}
def training_step(self, batch, batch_idx):
def training_step(self, *args, **kwargs):
self.resample()
return self.model.training_step(batch, batch_idx)
return self.model.training_step(*args, **kwargs)
def export(self) -> dict[str, Any]:
"""
......@@ -115,6 +118,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
{{module_notes}}
{optimization_note}
Parameters
----------
{{module_params}}
......@@ -133,6 +138,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
before updating the weights of RL controller.
ctrl_grad_clip : float
Gradient clipping value of controller.
log_prob_every_n_step : int
Log the probability of choices every N steps. Useful for visualization and debugging.
reward_metric_name : str or None
The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator.
......@@ -141,11 +148,12 @@ class EnasLightningModule(RandomSamplingLightningModule):
Otherwise it raises an exception indicating multiple metrics are found.
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES)
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES),
optimization_note=MANUAL_OPTIMIZATION_NOTE
)
__doc__ = _enas_note.format(
module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.',
module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.pytorch.dataloader.ConcatLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note,
)
......@@ -162,6 +170,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
baseline_decay: float = .999,
ctrl_steps_aggregate: float = 20,
ctrl_grad_clip: float = 0,
log_prob_every_n_step: int = 10,
reward_metric_name: str | None = None,
mutation_hooks: list[MutationHook] | None = None):
super().__init__(inner_module, mutation_hooks)
......@@ -181,33 +190,29 @@ class EnasLightningModule(RandomSamplingLightningModule):
self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.ctrl_grad_clip = ctrl_grad_clip
self.log_prob_every_n_step = log_prob_every_n_step
self.reward_metric_name = reward_metric_name
def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4)
def training_step(self, batch_packed, batch_idx):
# The received batch is a tuple of (data, "train" | "val")
batch, mode = batch_packed
if mode == 'train':
# train model params
with torch.no_grad():
self.resample()
self.call_weight_optimizers('zero_grad')
step_output = self.model.training_step(batch, batch_idx)
w_step_loss = step_output['loss'] \
if isinstance(step_output, dict) else step_output
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
w_step_loss = step_output['loss'] if isinstance(step_output, dict) else step_output
self.advance_optimization(w_step_loss, batch_idx)
else:
# train ENAS agent
arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad()
self.resample()
# Run a sample to retrieve the reward
self.resample()
step_output = self.model.validation_step(batch, batch_idx)
# use the default metric of self.model as reward function
......@@ -218,11 +223,13 @@ class EnasLightningModule(RandomSamplingLightningModule):
if metric_name not in self.trainer.callback_metrics:
raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but '
f'found multiple (or zero) metrics without default: {list(self.trainer.callback_metrics.keys())}. '
f'Try to use self.log to report metrics with the specified key ``{metric_name}`` in validation_step, '
'and remember to set on_step=True.')
f'Please try to set ``reward_metric_name`` to be one of the keys listed above. '
f'If it is not working use self.log to report metrics with the specified key ``{metric_name}`` '
'in validation_step, and remember to set on_step=True.')
metric = self.trainer.callback_metrics[metric_name]
reward: float = metric.item()
# Compute the loss and run back propagation
if self.entropy_weight:
reward = reward + self.entropy_weight * self.controller.sample_entropy.item() # type: ignore
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
......@@ -236,11 +243,29 @@ class EnasLightningModule(RandomSamplingLightningModule):
if (batch_idx + 1) % self.ctrl_steps_aggregate == 0:
if self.ctrl_grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.ctrl_grad_clip)
# Update the controller and zero out its gradients
arc_opt = cast(optim.Optimizer, self.architecture_optimizers())
arc_opt.step()
arc_opt.zero_grad()
self.advance_lr_schedulers(batch_idx)
if (batch_idx + 1) % self.log_prob_every_n_step == 0:
with torch.no_grad():
self.log_dict({'prob/' + k: v for k, v in self.export_probs().items()})
return step_output
def on_train_epoch_start(self):
# Always zero out the gradients of ENAS controller at the beginning of epochs.
arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad()
return self.model.on_train_epoch_start()
def resample(self):
"""Resample the architecture with ENAS controller."""
sample = self.controller.resample()
......@@ -249,6 +274,14 @@ class EnasLightningModule(RandomSamplingLightningModule):
module.resample(memo=result)
return result
def export_probs(self):
"""Export the probability from ENAS controller directly."""
sample = self.controller.resample(return_prob=True)
result = self._interpret_controller_probability_result(sample)
for module in self.nas_modules:
module.resample(memo=result)
return result
def export(self):
"""Run one more inference of ENAS controller."""
self.controller.eval()
......@@ -261,3 +294,14 @@ class EnasLightningModule(RandomSamplingLightningModule):
for key in list(sample.keys()):
sample[key] = space_spec[key].values[sample[key]]
return sample
def _interpret_controller_probability_result(self, sample: dict[str, list[float]]) -> dict[str, Any]:
"""Convert ``{label: [prob1, prob2, prob3]} to ``{label/choice: prob}``"""
space_spec = self.search_space_spec()
result = {}
for key in list(sample.keys()):
if len(space_spec[key].values) != len(sample[key]):
raise ValueError(f'Expect {space_spec[key].values} to be of the same length as {sample[key]}')
for value, weight in zip(space_spec[key].values, sample[key]):
result[f'{key}/{value}'] = weight
return result
......@@ -168,11 +168,11 @@ def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence
assert len(items) == len(weights) > 0
elem = items[0]
unsupported_msg = f'Unsupported element type in weighted sum: {type(elem)}. Value is: {elem}'
unsupported_msg = 'Unsupported element type in weighted sum: {}. Value is: {}'
if isinstance(elem, str):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise TypeError(unsupported_msg)
raise TypeError(unsupported_msg.format(type(elem), elem))
try:
if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)):
......
......@@ -56,6 +56,17 @@ class BaseSuperNetModule(nn.Module):
"""
raise NotImplementedError()
def export_probs(self, memo: dict[str, Any]) -> dict[str, Any]:
"""
Export the probability / logits of every choice got chosen.
Parameters
----------
memo : dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise NotImplementedError()
def search_space_spec(self) -> dict[str, ParameterSpec]:
"""
Space specification (sample points).
......
......@@ -104,6 +104,13 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def export_probs(self, memo):
if any(k.startswith(self.label + '/') for k in memo):
return {} # nothing new
weights = self._softmax(self._arch_alpha).cpu().tolist()
ret = {f'{self.label}/{name}': value for name, value in zip(self.op_names, weights)}
return ret
def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))}
......@@ -117,7 +124,8 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later
alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later
memo[module.label] = alpha
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(list(module.named_children()), alpha, softmax, module.label)
......@@ -208,6 +216,13 @@ class DifferentiableMixedInput(BaseSuperNetModule):
chosen = chosen[0]
return {self.label: chosen}
def export_probs(self, memo):
if any(k.startswith(self.label + '/') for k in memo):
return {} # nothing new
weights = self._softmax(self._arch_alpha).cpu().tolist()
ret = {f'{self.label}/{index}': value for index, value in enumerate(weights)}
return ret
def search_space_spec(self):
return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
......@@ -225,7 +240,8 @@ class DifferentiableMixedInput(BaseSuperNetModule):
if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later
alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later
memo[module.label] = alpha
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label)
......@@ -284,6 +300,7 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
memo[name] = alpha
operation._arch_alpha[name] = alpha
operation.parameters = functools.partial(self.parameters, module=operation) # bind self
......@@ -321,6 +338,16 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
result[name] = spec.values[chosen_index]
return result
def export_probs(self, operation: MixedOperation, memo: dict[str, Any]):
"""Export the weight for every leaf value choice."""
ret = {}
for name, spec in operation.search_space_spec().items():
if any(k.startswith(name + '/') for k in memo):
continue
weights = operation._softmax(operation._arch_alpha[name]).cpu().tolist() # type: ignore
ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)})
return ret
def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any:
if name in operation.mutable_arguments:
weights: dict[str, torch.Tensor] = {
......@@ -360,6 +387,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
memo[name] = alpha
self._arch_alpha[name] = alpha
def resample(self, memo):
......@@ -376,6 +404,16 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
result[name] = spec.values[chosen_index]
return result
def export_probs(self, memo):
"""Export the weight for every leaf value choice."""
ret = {}
for name, spec in self.search_space_spec().items():
if any(k.startswith(name + '/') for k in memo):
continue
weights = self._softmax(self._arch_alpha[name]).cpu().tolist()
ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)})
return ret
def search_space_spec(self):
return self._space_spec
......@@ -427,6 +465,8 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
class DifferentiableMixedCell(PathSamplingCell):
"""Implementation of Cell under differentiable context.
Similar to PathSamplingCell, this cell only handles cells of specific kinds (e.g., with loose end).
An architecture parameter is created on each edge of the full-connected graph.
"""
......@@ -450,13 +490,21 @@ class DifferentiableMixedCell(PathSamplingCell):
op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j]
if edge_label in memo:
alpha = memo[edge_label]
if len(alpha) != len(op):
raise ValueError(
f'Architecture parameter size of same label {edge_label} conflict: '
f'{len(alpha)} vs. {len(op)}'
if len(alpha) != len(op) + 1:
if len(alpha) != len(op):
raise ValueError(
f'Architecture parameter size of same label {edge_label} conflict: '
f'{len(alpha)} vs. {len(op)}'
)
warnings.warn(
f'Architecture parameter size {len(alpha)} is not same as expected: {len(op) + 1}. '
'This is likely due to the label being shared by a LayerChoice inside the cell and outside.',
UserWarning
)
else:
alpha = nn.Parameter(torch.randn(len(op)) * 1E-3)
# +1 to emulate the input choice.
alpha = nn.Parameter(torch.randn(len(op) + 1) * 1E-3)
memo[edge_label] = alpha
self._arch_alpha[edge_label] = alpha
self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
......@@ -465,18 +513,32 @@ class DifferentiableMixedCell(PathSamplingCell):
"""Differentiable doesn't need to resample."""
return {}
def export_probs(self, memo):
"""When export probability, we follow the structure in arch alpha."""
ret = {}
for name, parameter in self._arch_alpha.items():
if any(k.startswith(name + '/') for k in memo):
continue
weights = self._softmax(parameter).cpu().tolist()
ret.update({f'{name}/{value}': weight for value, weight in zip(self.op_names, weights)})
return ret
def export(self, memo):
"""Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
"""
exported = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
# If label already exists, no need to re-export.
if all(f'{self.label}/op_{i}_{k}' in memo and f'{self.label}/input_{i}_{k}' in memo for k in range(self.num_ops_per_node)):
continue
# Tuple of (weight, input_index, op_name)
all_weights: list[tuple[float, int, str]] = []
for j in range(i):
for k, name in enumerate(self.op_names):
# The last appended weight is automatically skipped in export.
all_weights.append((
float(self._arch_alpha[f'{self.label}/{i}_{j}'][k].item()),
j, name,
......@@ -497,7 +559,7 @@ class DifferentiableMixedCell(PathSamplingCell):
all_weights = [all_weights[k] for k in first_occurrence_index] + \
[w for j, w in enumerate(all_weights) if j not in first_occurrence_index]
_logger.info('Sorted weights in differentiable cell export (node %d): %s', i, all_weights)
_logger.info('Sorted weights in differentiable cell export (%s cell, node %d): %s', self.label, i, all_weights)
for k in range(self.num_ops_per_node):
# all_weights could be too short in case ``num_ops_per_node`` is too large.
......@@ -515,7 +577,11 @@ class DifferentiableMixedCell(PathSamplingCell):
for j in range(i): # for every previous tensors
op_results = torch.stack([op(states[j]) for op in ops[j].values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) # (-1, 1, 1, 1, 1, ...)
op_weights = self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}'])
if len(op_weights) == len(op_results) + 1:
# concatenate with a zero operation, indicating this path is not chosen at all.
op_results = torch.cat((op_results, torch.zeros_like(op_results[:1])), 0)
edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0)
current_state.append(edge_sum)
......
......@@ -71,6 +71,10 @@ class MixedOperationSamplingPolicy:
"""The handler of :meth:`MixedOperation.export`."""
raise NotImplementedError()
def export_probs(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.export_probs`."""
raise NotImplementedError()
def forward_argument(self, operation: 'MixedOperation', name: str) -> Any:
"""Computing the argument with ``name`` used in operation's forward.
Usually a value, or a distribution of value.
......@@ -162,6 +166,10 @@ class MixedOperation(BaseSuperNetModule):
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
return self.sampling_policy.resample(self, memo)
def export_probs(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export_probs`."""
return self.sampling_policy.export_probs(self, memo)
def export(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
return self.sampling_policy.export(self, memo)
......
......@@ -11,7 +11,7 @@ The support remains limited. Known limitations include:
from __future__ import annotations
from typing import cast
from typing import Any, Tuple, Union, cast
import torch
import torch.nn as nn
......@@ -21,28 +21,115 @@ from .differentiable import DifferentiableMixedLayer, DifferentiableMixedInput
__all__ = ['ProxylessMixedLayer', 'ProxylessMixedInput']
class _ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
def _detach_tensor(tensor: Any) -> Any:
"""Recursively detach all the tensors."""
if isinstance(tensor, (list, tuple)):
return tuple(_detach_tensor(t) for t in tensor)
elif isinstance(tensor, dict):
return {k: _detach_tensor(v) for k, v in tensor.items()}
elif isinstance(tensor, torch.Tensor):
return tensor.detach()
else:
return tensor
detached_x = x.detach()
detached_x.requires_grad = x.requires_grad
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
def _iter_tensors(tensor: Any) -> Any:
"""Recursively iterate over all the tensors.
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
This is kept for complex outputs (like dicts / lists).
However, complex outputs are not supported by PyTorch backward hooks yet.
"""
if isinstance(tensor, torch.Tensor):
yield tensor
elif isinstance(tensor, (list, tuple)):
for t in tensor:
yield from _iter_tensors(t)
elif isinstance(tensor, dict):
for t in tensor.values():
yield from _iter_tensors(t)
def _pack_as_tuple(tensor: Any) -> tuple:
"""Return a tuple of tensor with only one element if tensor it's not a tuple."""
if isinstance(tensor, (tuple, list)):
return tuple(tensor)
return (tensor,)
def element_product_sum(tensor1: tuple[torch.Tensor, ...], tensor2: tuple[torch.Tensor, ...]) -> torch.Tensor:
"""Compute the sum of all the element-wise product."""
assert len(tensor1) == len(tensor2), 'The number of tensors must be the same.'
# Skip zero gradients
ret = [torch.sum(t1 * t2) for t1, t2 in zip(tensor1, tensor2) if t1 is not None and t2 is not None]
if not ret:
return torch.tensor(0)
if len(ret) == 1:
return ret[0]
return cast(torch.Tensor, sum(ret))
class ProxylessContext:
def __init__(self, arch_alpha: torch.Tensor, softmax: nn.Module) -> None:
self.arch_alpha = arch_alpha
self.softmax = softmax
# When a layer is called multiple times, the inputs and outputs are saved in order.
# In backward propagation, we assume that they are used in the reversed order.
self.layer_input: list[Any] = []
self.layer_output: list[Any] = []
self.layer_sample_idx: list[int] = []
def clear_context(self) -> None:
self.layer_input = []
self.layer_output = []
self.layer_sample_idx = []
def save_forward_context(self, layer_input: Any, layer_output: Any, layer_sample_idx: int):
self.layer_input.append(_detach_tensor(layer_input))
self.layer_output.append(_detach_tensor(layer_output))
self.layer_sample_idx.append(layer_sample_idx)
def backward_hook(self, module: nn.Module,
grad_input: Union[Tuple[torch.Tensor, ...], torch.Tensor],
grad_output: Union[Tuple[torch.Tensor, ...], torch.Tensor]) -> None:
# binary_grads is the gradient of binary gates.
# Binary gates is a one-hot tensor where 1 is on the sampled index, and others are 0.
# By chain rule, it's gradient is grad_output times the layer_output (of the corresponding path).
binary_grads = torch.zeros_like(self.arch_alpha)
# Retrieve the layer input/output in reverse order.
if not self.layer_input:
raise ValueError('Unexpected backward call. The saved context is empty.')
layer_input = self.layer_input.pop()
layer_output = self.layer_output.pop()
layer_sample_idx = self.layer_sample_idx.pop()
return grad_x[0], binary_grads, None, None
with torch.no_grad():
# Compute binary grads.
for k in range(len(binary_grads)):
if k != layer_sample_idx:
args, kwargs = layer_input
out_k = module.forward_path(k, *args, **kwargs) # type: ignore
else:
out_k = layer_output
# FIXME: One limitation here is that out_k can't be complex objects like dict.
# I think it's also a limitation of backward hook.
binary_grads[k] = element_product_sum(
_pack_as_tuple(out_k), # In case out_k is a single tensor
_pack_as_tuple(grad_output)
)
# Compute the gradient of the arch_alpha, based on binary_grads.
if self.arch_alpha.grad is None:
self.arch_alpha.grad = torch.zeros_like(self.arch_alpha)
probs = self.softmax(self.arch_alpha)
for i in range(len(self.arch_alpha)):
for j in range(len(self.arch_alpha)):
# Arch alpha's gradients are accumulated for all backwards through this layer.
self.arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessMixedLayer(DifferentiableMixedLayer):
......@@ -50,46 +137,32 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
It resamples a single-path every time, rather than go through the softmax.
"""
_arch_parameter_names = ['_arch_alpha', '_binary_gates']
_arch_parameter_names = ['_arch_alpha']
def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(paths, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(len(paths)) * 1E-3)
# Binary gates should be created here, but it's not because it's never used in the forward pass.
# self._binary_gates = nn.Parameter(torch.zeros(len(paths)))
# like sampling-based methods, it has a ``_sampled``.
self._sampled: str | None = None
self._sample_idx: int | None = None
# arch_alpha could be shared by multiple layers,
# but binary_gates is owned by the current layer.
self.ctx = ProxylessContext(alpha, softmax)
self.register_full_backward_hook(self.ctx.backward_hook)
def forward(self, *args, **kwargs):
def run_function(ops, active_id, **kwargs):
def forward(_x):
return ops[active_id](_x, **kwargs)
return forward
def backward_function(ops, active_id, binary_gates, **kwargs):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data, **kwargs)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1, 'ProxylessMixedLayer only supports exactly one input argument.'
x = args[0]
assert self._sampled is not None, 'Need to call resample() before running fprop.'
list_ops = [getattr(self, op) for op in self.op_names]
return _ArchGradientFunction.apply(
x, self._binary_gates, run_function(list_ops, self._sample_idx, **kwargs),
backward_function(list_ops, self._sample_idx, self._binary_gates, **kwargs)
)
"""Forward pass of one single path."""
if self._sample_idx is None:
raise RuntimeError('resample() needs to be called before fprop.')
output = self.forward_path(self._sample_idx, *args, **kwargs)
self.ctx.save_forward_context((args, kwargs), output, self._sample_idx)
return output
def forward_path(self, index, *args, **kwargs):
return getattr(self, self.op_names[index])(*args, **kwargs)
def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo."""
......@@ -101,66 +174,37 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
self._sample_idx = int(torch.multinomial(probs, 1)[0].item())
self._sampled = self.op_names[self._sample_idx]
# set binary gates
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[self._sample_idx] = 1.0
self.ctx.clear_context()
return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(len(self._arch_alpha)):
for j in range(len(self._arch_alpha)):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessMixedInput(DifferentiableMixedInput):
"""Proxyless version of differentiable input choice.
See :class:`ProxylessLayerChoice` for implementation details.
See :class:`ProxylessMixedLayer` for implementation details.
"""
_arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(n_candidates, n_chosen, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(n_candidates) * 1E-3)
# We only support choosing a particular one here.
# Nevertheless, we rank the score and export the tops in export.
self._sampled: int | None = None
self.ctx = ProxylessContext(alpha, softmax)
self.register_full_backward_hook(self.ctx.backward_hook)
def forward(self, inputs):
def run_function(active_sample):
return lambda x: x[active_sample]
def backward_function(binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(self.n_candidates):
out_k = _x[k].data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
inputs = torch.stack(inputs, 0)
assert self._sampled is not None, 'Need to call resample() before running fprop.'
return _ArchGradientFunction.apply(
inputs, self._binary_gates, run_function(self._sampled),
backward_function(self._binary_gates)
)
"""Choose one single input."""
if self._sampled is None:
raise RuntimeError('resample() needs to be called before fprop.')
output = self.forward_path(self._sampled, inputs)
self.ctx.save_forward_context(((inputs,), {}), output, self._sampled)
return output
def forward_path(self, index, inputs):
return inputs[index]
def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo."""
......@@ -171,27 +215,6 @@ class ProxylessMixedInput(DifferentiableMixedInput):
sample = torch.multinomial(probs, 1)[0].item()
self._sampled = int(sample)
# set binary gates
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[cast(int, self._sampled)] = 1.0
self.ctx.clear_context()
return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: torch.argmax(self._arch_alpha).item()}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(self.n_candidates):
for j in range(self.n_candidates):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
......@@ -169,7 +169,7 @@ class PathSamplingInput(BaseSuperNetModule):
class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
"""Implementes the path sampling in mixed operation.
"""Implements the path sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments.
Each value choice can be further decomposed into "leaf value choices".
......@@ -388,6 +388,10 @@ class PathSamplingCell(BaseSuperNetModule):
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
"""
Mutate only handles cells of specific configurations (e.g., with loose end).
Fallback to the default mutate if the cell is not handled here.
"""
if isinstance(module, Cell):
op_factory = None # not all the cells need to be replaced
if module.op_candidates_factory is not None:
......
......@@ -5,6 +5,7 @@ import pytorch_lightning as pl
import pytest
from torchvision import transforms
from torchvision.datasets import MNIST
from torch import nn
from torch.utils.data import Dataset, RandomSampler
import nni
......@@ -13,7 +14,11 @@ from nni.retiarii import strategy, model_wrapper, basic_unit
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.evaluator.pytorch.lightning import Classification, Regression, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
from nni.retiarii.oneshot.pytorch import DartsLightningModule
from nni.retiarii.strategy import BaseStrategy
from pytorch_lightning import LightningModule, Trainer
from .test_oneshot_utils import RandomDataset
pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
......@@ -338,17 +343,49 @@ def test_gumbel_darts():
_test_strategy(strategy.GumbelDARTS())
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='all', metavar='E',
help='experiment to run, default = all')
args = parser.parse_args()
def test_optimizer_lr_scheduler():
learning_rates = []
if args.exp == 'all':
test_darts()
test_proxyless()
test_enas()
test_random()
test_gumbel_darts()
else:
globals()[f'test_{args.exp}']()
class CustomLightningModule(LightningModule):
def __init__(self):
super().__init__()
self.layer1 = nn.Linear(32, 2)
self.layer2 = nn.LayerChoice([nn.Linear(2, 2), nn.Linear(2, 2, bias=False)])
def forward(self, x):
return self.layer2(self.layer1(x))
def configure_optimizers(self):
opt1 = torch.optim.SGD(self.layer1.parameters(), lr=0.1)
opt2 = torch.optim.Adam(self.layer2.parameters(), lr=0.2)
return [opt1, opt2], [torch.optim.lr_scheduler.StepLR(opt1, step_size=2, gamma=0.1)]
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('train_loss', loss)
return {'loss': loss}
def on_train_epoch_start(self) -> None:
learning_rates.append(self.optimizers()[0].param_groups[0]['lr'])
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('test_loss', loss)
train_data = RandomDataset(32, 32)
valid_data = RandomDataset(32, 16)
model = CustomLightningModule()
darts_module = DartsLightningModule(model, gradient_clip_val=5)
trainer = Trainer(max_epochs=10)
trainer.fit(
darts_module,
dict(train=DataLoader(train_data, batch_size=8), val=DataLoader(valid_data, batch_size=8))
)
assert len(learning_rates) == 10 and abs(learning_rates[0] - 0.1) < 1e-5 and \
abs(learning_rates[2] - 0.01) < 1e-5 and abs(learning_rates[-1] - 1e-5) < 1e-6
import torch
import torch.nn as nn
from nni.nas.hub.pytorch.nasbench201 import OPS_WITH_STRIDE
from nni.nas.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput, _iter_tensors
def test_proxyless_bp():
op = ProxylessMixedLayer(
[(name, value(3, 3, 1)) for name, value in OPS_WITH_STRIDE.items()],
nn.Parameter(torch.randn(len(OPS_WITH_STRIDE))),
nn.Softmax(-1), 'proxyless'
)
optimizer = torch.optim.SGD(op.parameters(arch=True), 0.1)
for _ in range(10):
x = torch.randn(1, 3, 9, 9).requires_grad_()
op.resample({})
y = op(x).sum()
optimizer.zero_grad()
y.backward()
assert op._arch_alpha.grad.abs().sum().item() != 0
def test_proxyless_input():
inp = ProxylessMixedInput(6, 2, nn.Parameter(torch.zeros(6)), nn.Softmax(-1), 'proxyless')
optimizer = torch.optim.SGD(inp.parameters(arch=True), 0.1)
for _ in range(10):
x = [torch.randn(1, 3, 9, 9).requires_grad_() for _ in range(6)]
inp.resample({})
y = inp(x).sum()
optimizer.zero_grad()
y.backward()
def test_iter_tensors():
a = (torch.zeros(3, 1), {'a': torch.zeros(5, 1), 'b': torch.zeros(6, 1)}, [torch.zeros(7, 1)])
ret = []
for x in _iter_tensors(a):
ret.append(x.shape[0])
assert ret == [3, 5, 6, 7]
class MultiInputLayer(nn.Module):
def __init__(self, d):
super().__init__()
self.d = d
def forward(self, q, k, v=None, mask=None):
return q + self.d, 2 * k - 2 * self.d, v, mask
def test_proxyless_multi_input():
op = ProxylessMixedLayer(
[
('a', MultiInputLayer(1)),
('b', MultiInputLayer(3))
],
nn.Parameter(torch.randn(2)),
nn.Softmax(-1), 'proxyless'
)
optimizer = torch.optim.SGD(op.parameters(arch=True), 0.1)
for retry in range(10):
q = torch.randn(1, 3, 9, 9).requires_grad_()
k = torch.randn(1, 3, 9, 8).requires_grad_()
v = None if retry < 5 else torch.randn(1, 3, 9, 7).requires_grad_()
mask = None if retry % 5 < 2 else torch.randn(1, 3, 9, 6).requires_grad_()
op.resample({})
y = op(q, k, v, mask=mask)
y = y[0].sum() + y[1].sum()
optimizer.zero_grad()
y.backward()
assert op._arch_alpha.grad.abs().sum().item() != 0, op._arch_alpha.grad
......@@ -3,7 +3,7 @@ import pytest
import numpy as np
import torch
import torch.nn as nn
from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, LayerNorm, Linear, MultiheadAttention
from nni.retiarii.nn.pytorch import ValueChoice, LayerChoice, Conv2d, BatchNorm2d, LayerNorm, Linear, MultiheadAttention
from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import (
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax,
......@@ -144,6 +144,16 @@ def test_differentiable_valuechoice():
assert set(conv.export({}).keys()) == {'123', '456'}
def test_differentiable_layerchoice_dedup():
layerchoice1 = LayerChoice([Conv2d(3, 3, 3), Conv2d(3, 3, 3)], label='a')
layerchoice2 = LayerChoice([Conv2d(3, 3, 3), Conv2d(3, 3, 3)], label='a')
memo = {}
DifferentiableMixedLayer.mutate(layerchoice1, 'x', memo, {})
DifferentiableMixedLayer.mutate(layerchoice2, 'x', memo, {})
assert len(memo) == 1 and 'a' in memo
def _mixed_operation_sampling_sanity_check(operation, memo, *input):
for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation):
......@@ -160,7 +170,9 @@ def _mixed_operation_differentiable_sanity_check(operation, *input):
mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
break
return mutate_op(*input)
mutate_op(*input)
mutate_op.export({})
mutate_op.export_probs({})
def test_mixed_linear():
......@@ -319,6 +331,9 @@ def test_differentiable_layer_input():
op = DifferentiableMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b']
probs = op.export_probs({})
assert len(probs) == 2
assert abs(probs['eee/a'] + probs['eee/b'] - 1) < 1e-4
assert len(list(op.parameters())) == 3
with pytest.raises(ValueError):
......@@ -328,6 +343,8 @@ def test_differentiable_layer_input():
input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2
assert len(input.export({})['ddd']) == 2
assert len(input.export_probs({})) == 5
assert 'ddd/3' in input.export_probs({})
def test_proxyless_layer_input():
......@@ -341,7 +358,8 @@ def test_proxyless_layer_input():
input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input.resample({})['ddd'] in list(range(5))
assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2])
assert input.export({})['ddd'] in list(range(5))
exported = input.export({})['ddd']
assert len(exported) == 2 and all(e in list(range(5)) for e in exported)
def test_pathsampling_repeat():
......@@ -373,6 +391,7 @@ def test_differentiable_repeat():
assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
sample = op.export({})
assert 'ccc' in sample and sample['ccc'] in [0, 1]
assert sorted(op.export_probs({}).keys()) == ['ccc/0', 'ccc/1']
class TupleModule(nn.Module):
def __init__(self, num):
......@@ -452,11 +471,16 @@ def test_differentiable_cell():
result.update(module.export(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
result_prob = {}
for module in nas_modules:
result_prob.update(module.export_probs(memo=result_prob))
ctrl_params = []
for m in nas_modules:
ctrl_params += list(m.parameters(arch=True))
if cell_cls in [CellLooseEnd, CellOpFactory]:
assert len(ctrl_params) == model.cell.num_nodes * (model.cell.num_nodes + 3) // 2
assert len(result_prob) == len(ctrl_params) * 2 # len(op_names) == 2
assert isinstance(model.cell, DifferentiableMixedCell)
else:
assert not isinstance(model.cell, DifferentiableMixedCell)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment