Unverified Commit f77db747 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Enhancement of one-shot NAS (v2.9) (#5049)

parent 125ec21f
...@@ -70,7 +70,7 @@ class NasBench201Cell(nn.Module): ...@@ -70,7 +70,7 @@ class NasBench201Cell(nn.Module):
inp = in_features if j == 0 else out_features inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features)) op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()]) for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine node_ops.append(LayerChoice(op_choices, label=f'{self._label}/{j}_{tid}'))
self.layers.append(node_ops) self.layers.append(node_ops)
def forward(self, inputs: torch.Tensor) -> torch.Tensor: def forward(self, inputs: torch.Tensor) -> torch.Tensor:
......
...@@ -179,7 +179,7 @@ class NasBench201(nn.Module): ...@@ -179,7 +179,7 @@ class NasBench201(nn.Module):
cell = ResNetBasicblock(C_prev, C_curr, 2) cell = ResNetBasicblock(C_prev, C_curr, 2)
else: else:
ops: Dict[str, Callable[[int, int], nn.Module]] = { ops: Dict[str, Callable[[int, int], nn.Module]] = {
prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES prim: self._make_op_factory(prim) for prim in PRIMITIVES
} }
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell') cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
self.cells.append(cell) self.cells.append(cell)
...@@ -192,6 +192,9 @@ class NasBench201(nn.Module): ...@@ -192,6 +192,9 @@ class NasBench201(nn.Module):
self.global_pooling = nn.AdaptiveAvgPool2d(1) self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, self.num_labels) self.classifier = nn.Linear(C_prev, self.num_labels)
def _make_op_factory(self, prim):
return lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1)
def forward(self, inputs): def forward(self, inputs):
feature = self.stem(inputs) feature = self.stem(inputs)
for cell in self.cells: for cell in self.cells:
......
This diff is collapsed.
...@@ -9,7 +9,7 @@ import pytorch_lightning as pl ...@@ -9,7 +9,7 @@ import pytorch_lightning as pl
import torch import torch
import torch.optim as optim import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook from .base_lightning import BaseOneShotLightningModule, MANUAL_OPTIMIZATION_NOTE, MutationHook, no_default_hook
from .supermodule.differentiable import ( from .supermodule.differentiable import (
DifferentiableMixedLayer, DifferentiableMixedInput, DifferentiableMixedLayer, DifferentiableMixedInput,
MixedOpDifferentiablePolicy, GumbelSoftmax, MixedOpDifferentiablePolicy, GumbelSoftmax,
...@@ -28,6 +28,7 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -28,6 +28,7 @@ class DartsLightningModule(BaseOneShotLightningModule):
DARTS repeats iterations, where each iteration consists of 2 training phases. DARTS repeats iterations, where each iteration consists of 2 training phases.
The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained. The phase 1 is architecture step, in which model parameters are frozen and the architecture parameters are trained.
The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained. The phase 2 is model step, in which architecture parameters are frozen and model parameters are trained.
In both phases, ``training_step`` of the Lightning evaluator will be used.
The current implementation corresponds to DARTS (1st order) in paper. The current implementation corresponds to DARTS (1st order) in paper.
Second order (unrolled 2nd-order derivatives) is not supported yet. Second order (unrolled 2nd-order derivatives) is not supported yet.
...@@ -49,15 +50,20 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -49,15 +50,20 @@ class DartsLightningModule(BaseOneShotLightningModule):
{{module_notes}} {{module_notes}}
{optimization_note}
Parameters Parameters
---------- ----------
{{module_params}} {{module_params}}
{base_params} {base_params}
arc_learning_rate : float arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
""".format( """.format(
base_params=BaseOneShotLightningModule._mutation_hooks_note, base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES) supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES),
optimization_note=MANUAL_OPTIMIZATION_NOTE
) )
__doc__ = _darts_note.format( __doc__ = _darts_note.format(
...@@ -85,8 +91,10 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -85,8 +91,10 @@ class DartsLightningModule(BaseOneShotLightningModule):
def __init__(self, inner_module: pl.LightningModule, def __init__(self, inner_module: pl.LightningModule,
mutation_hooks: list[MutationHook] | None = None, mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0E-4): arc_learning_rate: float = 3.0E-4,
gradient_clip_val: float | None = None):
self.arc_learning_rate = arc_learning_rate self.arc_learning_rate = arc_learning_rate
self.gradient_clip_val = gradient_clip_val
super().__init__(inner_module, mutation_hooks=mutation_hooks) super().__init__(inner_module, mutation_hooks=mutation_hooks)
def training_step(self, batch, batch_idx): def training_step(self, batch, batch_idx):
...@@ -108,33 +116,32 @@ class DartsLightningModule(BaseOneShotLightningModule): ...@@ -108,33 +116,32 @@ class DartsLightningModule(BaseOneShotLightningModule):
if isinstance(arc_step_loss, dict): if isinstance(arc_step_loss, dict):
arc_step_loss = arc_step_loss['loss'] arc_step_loss = arc_step_loss['loss']
self.manual_backward(arc_step_loss) self.manual_backward(arc_step_loss)
self.finalize_grad()
arc_optim.step() arc_optim.step()
# phase 2: model step # phase 2: model step
self.resample() self.resample()
self.call_weight_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1) loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1)
w_step_loss = loss_and_metrics['loss'] \ w_step_loss = loss_and_metrics['loss'] if isinstance(loss_and_metrics, dict) else loss_and_metrics
if isinstance(loss_and_metrics, dict) else loss_and_metrics self.advance_optimization(w_step_loss, batch_idx, self.gradient_clip_val)
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
self.call_lr_schedulers(batch_idx) # Update learning rates
self.advance_lr_schedulers(batch_idx)
return loss_and_metrics self.log_dict({'prob/' + k: v for k, v in self.export_probs().items()})
def finalize_grad(self): return loss_and_metrics
# Note: This hook is currently kept for Proxyless NAS.
pass
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
# The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer. # The alpha in DartsXXXChoices are the architecture parameters of DARTS. They share one optimizer.
ctrl_params = [] ctrl_params = []
for m in self.nas_modules: for m in self.nas_modules:
ctrl_params += list(m.parameters(arch=True)) # type: ignore ctrl_params += list(m.parameters(arch=True)) # type: ignore
ctrl_optim = torch.optim.Adam(list(set(ctrl_params)), 3.e-4, betas=(0.5, 0.999), # Follow the hyper-parameters used in
weight_decay=1.0E-3) # https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/architect.py#L17
params = list(set(ctrl_params))
if not params:
raise ValueError('No architecture parameters found. Nothing to search.')
ctrl_optim = torch.optim.Adam(params, 3.e-4, betas=(0.5, 0.999), weight_decay=1.0E-3)
return ctrl_optim return ctrl_optim
...@@ -153,13 +160,20 @@ class ProxylessLightningModule(DartsLightningModule): ...@@ -153,13 +160,20 @@ class ProxylessLightningModule(DartsLightningModule):
{{module_notes}} {{module_notes}}
{optimization_note}
Parameters Parameters
---------- ----------
{{module_params}} {{module_params}}
{base_params} {base_params}
arc_learning_rate : float arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
""".format(base_params=BaseOneShotLightningModule._mutation_hooks_note) gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
""".format(
base_params=BaseOneShotLightningModule._mutation_hooks_note,
optimization_note=MANUAL_OPTIMIZATION_NOTE
)
__doc__ = _proxyless_note.format( __doc__ = _proxyless_note.format(
module_notes='This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.', module_notes='This module should be trained with :class:`pytorch_lightning.trainer.supporters.CombinedLoader`.',
...@@ -176,10 +190,6 @@ class ProxylessLightningModule(DartsLightningModule): ...@@ -176,10 +190,6 @@ class ProxylessLightningModule(DartsLightningModule):
# FIXME: no support for mixed operation currently # FIXME: no support for mixed operation currently
return hooks return hooks
def finalize_grad(self):
for m in self.nas_modules:
m.finalize_grad() # type: ignore
class GumbelDartsLightningModule(DartsLightningModule): class GumbelDartsLightningModule(DartsLightningModule):
_gumbel_darts_note = """ _gumbel_darts_note = """
...@@ -207,6 +217,8 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -207,6 +217,8 @@ class GumbelDartsLightningModule(DartsLightningModule):
{{module_notes}} {{module_notes}}
{optimization_note}
Parameters Parameters
---------- ----------
{{module_params}} {{module_params}}
...@@ -216,13 +228,17 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -216,13 +228,17 @@ class GumbelDartsLightningModule(DartsLightningModule):
use_temp_anneal : bool use_temp_anneal : bool
If true, a linear annealing will be applied to ``gumbel_temperature``. If true, a linear annealing will be applied to ``gumbel_temperature``.
Otherwise, run at a fixed temperature. See `SNAS <https://arxiv.org/abs/1812.09926>`__ for details. Otherwise, run at a fixed temperature. See `SNAS <https://arxiv.org/abs/1812.09926>`__ for details.
Default is false.
min_temp : float min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False. The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
arc_learning_rate : float arc_learning_rate : float
Learning rate for architecture optimizer. Default: 3.0e-4 Learning rate for architecture optimizer. Default: 3.0e-4
gradient_clip_val : float
Clip gradients before optimizing models at each step. Default: None
""".format( """.format(
base_params=BaseOneShotLightningModule._mutation_hooks_note, base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES) supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES),
optimization_note=MANUAL_OPTIMIZATION_NOTE
) )
def mutate_kwargs(self): def mutate_kwargs(self):
...@@ -235,22 +251,25 @@ class GumbelDartsLightningModule(DartsLightningModule): ...@@ -235,22 +251,25 @@ class GumbelDartsLightningModule(DartsLightningModule):
def __init__(self, inner_module, def __init__(self, inner_module,
mutation_hooks: list[MutationHook] | None = None, mutation_hooks: list[MutationHook] | None = None,
arc_learning_rate: float = 3.0e-4, arc_learning_rate: float = 3.0e-4,
gradient_clip_val: float | None = None,
gumbel_temperature: float = 1., gumbel_temperature: float = 1.,
use_temp_anneal: bool = False, use_temp_anneal: bool = False,
min_temp: float = .33): min_temp: float = .33):
super().__init__(inner_module, mutation_hooks, arc_learning_rate=arc_learning_rate) super().__init__(inner_module, mutation_hooks, arc_learning_rate=arc_learning_rate, gradient_clip_val=gradient_clip_val)
self.temp = gumbel_temperature self.temp = gumbel_temperature
self.init_temp = gumbel_temperature self.init_temp = gumbel_temperature
self.use_temp_anneal = use_temp_anneal self.use_temp_anneal = use_temp_anneal
self.min_temp = min_temp self.min_temp = min_temp
def on_train_epoch_end(self): def on_train_epoch_start(self):
if self.use_temp_anneal: if self.use_temp_anneal:
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp) self.temp = max(self.temp, self.min_temp)
self.log('gumbel_temperature', self.temp)
for module in self.nas_modules: for module in self.nas_modules:
if hasattr(module, '_softmax'): if hasattr(module, '_softmax') and isinstance(module, GumbelSoftmax):
module._softmax.temp = self.temp # type: ignore module._softmax.tau = self.temp # type: ignore
return self.model.on_train_epoch_end() return self.model.on_train_epoch_start()
...@@ -94,11 +94,11 @@ class ReinforceController(nn.Module): ...@@ -94,11 +94,11 @@ class ReinforceController(nn.Module):
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
}) })
def resample(self): def resample(self, return_prob=False):
self._initialize() self._initialize()
result = dict() result = dict()
for field in self.fields: for field in self.fields:
result[field.name] = self._sample_single(field) result[field.name] = self._sample_single(field, return_prob=return_prob)
return result return result
def _initialize(self): def _initialize(self):
...@@ -116,7 +116,7 @@ class ReinforceController(nn.Module): ...@@ -116,7 +116,7 @@ class ReinforceController(nn.Module):
def _lstm_next_step(self): def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c)) self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field): def _sample_single(self, field, return_prob):
self._lstm_next_step() self._lstm_next_step()
logit = self.soft[field.name](self._h[-1]) logit = self.soft[field.name](self._h[-1])
if self.temperature is not None: if self.temperature is not None:
...@@ -124,10 +124,12 @@ class ReinforceController(nn.Module): ...@@ -124,10 +124,12 @@ class ReinforceController(nn.Module):
if self.tanh_constant is not None: if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit) logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one: if field.choose_one:
sampled_dist = F.softmax(logit, dim=-1)
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled) log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled) self._inputs = self.embedding[field.name](sampled)
else: else:
sampled_dist = torch.sigmoid(logit)
logit = logit.view(-1, 1) logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
...@@ -147,4 +149,7 @@ class ReinforceController(nn.Module): ...@@ -147,4 +149,7 @@ class ReinforceController(nn.Module):
self.sample_entropy += self.entropy_reduction(entropy) self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1: if len(sampled) == 1:
sampled = sampled[0] sampled = sampled[0]
if return_prob:
return sampled_dist.flatten().detach().cpu().numpy().tolist()
return sampled return sampled
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
from __future__ import annotations from __future__ import annotations
import warnings import warnings
from typing import Any from typing import Any, cast
import pytorch_lightning as pl import pytorch_lightning as pl
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.optim as optim import torch.optim as optim
from .base_lightning import BaseOneShotLightningModule, MutationHook, no_default_hook from .base_lightning import MANUAL_OPTIMIZATION_NOTE, BaseOneShotLightningModule, MutationHook, no_default_hook
from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES from .supermodule.operation import NATIVE_MIXED_OPERATIONS, NATIVE_SUPPORTED_OP_NAMES
from .supermodule.sampling import ( from .supermodule.sampling import (
PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy, PathSamplingInput, PathSamplingLayer, MixedOpPathSamplingPolicy,
...@@ -37,6 +37,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule): ...@@ -37,6 +37,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
* :class:`nni.retiarii.nn.pytorch.Cell`. * :class:`nni.retiarii.nn.pytorch.Cell`.
* :class:`nni.retiarii.nn.pytorch.NasBench201Cell`. * :class:`nni.retiarii.nn.pytorch.NasBench201Cell`.
This strategy assumes inner evaluator has set
`automatic optimization <https://pytorch-lightning.readthedocs.io/en/stable/common/optimization.html>`__ to true.
Parameters Parameters
---------- ----------
{{module_params}} {{module_params}}
...@@ -73,9 +76,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule): ...@@ -73,9 +76,9 @@ class RandomSamplingLightningModule(BaseOneShotLightningModule):
'mixed_op_sampling': MixedOpPathSamplingPolicy 'mixed_op_sampling': MixedOpPathSamplingPolicy
} }
def training_step(self, batch, batch_idx): def training_step(self, *args, **kwargs):
self.resample() self.resample()
return self.model.training_step(batch, batch_idx) return self.model.training_step(*args, **kwargs)
def export(self) -> dict[str, Any]: def export(self) -> dict[str, Any]:
""" """
...@@ -115,6 +118,8 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -115,6 +118,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
{{module_notes}} {{module_notes}}
{optimization_note}
Parameters Parameters
---------- ----------
{{module_params}} {{module_params}}
...@@ -133,6 +138,8 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -133,6 +138,8 @@ class EnasLightningModule(RandomSamplingLightningModule):
before updating the weights of RL controller. before updating the weights of RL controller.
ctrl_grad_clip : float ctrl_grad_clip : float
Gradient clipping value of controller. Gradient clipping value of controller.
log_prob_every_n_step : int
Log the probability of choices every N steps. Useful for visualization and debugging.
reward_metric_name : str or None reward_metric_name : str or None
The name of the metric which is treated as reward. The name of the metric which is treated as reward.
This will be not effective when there's only one metric returned from evaluator. This will be not effective when there's only one metric returned from evaluator.
...@@ -141,11 +148,12 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -141,11 +148,12 @@ class EnasLightningModule(RandomSamplingLightningModule):
Otherwise it raises an exception indicating multiple metrics are found. Otherwise it raises an exception indicating multiple metrics are found.
""".format( """.format(
base_params=BaseOneShotLightningModule._mutation_hooks_note, base_params=BaseOneShotLightningModule._mutation_hooks_note,
supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES) supported_ops=', '.join(NATIVE_SUPPORTED_OP_NAMES),
optimization_note=MANUAL_OPTIMIZATION_NOTE
) )
__doc__ = _enas_note.format( __doc__ = _enas_note.format(
module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.', module_notes='``ENASModule`` should be trained with :class:`nni.retiarii.oneshot.pytorch.dataloader.ConcatLoader`.',
module_params=BaseOneShotLightningModule._inner_module_note, module_params=BaseOneShotLightningModule._inner_module_note,
) )
...@@ -162,6 +170,7 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -162,6 +170,7 @@ class EnasLightningModule(RandomSamplingLightningModule):
baseline_decay: float = .999, baseline_decay: float = .999,
ctrl_steps_aggregate: float = 20, ctrl_steps_aggregate: float = 20,
ctrl_grad_clip: float = 0, ctrl_grad_clip: float = 0,
log_prob_every_n_step: int = 10,
reward_metric_name: str | None = None, reward_metric_name: str | None = None,
mutation_hooks: list[MutationHook] | None = None): mutation_hooks: list[MutationHook] | None = None):
super().__init__(inner_module, mutation_hooks) super().__init__(inner_module, mutation_hooks)
...@@ -181,33 +190,29 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -181,33 +190,29 @@ class EnasLightningModule(RandomSamplingLightningModule):
self.baseline = 0. self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.ctrl_grad_clip = ctrl_grad_clip self.ctrl_grad_clip = ctrl_grad_clip
self.log_prob_every_n_step = log_prob_every_n_step
self.reward_metric_name = reward_metric_name self.reward_metric_name = reward_metric_name
def configure_architecture_optimizers(self): def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4) return optim.Adam(self.controller.parameters(), lr=3.5e-4)
def training_step(self, batch_packed, batch_idx): def training_step(self, batch_packed, batch_idx):
# The received batch is a tuple of (data, "train" | "val")
batch, mode = batch_packed batch, mode = batch_packed
if mode == 'train': if mode == 'train':
# train model params # train model params
with torch.no_grad(): with torch.no_grad():
self.resample() self.resample()
self.call_weight_optimizers('zero_grad')
step_output = self.model.training_step(batch, batch_idx) step_output = self.model.training_step(batch, batch_idx)
w_step_loss = step_output['loss'] \ w_step_loss = step_output['loss'] if isinstance(step_output, dict) else step_output
if isinstance(step_output, dict) else step_output self.advance_optimization(w_step_loss, batch_idx)
self.manual_backward(w_step_loss)
self.call_weight_optimizers('step')
else: else:
# train ENAS agent # train ENAS agent
arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad()
self.resample()
# Run a sample to retrieve the reward
self.resample()
step_output = self.model.validation_step(batch, batch_idx) step_output = self.model.validation_step(batch, batch_idx)
# use the default metric of self.model as reward function # use the default metric of self.model as reward function
...@@ -218,11 +223,13 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -218,11 +223,13 @@ class EnasLightningModule(RandomSamplingLightningModule):
if metric_name not in self.trainer.callback_metrics: if metric_name not in self.trainer.callback_metrics:
raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but ' raise KeyError(f'Model reported metrics should contain a ``{metric_name}`` key but '
f'found multiple (or zero) metrics without default: {list(self.trainer.callback_metrics.keys())}. ' f'found multiple (or zero) metrics without default: {list(self.trainer.callback_metrics.keys())}. '
f'Try to use self.log to report metrics with the specified key ``{metric_name}`` in validation_step, ' f'Please try to set ``reward_metric_name`` to be one of the keys listed above. '
'and remember to set on_step=True.') f'If it is not working use self.log to report metrics with the specified key ``{metric_name}`` '
'in validation_step, and remember to set on_step=True.')
metric = self.trainer.callback_metrics[metric_name] metric = self.trainer.callback_metrics[metric_name]
reward: float = metric.item() reward: float = metric.item()
# Compute the loss and run back propagation
if self.entropy_weight: if self.entropy_weight:
reward = reward + self.entropy_weight * self.controller.sample_entropy.item() # type: ignore reward = reward + self.entropy_weight * self.controller.sample_entropy.item() # type: ignore
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay) self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
...@@ -236,11 +243,29 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -236,11 +243,29 @@ class EnasLightningModule(RandomSamplingLightningModule):
if (batch_idx + 1) % self.ctrl_steps_aggregate == 0: if (batch_idx + 1) % self.ctrl_steps_aggregate == 0:
if self.ctrl_grad_clip > 0: if self.ctrl_grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.ctrl_grad_clip) nn.utils.clip_grad_norm_(self.controller.parameters(), self.ctrl_grad_clip)
# Update the controller and zero out its gradients
arc_opt = cast(optim.Optimizer, self.architecture_optimizers())
arc_opt.step() arc_opt.step()
arc_opt.zero_grad() arc_opt.zero_grad()
self.advance_lr_schedulers(batch_idx)
if (batch_idx + 1) % self.log_prob_every_n_step == 0:
with torch.no_grad():
self.log_dict({'prob/' + k: v for k, v in self.export_probs().items()})
return step_output return step_output
def on_train_epoch_start(self):
# Always zero out the gradients of ENAS controller at the beginning of epochs.
arc_opt = self.architecture_optimizers()
if not isinstance(arc_opt, optim.Optimizer):
raise TypeError(f'Expect arc_opt to be a single Optimizer, but found: {arc_opt}')
arc_opt.zero_grad()
return self.model.on_train_epoch_start()
def resample(self): def resample(self):
"""Resample the architecture with ENAS controller.""" """Resample the architecture with ENAS controller."""
sample = self.controller.resample() sample = self.controller.resample()
...@@ -249,6 +274,14 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -249,6 +274,14 @@ class EnasLightningModule(RandomSamplingLightningModule):
module.resample(memo=result) module.resample(memo=result)
return result return result
def export_probs(self):
"""Export the probability from ENAS controller directly."""
sample = self.controller.resample(return_prob=True)
result = self._interpret_controller_probability_result(sample)
for module in self.nas_modules:
module.resample(memo=result)
return result
def export(self): def export(self):
"""Run one more inference of ENAS controller.""" """Run one more inference of ENAS controller."""
self.controller.eval() self.controller.eval()
...@@ -261,3 +294,14 @@ class EnasLightningModule(RandomSamplingLightningModule): ...@@ -261,3 +294,14 @@ class EnasLightningModule(RandomSamplingLightningModule):
for key in list(sample.keys()): for key in list(sample.keys()):
sample[key] = space_spec[key].values[sample[key]] sample[key] = space_spec[key].values[sample[key]]
return sample return sample
def _interpret_controller_probability_result(self, sample: dict[str, list[float]]) -> dict[str, Any]:
"""Convert ``{label: [prob1, prob2, prob3]} to ``{label/choice: prob}``"""
space_spec = self.search_space_spec()
result = {}
for key in list(sample.keys()):
if len(space_spec[key].values) != len(sample[key]):
raise ValueError(f'Expect {space_spec[key].values} to be of the same length as {sample[key]}')
for value, weight in zip(space_spec[key].values, sample[key]):
result[f'{key}/{value}'] = weight
return result
...@@ -168,11 +168,11 @@ def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence ...@@ -168,11 +168,11 @@ def weighted_sum(items: list[T], weights: Sequence[float | None] = cast(Sequence
assert len(items) == len(weights) > 0 assert len(items) == len(weights) > 0
elem = items[0] elem = items[0]
unsupported_msg = f'Unsupported element type in weighted sum: {type(elem)}. Value is: {elem}' unsupported_msg = 'Unsupported element type in weighted sum: {}. Value is: {}'
if isinstance(elem, str): if isinstance(elem, str):
# Need to check this first. Otherwise it goes into sequence and causes infinite recursion. # Need to check this first. Otherwise it goes into sequence and causes infinite recursion.
raise TypeError(unsupported_msg) raise TypeError(unsupported_msg.format(type(elem), elem))
try: try:
if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)): if isinstance(elem, (torch.Tensor, np.ndarray, float, int, np.number)):
......
...@@ -56,6 +56,17 @@ class BaseSuperNetModule(nn.Module): ...@@ -56,6 +56,17 @@ class BaseSuperNetModule(nn.Module):
""" """
raise NotImplementedError() raise NotImplementedError()
def export_probs(self, memo: dict[str, Any]) -> dict[str, Any]:
"""
Export the probability / logits of every choice got chosen.
Parameters
----------
memo : dict[str, Any]
Use memo to avoid the same label gets exported multiple times.
"""
raise NotImplementedError()
def search_space_spec(self) -> dict[str, ParameterSpec]: def search_space_spec(self) -> dict[str, ParameterSpec]:
""" """
Space specification (sample points). Space specification (sample points).
......
...@@ -104,6 +104,13 @@ class DifferentiableMixedLayer(BaseSuperNetModule): ...@@ -104,6 +104,13 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
return {} # nothing new to export return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]} return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def export_probs(self, memo):
if any(k.startswith(self.label + '/') for k in memo):
return {} # nothing new
weights = self._softmax(self._arch_alpha).cpu().tolist()
ret = {f'{self.label}/{name}': value for name, value in zip(self.op_names, weights)}
return ret
def search_space_spec(self): def search_space_spec(self):
return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ), return {self.label: ParameterSpec(self.label, 'choice', self.op_names, (self.label, ),
True, size=len(self.op_names))} True, size=len(self.op_names))}
...@@ -117,7 +124,8 @@ class DifferentiableMixedLayer(BaseSuperNetModule): ...@@ -117,7 +124,8 @@ class DifferentiableMixedLayer(BaseSuperNetModule):
if len(alpha) != size: if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}') raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else: else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later
memo[module.label] = alpha
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(list(module.named_children()), alpha, softmax, module.label) return cls(list(module.named_children()), alpha, softmax, module.label)
...@@ -208,6 +216,13 @@ class DifferentiableMixedInput(BaseSuperNetModule): ...@@ -208,6 +216,13 @@ class DifferentiableMixedInput(BaseSuperNetModule):
chosen = chosen[0] chosen = chosen[0]
return {self.label: chosen} return {self.label: chosen}
def export_probs(self, memo):
if any(k.startswith(self.label + '/') for k in memo):
return {} # nothing new
weights = self._softmax(self._arch_alpha).cpu().tolist()
ret = {f'{self.label}/{index}': value for index, value in enumerate(weights)}
return ret
def search_space_spec(self): def search_space_spec(self):
return { return {
self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)), self.label: ParameterSpec(self.label, 'choice', list(range(self.n_candidates)),
...@@ -225,7 +240,8 @@ class DifferentiableMixedInput(BaseSuperNetModule): ...@@ -225,7 +240,8 @@ class DifferentiableMixedInput(BaseSuperNetModule):
if len(alpha) != size: if len(alpha) != size:
raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}') raise ValueError(f'Architecture parameter size of same label {module.label} conflict: {len(alpha)} vs. {size}')
else: else:
alpha = nn.Parameter(torch.randn(size) * 1E-3) # this can be reinitialized later alpha = nn.Parameter(torch.randn(size) * 1E-3) # the numbers in the parameter can be reinitialized later
memo[module.label] = alpha
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label) return cls(module.n_candidates, module.n_chosen, alpha, softmax, module.label)
...@@ -284,6 +300,7 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): ...@@ -284,6 +300,7 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}') raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else: else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3) alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
memo[name] = alpha
operation._arch_alpha[name] = alpha operation._arch_alpha[name] = alpha
operation.parameters = functools.partial(self.parameters, module=operation) # bind self operation.parameters = functools.partial(self.parameters, module=operation) # bind self
...@@ -321,6 +338,16 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy): ...@@ -321,6 +338,16 @@ class MixedOpDifferentiablePolicy(MixedOperationSamplingPolicy):
result[name] = spec.values[chosen_index] result[name] = spec.values[chosen_index]
return result return result
def export_probs(self, operation: MixedOperation, memo: dict[str, Any]):
"""Export the weight for every leaf value choice."""
ret = {}
for name, spec in operation.search_space_spec().items():
if any(k.startswith(name + '/') for k in memo):
continue
weights = operation._softmax(operation._arch_alpha[name]).cpu().tolist() # type: ignore
ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)})
return ret
def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any: def forward_argument(self, operation: MixedOperation, name: str) -> dict[Any, float] | Any:
if name in operation.mutable_arguments: if name in operation.mutable_arguments:
weights: dict[str, torch.Tensor] = { weights: dict[str, torch.Tensor] = {
...@@ -360,6 +387,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): ...@@ -360,6 +387,7 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}') raise ValueError(f'Architecture parameter size of same label {name} conflict: {len(alpha)} vs. {spec.size}')
else: else:
alpha = nn.Parameter(torch.randn(spec.size) * 1E-3) alpha = nn.Parameter(torch.randn(spec.size) * 1E-3)
memo[name] = alpha
self._arch_alpha[name] = alpha self._arch_alpha[name] = alpha
def resample(self, memo): def resample(self, memo):
...@@ -376,6 +404,16 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): ...@@ -376,6 +404,16 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
result[name] = spec.values[chosen_index] result[name] = spec.values[chosen_index]
return result return result
def export_probs(self, memo):
"""Export the weight for every leaf value choice."""
ret = {}
for name, spec in self.search_space_spec().items():
if any(k.startswith(name + '/') for k in memo):
continue
weights = self._softmax(self._arch_alpha[name]).cpu().tolist()
ret.update({f'{name}/{value}': weight for value, weight in zip(spec.values, weights)})
return ret
def search_space_spec(self): def search_space_spec(self):
return self._space_spec return self._space_spec
...@@ -427,6 +465,8 @@ class DifferentiableMixedRepeat(BaseSuperNetModule): ...@@ -427,6 +465,8 @@ class DifferentiableMixedRepeat(BaseSuperNetModule):
class DifferentiableMixedCell(PathSamplingCell): class DifferentiableMixedCell(PathSamplingCell):
"""Implementation of Cell under differentiable context. """Implementation of Cell under differentiable context.
Similar to PathSamplingCell, this cell only handles cells of specific kinds (e.g., with loose end).
An architecture parameter is created on each edge of the full-connected graph. An architecture parameter is created on each edge of the full-connected graph.
""" """
...@@ -450,13 +490,21 @@ class DifferentiableMixedCell(PathSamplingCell): ...@@ -450,13 +490,21 @@ class DifferentiableMixedCell(PathSamplingCell):
op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j] op = cast(List[Dict[str, nn.Module]], self.ops[i - self.num_predecessors])[j]
if edge_label in memo: if edge_label in memo:
alpha = memo[edge_label] alpha = memo[edge_label]
if len(alpha) != len(op): if len(alpha) != len(op) + 1:
raise ValueError( if len(alpha) != len(op):
f'Architecture parameter size of same label {edge_label} conflict: ' raise ValueError(
f'{len(alpha)} vs. {len(op)}' f'Architecture parameter size of same label {edge_label} conflict: '
f'{len(alpha)} vs. {len(op)}'
)
warnings.warn(
f'Architecture parameter size {len(alpha)} is not same as expected: {len(op) + 1}. '
'This is likely due to the label being shared by a LayerChoice inside the cell and outside.',
UserWarning
) )
else: else:
alpha = nn.Parameter(torch.randn(len(op)) * 1E-3) # +1 to emulate the input choice.
alpha = nn.Parameter(torch.randn(len(op) + 1) * 1E-3)
memo[edge_label] = alpha
self._arch_alpha[edge_label] = alpha self._arch_alpha[edge_label] = alpha
self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1)) self._softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
...@@ -465,18 +513,32 @@ class DifferentiableMixedCell(PathSamplingCell): ...@@ -465,18 +513,32 @@ class DifferentiableMixedCell(PathSamplingCell):
"""Differentiable doesn't need to resample.""" """Differentiable doesn't need to resample."""
return {} return {}
def export_probs(self, memo):
"""When export probability, we follow the structure in arch alpha."""
ret = {}
for name, parameter in self._arch_alpha.items():
if any(k.startswith(name + '/') for k in memo):
continue
weights = self._softmax(parameter).cpu().tolist()
ret.update({f'{name}/{value}': weight for value, weight in zip(self.op_names, weights)})
return ret
def export(self, memo): def export(self, memo):
"""Tricky export. """Tricky export.
Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135 Reference: https://github.com/quark0/darts/blob/f276dd346a09ae3160f8e3aca5c7b193fda1da37/cnn/model_search.py#L135
We don't avoid selecting operations like ``none`` here, because it looks like a different search space.
""" """
exported = {} exported = {}
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors): for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
# If label already exists, no need to re-export.
if all(f'{self.label}/op_{i}_{k}' in memo and f'{self.label}/input_{i}_{k}' in memo for k in range(self.num_ops_per_node)):
continue
# Tuple of (weight, input_index, op_name) # Tuple of (weight, input_index, op_name)
all_weights: list[tuple[float, int, str]] = [] all_weights: list[tuple[float, int, str]] = []
for j in range(i): for j in range(i):
for k, name in enumerate(self.op_names): for k, name in enumerate(self.op_names):
# The last appended weight is automatically skipped in export.
all_weights.append(( all_weights.append((
float(self._arch_alpha[f'{self.label}/{i}_{j}'][k].item()), float(self._arch_alpha[f'{self.label}/{i}_{j}'][k].item()),
j, name, j, name,
...@@ -497,7 +559,7 @@ class DifferentiableMixedCell(PathSamplingCell): ...@@ -497,7 +559,7 @@ class DifferentiableMixedCell(PathSamplingCell):
all_weights = [all_weights[k] for k in first_occurrence_index] + \ all_weights = [all_weights[k] for k in first_occurrence_index] + \
[w for j, w in enumerate(all_weights) if j not in first_occurrence_index] [w for j, w in enumerate(all_weights) if j not in first_occurrence_index]
_logger.info('Sorted weights in differentiable cell export (node %d): %s', i, all_weights) _logger.info('Sorted weights in differentiable cell export (%s cell, node %d): %s', self.label, i, all_weights)
for k in range(self.num_ops_per_node): for k in range(self.num_ops_per_node):
# all_weights could be too short in case ``num_ops_per_node`` is too large. # all_weights could be too short in case ``num_ops_per_node`` is too large.
...@@ -515,7 +577,11 @@ class DifferentiableMixedCell(PathSamplingCell): ...@@ -515,7 +577,11 @@ class DifferentiableMixedCell(PathSamplingCell):
for j in range(i): # for every previous tensors for j in range(i): # for every previous tensors
op_results = torch.stack([op(states[j]) for op in ops[j].values()]) op_results = torch.stack([op(states[j]) for op in ops[j].values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) alpha_shape = [-1] + [1] * (len(op_results.size()) - 1) # (-1, 1, 1, 1, 1, ...)
op_weights = self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}'])
if len(op_weights) == len(op_results) + 1:
# concatenate with a zero operation, indicating this path is not chosen at all.
op_results = torch.cat((op_results, torch.zeros_like(op_results[:1])), 0)
edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0) edge_sum = torch.sum(op_results * self._softmax(self._arch_alpha[f'{self.label}/{i}_{j}']).view(*alpha_shape), 0)
current_state.append(edge_sum) current_state.append(edge_sum)
......
...@@ -71,6 +71,10 @@ class MixedOperationSamplingPolicy: ...@@ -71,6 +71,10 @@ class MixedOperationSamplingPolicy:
"""The handler of :meth:`MixedOperation.export`.""" """The handler of :meth:`MixedOperation.export`."""
raise NotImplementedError() raise NotImplementedError()
def export_probs(self, operation: 'MixedOperation', memo: dict[str, Any]) -> dict[str, Any]:
"""The handler of :meth:`MixedOperation.export_probs`."""
raise NotImplementedError()
def forward_argument(self, operation: 'MixedOperation', name: str) -> Any: def forward_argument(self, operation: 'MixedOperation', name: str) -> Any:
"""Computing the argument with ``name`` used in operation's forward. """Computing the argument with ``name`` used in operation's forward.
Usually a value, or a distribution of value. Usually a value, or a distribution of value.
...@@ -162,6 +166,10 @@ class MixedOperation(BaseSuperNetModule): ...@@ -162,6 +166,10 @@ class MixedOperation(BaseSuperNetModule):
"""Delegates to :meth:`MixedOperationSamplingPolicy.resample`.""" """Delegates to :meth:`MixedOperationSamplingPolicy.resample`."""
return self.sampling_policy.resample(self, memo) return self.sampling_policy.resample(self, memo)
def export_probs(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export_probs`."""
return self.sampling_policy.export_probs(self, memo)
def export(self, memo): def export(self, memo):
"""Delegates to :meth:`MixedOperationSamplingPolicy.export`.""" """Delegates to :meth:`MixedOperationSamplingPolicy.export`."""
return self.sampling_policy.export(self, memo) return self.sampling_policy.export(self, memo)
......
...@@ -11,7 +11,7 @@ The support remains limited. Known limitations include: ...@@ -11,7 +11,7 @@ The support remains limited. Known limitations include:
from __future__ import annotations from __future__ import annotations
from typing import cast from typing import Any, Tuple, Union, cast
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -21,28 +21,115 @@ from .differentiable import DifferentiableMixedLayer, DifferentiableMixedInput ...@@ -21,28 +21,115 @@ from .differentiable import DifferentiableMixedLayer, DifferentiableMixedInput
__all__ = ['ProxylessMixedLayer', 'ProxylessMixedInput'] __all__ = ['ProxylessMixedLayer', 'ProxylessMixedInput']
class _ArchGradientFunction(torch.autograd.Function): def _detach_tensor(tensor: Any) -> Any:
@staticmethod """Recursively detach all the tensors."""
def forward(ctx, x, binary_gates, run_func, backward_func): if isinstance(tensor, (list, tuple)):
ctx.run_func = run_func return tuple(_detach_tensor(t) for t in tensor)
ctx.backward_func = backward_func elif isinstance(tensor, dict):
return {k: _detach_tensor(v) for k, v in tensor.items()}
elif isinstance(tensor, torch.Tensor):
return tensor.detach()
else:
return tensor
detached_x = x.detach()
detached_x.requires_grad = x.requires_grad
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod def _iter_tensors(tensor: Any) -> Any:
def backward(ctx, grad_output): """Recursively iterate over all the tensors.
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True) This is kept for complex outputs (like dicts / lists).
# compute gradients w.r.t. binary_gates However, complex outputs are not supported by PyTorch backward hooks yet.
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data) """
if isinstance(tensor, torch.Tensor):
yield tensor
elif isinstance(tensor, (list, tuple)):
for t in tensor:
yield from _iter_tensors(t)
elif isinstance(tensor, dict):
for t in tensor.values():
yield from _iter_tensors(t)
def _pack_as_tuple(tensor: Any) -> tuple:
"""Return a tuple of tensor with only one element if tensor it's not a tuple."""
if isinstance(tensor, (tuple, list)):
return tuple(tensor)
return (tensor,)
def element_product_sum(tensor1: tuple[torch.Tensor, ...], tensor2: tuple[torch.Tensor, ...]) -> torch.Tensor:
"""Compute the sum of all the element-wise product."""
assert len(tensor1) == len(tensor2), 'The number of tensors must be the same.'
# Skip zero gradients
ret = [torch.sum(t1 * t2) for t1, t2 in zip(tensor1, tensor2) if t1 is not None and t2 is not None]
if not ret:
return torch.tensor(0)
if len(ret) == 1:
return ret[0]
return cast(torch.Tensor, sum(ret))
class ProxylessContext:
def __init__(self, arch_alpha: torch.Tensor, softmax: nn.Module) -> None:
self.arch_alpha = arch_alpha
self.softmax = softmax
# When a layer is called multiple times, the inputs and outputs are saved in order.
# In backward propagation, we assume that they are used in the reversed order.
self.layer_input: list[Any] = []
self.layer_output: list[Any] = []
self.layer_sample_idx: list[int] = []
def clear_context(self) -> None:
self.layer_input = []
self.layer_output = []
self.layer_sample_idx = []
def save_forward_context(self, layer_input: Any, layer_output: Any, layer_sample_idx: int):
self.layer_input.append(_detach_tensor(layer_input))
self.layer_output.append(_detach_tensor(layer_output))
self.layer_sample_idx.append(layer_sample_idx)
def backward_hook(self, module: nn.Module,
grad_input: Union[Tuple[torch.Tensor, ...], torch.Tensor],
grad_output: Union[Tuple[torch.Tensor, ...], torch.Tensor]) -> None:
# binary_grads is the gradient of binary gates.
# Binary gates is a one-hot tensor where 1 is on the sampled index, and others are 0.
# By chain rule, it's gradient is grad_output times the layer_output (of the corresponding path).
binary_grads = torch.zeros_like(self.arch_alpha)
# Retrieve the layer input/output in reverse order.
if not self.layer_input:
raise ValueError('Unexpected backward call. The saved context is empty.')
layer_input = self.layer_input.pop()
layer_output = self.layer_output.pop()
layer_sample_idx = self.layer_sample_idx.pop()
return grad_x[0], binary_grads, None, None with torch.no_grad():
# Compute binary grads.
for k in range(len(binary_grads)):
if k != layer_sample_idx:
args, kwargs = layer_input
out_k = module.forward_path(k, *args, **kwargs) # type: ignore
else:
out_k = layer_output
# FIXME: One limitation here is that out_k can't be complex objects like dict.
# I think it's also a limitation of backward hook.
binary_grads[k] = element_product_sum(
_pack_as_tuple(out_k), # In case out_k is a single tensor
_pack_as_tuple(grad_output)
)
# Compute the gradient of the arch_alpha, based on binary_grads.
if self.arch_alpha.grad is None:
self.arch_alpha.grad = torch.zeros_like(self.arch_alpha)
probs = self.softmax(self.arch_alpha)
for i in range(len(self.arch_alpha)):
for j in range(len(self.arch_alpha)):
# Arch alpha's gradients are accumulated for all backwards through this layer.
self.arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessMixedLayer(DifferentiableMixedLayer): class ProxylessMixedLayer(DifferentiableMixedLayer):
...@@ -50,46 +137,32 @@ class ProxylessMixedLayer(DifferentiableMixedLayer): ...@@ -50,46 +137,32 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
It resamples a single-path every time, rather than go through the softmax. It resamples a single-path every time, rather than go through the softmax.
""" """
_arch_parameter_names = ['_arch_alpha', '_binary_gates'] _arch_parameter_names = ['_arch_alpha']
def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str): def __init__(self, paths: list[tuple[str, nn.Module]], alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(paths, alpha, softmax, label) super().__init__(paths, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(len(paths)) * 1E-3) # Binary gates should be created here, but it's not because it's never used in the forward pass.
# self._binary_gates = nn.Parameter(torch.zeros(len(paths)))
# like sampling-based methods, it has a ``_sampled``. # like sampling-based methods, it has a ``_sampled``.
self._sampled: str | None = None self._sampled: str | None = None
self._sample_idx: int | None = None self._sample_idx: int | None = None
# arch_alpha could be shared by multiple layers,
# but binary_gates is owned by the current layer.
self.ctx = ProxylessContext(alpha, softmax)
self.register_full_backward_hook(self.ctx.backward_hook)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
def run_function(ops, active_id, **kwargs): """Forward pass of one single path."""
def forward(_x): if self._sample_idx is None:
return ops[active_id](_x, **kwargs) raise RuntimeError('resample() needs to be called before fprop.')
return forward output = self.forward_path(self._sample_idx, *args, **kwargs)
self.ctx.save_forward_context((args, kwargs), output, self._sample_idx)
def backward_function(ops, active_id, binary_gates, **kwargs): return output
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data) def forward_path(self, index, *args, **kwargs):
with torch.no_grad(): return getattr(self, self.op_names[index])(*args, **kwargs)
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data, **kwargs)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1, 'ProxylessMixedLayer only supports exactly one input argument.'
x = args[0]
assert self._sampled is not None, 'Need to call resample() before running fprop.'
list_ops = [getattr(self, op) for op in self.op_names]
return _ArchGradientFunction.apply(
x, self._binary_gates, run_function(list_ops, self._sample_idx, **kwargs),
backward_function(list_ops, self._sample_idx, self._binary_gates, **kwargs)
)
def resample(self, memo): def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo.""" """Sample one path based on alpha if label is not found in memo."""
...@@ -101,66 +174,37 @@ class ProxylessMixedLayer(DifferentiableMixedLayer): ...@@ -101,66 +174,37 @@ class ProxylessMixedLayer(DifferentiableMixedLayer):
self._sample_idx = int(torch.multinomial(probs, 1)[0].item()) self._sample_idx = int(torch.multinomial(probs, 1)[0].item())
self._sampled = self.op_names[self._sample_idx] self._sampled = self.op_names[self._sample_idx]
# set binary gates self.ctx.clear_context()
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[self._sample_idx] = 1.0
return {self.label: self._sampled} return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: self.op_names[int(torch.argmax(self._arch_alpha).item())]}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(len(self._arch_alpha)):
for j in range(len(self._arch_alpha)):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessMixedInput(DifferentiableMixedInput): class ProxylessMixedInput(DifferentiableMixedInput):
"""Proxyless version of differentiable input choice. """Proxyless version of differentiable input choice.
See :class:`ProxylessLayerChoice` for implementation details. See :class:`ProxylessMixedLayer` for implementation details.
""" """
_arch_parameter_names = ['_arch_alpha', '_binary_gates'] _arch_parameter_names = ['_arch_alpha', '_binary_gates']
def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str): def __init__(self, n_candidates: int, n_chosen: int | None, alpha: torch.Tensor, softmax: nn.Module, label: str):
super().__init__(n_candidates, n_chosen, alpha, softmax, label) super().__init__(n_candidates, n_chosen, alpha, softmax, label)
self._binary_gates = nn.Parameter(torch.randn(n_candidates) * 1E-3)
# We only support choosing a particular one here.
# Nevertheless, we rank the score and export the tops in export.
self._sampled: int | None = None self._sampled: int | None = None
self.ctx = ProxylessContext(alpha, softmax)
self.register_full_backward_hook(self.ctx.backward_hook)
def forward(self, inputs): def forward(self, inputs):
def run_function(active_sample): """Choose one single input."""
return lambda x: x[active_sample] if self._sampled is None:
raise RuntimeError('resample() needs to be called before fprop.')
def backward_function(binary_gates): output = self.forward_path(self._sampled, inputs)
def backward(_x, _output, grad_output): self.ctx.save_forward_context(((inputs,), {}), output, self._sampled)
binary_grads = torch.zeros_like(binary_gates.data) return output
with torch.no_grad():
for k in range(self.n_candidates): def forward_path(self, index, inputs):
out_k = _x[k].data return inputs[index]
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
inputs = torch.stack(inputs, 0)
assert self._sampled is not None, 'Need to call resample() before running fprop.'
return _ArchGradientFunction.apply(
inputs, self._binary_gates, run_function(self._sampled),
backward_function(self._binary_gates)
)
def resample(self, memo): def resample(self, memo):
"""Sample one path based on alpha if label is not found in memo.""" """Sample one path based on alpha if label is not found in memo."""
...@@ -171,27 +215,6 @@ class ProxylessMixedInput(DifferentiableMixedInput): ...@@ -171,27 +215,6 @@ class ProxylessMixedInput(DifferentiableMixedInput):
sample = torch.multinomial(probs, 1)[0].item() sample = torch.multinomial(probs, 1)[0].item()
self._sampled = int(sample) self._sampled = int(sample)
# set binary gates self.ctx.clear_context()
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[cast(int, self._sampled)] = 1.0
return {self.label: self._sampled} return {self.label: self._sampled}
def export(self, memo):
"""Chose the argmax if label isn't found in memo."""
if self.label in memo:
return {} # nothing new to export
return {self.label: torch.argmax(self._arch_alpha).item()}
def finalize_grad(self):
binary_grads = self._binary_gates.grad
assert binary_grads is not None
with torch.no_grad():
if self._arch_alpha.grad is None:
self._arch_alpha.grad = torch.zeros_like(self._arch_alpha.data)
probs = self._softmax(self._arch_alpha)
for i in range(self.n_candidates):
for j in range(self.n_candidates):
self._arch_alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
...@@ -169,7 +169,7 @@ class PathSamplingInput(BaseSuperNetModule): ...@@ -169,7 +169,7 @@ class PathSamplingInput(BaseSuperNetModule):
class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy): class MixedOpPathSamplingPolicy(MixedOperationSamplingPolicy):
"""Implementes the path sampling in mixed operation. """Implements the path sampling in mixed operation.
One mixed operation can have multiple value choices in its arguments. One mixed operation can have multiple value choices in its arguments.
Each value choice can be further decomposed into "leaf value choices". Each value choice can be further decomposed into "leaf value choices".
...@@ -388,6 +388,10 @@ class PathSamplingCell(BaseSuperNetModule): ...@@ -388,6 +388,10 @@ class PathSamplingCell(BaseSuperNetModule):
@classmethod @classmethod
def mutate(cls, module, name, memo, mutate_kwargs): def mutate(cls, module, name, memo, mutate_kwargs):
"""
Mutate only handles cells of specific configurations (e.g., with loose end).
Fallback to the default mutate if the cell is not handled here.
"""
if isinstance(module, Cell): if isinstance(module, Cell):
op_factory = None # not all the cells need to be replaced op_factory = None # not all the cells need to be replaced
if module.op_candidates_factory is not None: if module.op_candidates_factory is not None:
......
...@@ -5,6 +5,7 @@ import pytorch_lightning as pl ...@@ -5,6 +5,7 @@ import pytorch_lightning as pl
import pytest import pytest
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import MNIST from torchvision.datasets import MNIST
from torch import nn
from torch.utils.data import Dataset, RandomSampler from torch.utils.data import Dataset, RandomSampler
import nni import nni
...@@ -13,7 +14,11 @@ from nni.retiarii import strategy, model_wrapper, basic_unit ...@@ -13,7 +14,11 @@ from nni.retiarii import strategy, model_wrapper, basic_unit
from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment from nni.retiarii.experiment.pytorch import RetiariiExeConfig, RetiariiExperiment
from nni.retiarii.evaluator.pytorch.lightning import Classification, Regression, DataLoader from nni.retiarii.evaluator.pytorch.lightning import Classification, Regression, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice from nni.retiarii.nn.pytorch import LayerChoice, InputChoice, ValueChoice
from nni.retiarii.oneshot.pytorch import DartsLightningModule
from nni.retiarii.strategy import BaseStrategy from nni.retiarii.strategy import BaseStrategy
from pytorch_lightning import LightningModule, Trainer
from .test_oneshot_utils import RandomDataset
pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs') pytestmark = pytest.mark.skipif(pl.__version__ < '1.0', reason='Incompatible APIs')
...@@ -338,17 +343,49 @@ def test_gumbel_darts(): ...@@ -338,17 +343,49 @@ def test_gumbel_darts():
_test_strategy(strategy.GumbelDARTS()) _test_strategy(strategy.GumbelDARTS())
if __name__ == '__main__': def test_optimizer_lr_scheduler():
parser = argparse.ArgumentParser() learning_rates = []
parser.add_argument('--exp', type=str, default='all', metavar='E',
help='experiment to run, default = all')
args = parser.parse_args()
if args.exp == 'all': class CustomLightningModule(LightningModule):
test_darts() def __init__(self):
test_proxyless() super().__init__()
test_enas() self.layer1 = nn.Linear(32, 2)
test_random() self.layer2 = nn.LayerChoice([nn.Linear(2, 2), nn.Linear(2, 2, bias=False)])
test_gumbel_darts()
else: def forward(self, x):
globals()[f'test_{args.exp}']() return self.layer2(self.layer1(x))
def configure_optimizers(self):
opt1 = torch.optim.SGD(self.layer1.parameters(), lr=0.1)
opt2 = torch.optim.Adam(self.layer2.parameters(), lr=0.2)
return [opt1, opt2], [torch.optim.lr_scheduler.StepLR(opt1, step_size=2, gamma=0.1)]
def training_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('train_loss', loss)
return {'loss': loss}
def on_train_epoch_start(self) -> None:
learning_rates.append(self.optimizers()[0].param_groups[0]['lr'])
def validation_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('valid_loss', loss)
def test_step(self, batch, batch_idx):
loss = self(batch).sum()
self.log('test_loss', loss)
train_data = RandomDataset(32, 32)
valid_data = RandomDataset(32, 16)
model = CustomLightningModule()
darts_module = DartsLightningModule(model, gradient_clip_val=5)
trainer = Trainer(max_epochs=10)
trainer.fit(
darts_module,
dict(train=DataLoader(train_data, batch_size=8), val=DataLoader(valid_data, batch_size=8))
)
assert len(learning_rates) == 10 and abs(learning_rates[0] - 0.1) < 1e-5 and \
abs(learning_rates[2] - 0.01) < 1e-5 and abs(learning_rates[-1] - 1e-5) < 1e-6
import torch
import torch.nn as nn
from nni.nas.hub.pytorch.nasbench201 import OPS_WITH_STRIDE
from nni.nas.oneshot.pytorch.supermodule.proxyless import ProxylessMixedLayer, ProxylessMixedInput, _iter_tensors
def test_proxyless_bp():
op = ProxylessMixedLayer(
[(name, value(3, 3, 1)) for name, value in OPS_WITH_STRIDE.items()],
nn.Parameter(torch.randn(len(OPS_WITH_STRIDE))),
nn.Softmax(-1), 'proxyless'
)
optimizer = torch.optim.SGD(op.parameters(arch=True), 0.1)
for _ in range(10):
x = torch.randn(1, 3, 9, 9).requires_grad_()
op.resample({})
y = op(x).sum()
optimizer.zero_grad()
y.backward()
assert op._arch_alpha.grad.abs().sum().item() != 0
def test_proxyless_input():
inp = ProxylessMixedInput(6, 2, nn.Parameter(torch.zeros(6)), nn.Softmax(-1), 'proxyless')
optimizer = torch.optim.SGD(inp.parameters(arch=True), 0.1)
for _ in range(10):
x = [torch.randn(1, 3, 9, 9).requires_grad_() for _ in range(6)]
inp.resample({})
y = inp(x).sum()
optimizer.zero_grad()
y.backward()
def test_iter_tensors():
a = (torch.zeros(3, 1), {'a': torch.zeros(5, 1), 'b': torch.zeros(6, 1)}, [torch.zeros(7, 1)])
ret = []
for x in _iter_tensors(a):
ret.append(x.shape[0])
assert ret == [3, 5, 6, 7]
class MultiInputLayer(nn.Module):
def __init__(self, d):
super().__init__()
self.d = d
def forward(self, q, k, v=None, mask=None):
return q + self.d, 2 * k - 2 * self.d, v, mask
def test_proxyless_multi_input():
op = ProxylessMixedLayer(
[
('a', MultiInputLayer(1)),
('b', MultiInputLayer(3))
],
nn.Parameter(torch.randn(2)),
nn.Softmax(-1), 'proxyless'
)
optimizer = torch.optim.SGD(op.parameters(arch=True), 0.1)
for retry in range(10):
q = torch.randn(1, 3, 9, 9).requires_grad_()
k = torch.randn(1, 3, 9, 8).requires_grad_()
v = None if retry < 5 else torch.randn(1, 3, 9, 7).requires_grad_()
mask = None if retry % 5 < 2 else torch.randn(1, 3, 9, 6).requires_grad_()
op.resample({})
y = op(q, k, v, mask=mask)
y = y[0].sum() + y[1].sum()
optimizer.zero_grad()
y.backward()
assert op._arch_alpha.grad.abs().sum().item() != 0, op._arch_alpha.grad
...@@ -3,7 +3,7 @@ import pytest ...@@ -3,7 +3,7 @@ import pytest
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.nn.pytorch import ValueChoice, Conv2d, BatchNorm2d, LayerNorm, Linear, MultiheadAttention from nni.retiarii.nn.pytorch import ValueChoice, LayerChoice, Conv2d, BatchNorm2d, LayerNorm, Linear, MultiheadAttention
from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules from nni.retiarii.oneshot.pytorch.base_lightning import traverse_and_mutate_submodules
from nni.retiarii.oneshot.pytorch.supermodule.differentiable import ( from nni.retiarii.oneshot.pytorch.supermodule.differentiable import (
MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax, MixedOpDifferentiablePolicy, DifferentiableMixedLayer, DifferentiableMixedInput, GumbelSoftmax,
...@@ -144,6 +144,16 @@ def test_differentiable_valuechoice(): ...@@ -144,6 +144,16 @@ def test_differentiable_valuechoice():
assert set(conv.export({}).keys()) == {'123', '456'} assert set(conv.export({}).keys()) == {'123', '456'}
def test_differentiable_layerchoice_dedup():
layerchoice1 = LayerChoice([Conv2d(3, 3, 3), Conv2d(3, 3, 3)], label='a')
layerchoice2 = LayerChoice([Conv2d(3, 3, 3), Conv2d(3, 3, 3)], label='a')
memo = {}
DifferentiableMixedLayer.mutate(layerchoice1, 'x', memo, {})
DifferentiableMixedLayer.mutate(layerchoice2, 'x', memo, {})
assert len(memo) == 1 and 'a' in memo
def _mixed_operation_sampling_sanity_check(operation, memo, *input): def _mixed_operation_sampling_sanity_check(operation, memo, *input):
for native_op in NATIVE_MIXED_OPERATIONS: for native_op in NATIVE_MIXED_OPERATIONS:
if native_op.bound_type == type(operation): if native_op.bound_type == type(operation):
...@@ -160,7 +170,9 @@ def _mixed_operation_differentiable_sanity_check(operation, *input): ...@@ -160,7 +170,9 @@ def _mixed_operation_differentiable_sanity_check(operation, *input):
mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpDifferentiablePolicy}) mutate_op = native_op.mutate(operation, 'dummy', {}, {'mixed_op_sampling': MixedOpDifferentiablePolicy})
break break
return mutate_op(*input) mutate_op(*input)
mutate_op.export({})
mutate_op.export_probs({})
def test_mixed_linear(): def test_mixed_linear():
...@@ -319,6 +331,9 @@ def test_differentiable_layer_input(): ...@@ -319,6 +331,9 @@ def test_differentiable_layer_input():
op = DifferentiableMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee') op = DifferentiableMixedLayer([('a', Linear(2, 3, bias=False)), ('b', Linear(2, 3, bias=True))], nn.Parameter(torch.randn(2)), nn.Softmax(-1), 'eee')
assert op(torch.randn(4, 2)).size(-1) == 3 assert op(torch.randn(4, 2)).size(-1) == 3
assert op.export({})['eee'] in ['a', 'b'] assert op.export({})['eee'] in ['a', 'b']
probs = op.export_probs({})
assert len(probs) == 2
assert abs(probs['eee/a'] + probs['eee/b'] - 1) < 1e-4
assert len(list(op.parameters())) == 3 assert len(list(op.parameters())) == 3
with pytest.raises(ValueError): with pytest.raises(ValueError):
...@@ -328,6 +343,8 @@ def test_differentiable_layer_input(): ...@@ -328,6 +343,8 @@ def test_differentiable_layer_input():
input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd') input = DifferentiableMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2 assert input([torch.randn(4, 2) for _ in range(5)]).size(-1) == 2
assert len(input.export({})['ddd']) == 2 assert len(input.export({})['ddd']) == 2
assert len(input.export_probs({})) == 5
assert 'ddd/3' in input.export_probs({})
def test_proxyless_layer_input(): def test_proxyless_layer_input():
...@@ -341,7 +358,8 @@ def test_proxyless_layer_input(): ...@@ -341,7 +358,8 @@ def test_proxyless_layer_input():
input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd') input = ProxylessMixedInput(5, 2, nn.Parameter(torch.zeros(5)), GumbelSoftmax(-1), 'ddd')
assert input.resample({})['ddd'] in list(range(5)) assert input.resample({})['ddd'] in list(range(5))
assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2]) assert input([torch.randn(4, 2) for _ in range(5)]).size() == torch.Size([4, 2])
assert input.export({})['ddd'] in list(range(5)) exported = input.export({})['ddd']
assert len(exported) == 2 and all(e in list(range(5)) for e in exported)
def test_pathsampling_repeat(): def test_pathsampling_repeat():
...@@ -373,6 +391,7 @@ def test_differentiable_repeat(): ...@@ -373,6 +391,7 @@ def test_differentiable_repeat():
assert op(torch.randn(2, 8)).size() == torch.Size([2, 16]) assert op(torch.randn(2, 8)).size() == torch.Size([2, 16])
sample = op.export({}) sample = op.export({})
assert 'ccc' in sample and sample['ccc'] in [0, 1] assert 'ccc' in sample and sample['ccc'] in [0, 1]
assert sorted(op.export_probs({}).keys()) == ['ccc/0', 'ccc/1']
class TupleModule(nn.Module): class TupleModule(nn.Module):
def __init__(self, num): def __init__(self, num):
...@@ -452,11 +471,16 @@ def test_differentiable_cell(): ...@@ -452,11 +471,16 @@ def test_differentiable_cell():
result.update(module.export(memo=result)) result.update(module.export(memo=result))
assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2 assert len(result) == model.cell.num_nodes * model.cell.num_ops_per_node * 2
result_prob = {}
for module in nas_modules:
result_prob.update(module.export_probs(memo=result_prob))
ctrl_params = [] ctrl_params = []
for m in nas_modules: for m in nas_modules:
ctrl_params += list(m.parameters(arch=True)) ctrl_params += list(m.parameters(arch=True))
if cell_cls in [CellLooseEnd, CellOpFactory]: if cell_cls in [CellLooseEnd, CellOpFactory]:
assert len(ctrl_params) == model.cell.num_nodes * (model.cell.num_nodes + 3) // 2 assert len(ctrl_params) == model.cell.num_nodes * (model.cell.num_nodes + 3) // 2
assert len(result_prob) == len(ctrl_params) * 2 # len(op_names) == 2
assert isinstance(model.cell, DifferentiableMixedCell) assert isinstance(model.cell, DifferentiableMixedCell)
else: else:
assert not isinstance(model.cell, DifferentiableMixedCell) assert not isinstance(model.cell, DifferentiableMixedCell)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment