Unverified Commit 8b2eb425 authored by v-fangdong's avatar v-fangdong Committed by GitHub
Browse files

Lightning implementation for retiarii oneshot nas (#4479)

parent 99818fba
...@@ -5,4 +5,6 @@ from .darts import DartsTrainer ...@@ -5,4 +5,6 @@ from .darts import DartsTrainer
from .enas import EnasTrainer from .enas import EnasTrainer
from .proxyless import ProxylessTrainer from .proxyless import ProxylessTrainer
from .random import SinglePathTrainer, RandomTrainer from .random import SinglePathTrainer, RandomTrainer
from .utils import replace_input_choice, replace_layer_choice from .differentiable import DartsModule, ProxylessModule, SNASModule
from .sampling import EnasModule, RandomSampleModule
from .utils import InterleavedTrainValDataLoader, ConcatenateTrainValDataLoader
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import pytorch_lightning as pl
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import _LRScheduler
def _replace_module_with_type(root_module, replace_dict, modules):
"""
Replace xxxChoice in user's model with NAS modules.
Parameters
----------
root_module : nn.Module
User-defined module with xxxChoice in it. In fact, since this method is called in the ``__init__`` of
``BaseOneShotLightningModule``, this will be a pl.LightningModule.
replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]]
Functions to replace xxxChoice modules. Keys should be xxxChoice type and values should be a
function that return an nn.module.
modules : List[nn.Module]
The replace result. This is also the return value of this function.
Returns
----------
modules : List[nn.Module]
The replace result.
"""
if modules is None:
modules = []
def apply(m):
for name, child in m.named_children():
child_type = type(child)
if child_type in replace_dict.keys():
setattr(m, name, replace_dict[child_type](child))
modules.append((child.key, getattr(m, name)))
else:
apply(child)
apply(root_module)
return modules
class BaseOneShotLightningModule(pl.LightningModule):
"""
The base class for all one-shot NAS modules. Essential function such as preprocessing user's model, redirecting lightning
hooks for user's model, configuring optimizers and exporting NAS result are implemented in this class.
Attributes
----------
nas_modules : List[nn.Module]
The replace result of a specific NAS method. xxxChoice will be replaced with some other modules with respect to the
NAS method.
Parameters
----------
base_model : pl.LightningModule
The evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
"""
automatic_optimization = False
def __init__(self, base_model, custom_replace_dict=None):
super().__init__()
assert isinstance(base_model, pl.LightningModule)
self.model = base_model
# replace xxxChoice with respect to NAS alg
# replaced modules are stored in self.nas_modules
self.nas_modules = []
choice_replace_dict = self.default_replace_dict
if custom_replace_dict is not None:
for k, v in custom_replace_dict.items():
assert isinstance(v, nn.Module)
choice_replace_dict[k] = v
_replace_module_with_type(self.model, choice_replace_dict, self.nas_modules)
def forward(self, x):
return self.model(x)
def training_step(self, batch, batch_idx):
# You can use self.architecture_optimizers or self.user_optimizers to get optimizers in
# your own training step.
return self.model.training_step(batch, batch_idx)
def configure_optimizers(self):
"""
Combine architecture optimizers and user's model optimizers.
You can overwrite configure_architecture_optimizers if architecture optimizers are needed in your NAS algorithm.
By now ``self.model`` is currently a :class:`nni.retiarii.evaluator.pytorch.lightning._SupervisedLearningModule`
and it only returns 1 optimizer. But for extendibility, codes for other return value types are also implemented.
"""
# pylint: disable=assignment-from-none
arc_optimizers = self.configure_architecture_optimizers()
if arc_optimizers is None:
return self.model.configure_optimizers()
if isinstance(arc_optimizers, optim.Optimizer):
arc_optimizers = [arc_optimizers]
self.arc_optim_count = len(arc_optimizers)
# The return values ``frequency`` and ``monitor`` are ignored because lightning requires
# ``len(optimizers) == len(frequency)``, and gradient backword is handled manually.
# For data structure of variables below, please see pytorch lightning docs of ``configure_optimizers``.
w_optimizers, lr_schedulers, self.frequencies, monitor = \
self.trainer._configure_optimizers(self.model.configure_optimizers())
lr_schedulers = self.trainer._configure_schedulers(lr_schedulers, monitor, not self.automatic_optimization)
if any(sch["scheduler"].optimizer not in w_optimizers for sch in lr_schedulers):
raise Exception(
"Some schedulers are attached with an optimizer that wasn't returned from `configure_optimizers`."
)
# variables used to handle optimizer frequency
self.cur_optimizer_step = 0
self.cur_optimizer_index = 0
return arc_optimizers + w_optimizers, lr_schedulers
def on_train_start(self):
self.model.trainer = self.trainer
self.model.log = self.log
return self.model.on_train_start()
def on_train_end(self):
return self.model.on_train_end()
def on_fit_start(self):
return self.model.on_train_start()
def on_fit_end(self):
return self.model.on_train_end()
def on_train_batch_start(self, batch, batch_idx, unused = 0):
return self.model.on_train_batch_start(batch, batch_idx, unused)
def on_train_batch_end(self, outputs, batch, batch_idx, unused = 0):
return self.model.on_train_batch_end(outputs, batch, batch_idx, unused)
def on_epoch_start(self):
return self.model.on_epoch_start()
def on_epoch_end(self):
return self.model.on_epoch_end()
def on_train_epoch_start(self):
return self.model.on_train_epoch_start()
def on_train_epoch_end(self):
return self.model.on_train_epoch_end()
def on_before_backward(self, loss):
return self.model.on_before_backward(loss)
def on_after_backward(self):
return self.model.on_after_backward()
def configure_gradient_clipping(self, optimizer, optimizer_idx, gradient_clip_val = None, gradient_clip_algorithm = None):
return self.model.configure_gradient_clipping(optimizer, optimizer_idx, gradient_clip_val, gradient_clip_algorithm)
def configure_architecture_optimizers(self):
"""
Hook kept for subclasses. A specific NAS method inheriting this base class should return its architecture optimizers here
if architecture parameters are needed. Note that lr schedulers are not supported now for architecture_optimizers.
Returns
----------
arc_optimizers : List[Optimizer], Optimizer
Optimizers used by a specific NAS algorithm. Return None if no architecture optimizers are needed.
"""
return None
@property
def default_replace_dict(self):
"""
Default xxxChoice replace dict. This is called in ``__init__`` to get the default replace functions for your NAS algorithm.
Note that your default replace functions may be overridden by user-defined custom_replace_dict.
Returns
----------
replace_dict : Dict[Type, Callable[nn.Module, nn.Module]]
Same as ``custom_replace_dict`` in ``__init__``, but this will be overridden if users define their own replace functions.
"""
replace_dict = {}
return replace_dict
def call_lr_schedulers(self, batch_index):
"""
Function that imitates lightning trainer's behaviour of calling user's lr schedulers. Since auto_optimization is turned off
by this class, you can use this function to make schedulers behave as they were automatically handled by the lightning trainer.
Parameters
----------
batch_idx : int
batch index
"""
def apply(lr_scheduler):
# single scheduler is called every epoch
if isinstance(lr_scheduler, _LRScheduler) and \
self.trainer.is_last_batch:
lr_schedulers.step()
# lr_scheduler_config is called as configured
elif isinstance(lr_scheduler, dict):
interval = lr_scheduler['interval']
frequency = lr_scheduler['frequency']
if (
interval == 'step' and
batch_index % frequency == 0
) or \
(
interval == 'epoch' and
self.trainer.is_last_batch and
(self.trainer.current_epoch + 1) % frequency == 0
):
lr_scheduler.step()
lr_schedulers = self.lr_schedulers()
if isinstance(lr_schedulers, list):
for lr_scheduler in lr_schedulers:
apply(lr_scheduler)
else:
apply(lr_schedulers)
def call_user_optimizers(self, method):
"""
Function that imitates lightning trainer's behaviour of calling user's optimizers. Since auto_optimization is turned off by this
class, you can use this function to make user optimizers behave as they were automatically handled by the lightning trainer.
Parameters
----------
method : str
Method to call. Only 'step' and 'zero_grad' are supported now.
"""
def apply_method(optimizer, method):
if method == 'step':
optimizer.step()
elif method == 'zero_grad':
optimizer.zero_grad()
optimizers = self.user_optimizers
if optimizers is None:
return
if len(self.frequencies) > 0:
self.cur_optimizer_step += 1
if self.frequencies[self.cur_optimizer_index] == self.cur_optimizer_step:
self.cur_optimizer_step = 0
self.cur_optimizer_index = self.cur_optimizer_index + 1 \
if self.cur_optimizer_index + 1 < len(optimizers) \
else 0
apply_method(optimizers[self.cur_optimizer_index], method)
else:
for optimizer in optimizers:
apply_method(optimizer, method)
@property
def architecture_optimizers(self):
"""
Get architecture optimizers from all optimizers. Use this to get your architecture optimizers in ``training_step``.
Returns
----------
opts : List[Optimizer], Optimizer, None
Architecture optimizers defined in ``configure_architecture_optimizers``. This will be None if there is no
architecture optimizers.
"""
opts = self.optimizers()
if isinstance(opts,list):
# pylint: disable=unsubscriptable-object
arc_opts = opts[:self.arc_optim_count]
if len(arc_opts) == 1:
arc_opts = arc_opts[0]
return arc_opts
# If there is only 1 optimizer and it is the architecture optimizer
if self.arc_optim_count == 1:
return opts
return None
@property
def user_optimizers(self):
"""
Get user optimizers from all optimizers. Use this to get user optimizers in ``training step``.
Returns
----------
opts : List[Optimizer], Optimizer, None
Optimizers defined by user's model. This will be None if there is no user optimizers.
"""
opts = self.optimizers()
if isinstance(opts,list):
# pylint: disable=unsubscriptable-object
return opts[self.arc_optim_count:]
# If there is only 1 optimizer and no architecture optimizer
if self.arc_optim_count == 0:
return opts
return None
def export(self):
"""
Export the NAS result, ideally the best choice of each nas_modules.
You may implement an ``export`` method for your customized nas_module.
Returns
--------
result : Dict[str, int]
Keys are names of nas_modules, and values are the choice indices of them.
"""
result = {}
for name, module in self.nas_modules:
if name not in result:
result[name] = module.export()
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from .base_lightning import BaseOneShotLightningModule
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.name = layer_choice.label
self.op_choices = nn.ModuleDict(OrderedDict([(name, layer_choice[name]) for name in layer_choice.names]))
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs):
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self, recurse=False):
for name, p in super(DartsLayerChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return list(self.op_choices.keys())[torch.argmax(self.alpha).item()]
class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.name = input_choice.label
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1
def forward(self, inputs):
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * F.softmax(self.alpha, -1).view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self, recurse=False):
for name, p in super(DartsInputChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return torch.argsort(-self.alpha).cpu().numpy().tolist()[:self.n_chosen]
class DartsModule(BaseOneShotLightningModule):
"""
The DARTS module. Each iteration consists of 2 training phases. The phase 1 is architecture step, in which model parameters are
frozen and the architecture parameters are trained. The phase 2 is model step, in which architecture parameters are frozen and
model parameters are trained. See [darts] for details.
The DARTS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.
Reference
----------
.. [darts] H. Liu, K. Simonyan, and Y. Yang, “DARTS: Differentiable Architecture Search,” presented at the
International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=S1eYHoC5FX
"""
def training_step(self, batch, batch_idx):
# grad manually
arc_optim = self.architecture_optimizers
# The InterleavedTrainValDataLoader yields both train and val data in a batch
trn_batch, val_batch = batch
# phase 1: architecture step
# The _resample hook is kept for some darts-based NAS methods like proxyless.
# See code of those methods for details.
self._resample()
arc_optim.zero_grad()
arc_step_loss = self.model.training_step(val_batch, 2 * batch_idx)
if isinstance(arc_step_loss, dict):
arc_step_loss = arc_step_loss['loss']
self.manual_backward(arc_step_loss)
self.finalize_grad()
arc_optim.step()
# phase 2: model step
self._resample()
self.call_user_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(trn_batch, 2 * batch_idx + 1)
w_step_loss = loss_and_metrics['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics
self.manual_backward(w_step_loss)
self.call_user_optimizers('step')
self.call_lr_schedulers(batch_idx)
return loss_and_metrics
def _resample(self):
# Note: This hook is kept for following darts-based NAS algs.
pass
def finalize_grad(self):
# Note: This hook is currently kept for Proxyless NAS.
pass
@property
def default_replace_dict(self):
return {
LayerChoice : DartsLayerChoice,
InputChoice : DartsInputChoice
}
def configure_architecture_optimizers(self):
# The alpha in DartsXXXChoices is the architecture parameter of DARTS. All alphas share one optimizer.
ctrl_params = {}
for _, m in self.nas_modules:
if m.name in ctrl_params:
assert m.alpha.size() == ctrl_params[m.name].size(), 'Size of parameters with the same label should be same.'
m.alpha = ctrl_params[m.name]
else:
ctrl_params[m.name] = m.alpha
ctrl_optim = torch.optim.Adam(list(ctrl_params.values()), 3.e-4, betas=(0.5, 0.999),
weight_decay=1.0E-3)
return ctrl_optim
class _ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = x.detach()
detached_x.requires_grad = x.requires_grad
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class ProxylessLayerChoice(nn.Module):
def __init__(self, ops):
super(ProxylessLayerChoice, self).__init__()
self.ops = nn.ModuleList(ops)
self.alpha = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
self._binary_gates = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
self.sampled = None
def forward(self, *args, **kwargs):
if self.training:
def run_function(ops, active_id, **kwargs):
def forward(_x):
return ops[active_id](_x, **kwargs)
return forward
def backward_function(ops, active_id, binary_gates, **kwargs):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data, **kwargs)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1
x = args[0]
return _ArchGradientFunction.apply(
x, self._binary_gates, run_function(self.ops, self.sampled, **kwargs),
backward_function(self.ops, self.sampled, self._binary_gates, **kwargs)
)
return super().forward(*args, **kwargs)
def resample(self):
probs = F.softmax(self.alpha, dim=-1)
sample = torch.multinomial(probs, 1)[0].item()
self.sampled = sample
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0
def finalize_grad(self):
binary_grads = self._binary_gates.grad
with torch.no_grad():
if self.alpha.grad is None:
self.alpha.grad = torch.zeros_like(self.alpha.data)
probs = F.softmax(self.alpha, dim=-1)
for i in range(len(self.ops)):
for j in range(len(self.ops)):
self.alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
def export(self):
return torch.argmax(self.alpha).item()
def export_prob(self):
return F.softmax(self.alpha, dim=-1)
class ProxylessInputChoice(nn.Module):
def __init__(self, input_choice):
super().__init__()
self.num_input_candidates = input_choice.n_candidates
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1E-3)
self._binary_gates = nn.Parameter(torch.randn(input_choice.n_candidates) * 1E-3)
self.sampled = None
def forward(self, inputs):
if self.training:
def run_function(active_sample):
return lambda x: x[active_sample]
def backward_function(binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(self.num_input_candidates):
out_k = _x[k].data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
inputs = torch.stack(inputs, 0)
return _ArchGradientFunction.apply(
inputs, self._binary_gates, run_function(self.sampled),
backward_function(self._binary_gates)
)
return super().forward(inputs)
def resample(self, sample=None):
if sample is None:
probs = F.softmax(self.alpha, dim=-1)
sample = torch.multinomial(probs, 1)[0].item()
self.sampled = sample
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0
return self.sampled
def finalize_grad(self):
binary_grads = self._binary_gates.grad
with torch.no_grad():
if self.alpha.grad is None:
self.alpha.grad = torch.zeros_like(self.alpha.data)
probs = F.softmax(self.alpha, dim=-1)
for i in range(self.num_input_candidates):
for j in range(self.num_input_candidates):
self.alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
class ProxylessModule(DartsModule):
"""
The Proxyless Module. This is a darts-based method that resamples the architecture to reduce memory consumption.
The Proxyless Module should be trained with :class:`nni.retiarii.oneshot.pytorch.utils.InterleavedTrainValDataLoader`.
Reference
----------
.. [proxyless] H. Cai, L. Zhu, and S. Han, “ProxylessNAS: Direct Neural Architecture Search on Target Task and Hardware,” presented
at the International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=HylVB3AqYm
"""
@property
def default_replace_dict(self):
return {
LayerChoice : ProxylessLayerChoice,
InputChoice : ProxylessInputChoice
}
def configure_architecture_optimizers(self):
ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], 3.e-4,
weight_decay=0, betas=(0, 0.999), eps=1e-8)
return ctrl_optim
def _resample(self):
for _, m in self.nas_modules:
m.resample()
def finalize_grad(self):
for _, m in self.nas_modules:
m.finalize_grad()
class SNASLayerChoice(DartsLayerChoice):
def forward(self, *args, **kwargs):
self.one_hot = F.gumbel_softmax(self.alpha, self.temp)
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
yhat = torch.sum(op_results * self.one_hot.view(*alpha_shape), 0)
return yhat
class SNASInputChoice(DartsInputChoice):
def forward(self, inputs):
self.one_hot = F.gumbel_softmax(self.alpha, self.temp)
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
yhat = torch.sum(inputs * self.one_hot.view(*alpha_shape), 0)
return yhat
class SNASModule(DartsModule):
"""
The SNAS Module. This is a darts-based method that uses gumble-softmax to simulate one-hot distribution.
The SNAS Module should be trained with :class:`nni.retiarii.oneshot.utils.InterleavedTrainValDataLoader`.
Parameters
----------
base_model : pl.LightningModule
The evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
gumble_temperature : float
The initial temperature used in gumble-softmax.
use_temp_anneal : bool
True: a linear annealing will be applied to gumble_temperature. False: run at a fixed temperature. See [snas] for details.
min_temp : float
The minimal temperature for annealing. No need to set this if you set ``use_temp_anneal`` False.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
Reference
----------
.. [snas] S. Xie, H. Zheng, C. Liu, and L. Lin, “SNAS: stochastic neural architecture search,” presented at the
International Conference on Learning Representations, Sep. 2018. Available: https://openreview.net/forum?id=rylqooRqK7
"""
def __init__(self, base_model, gumble_temperature = 1., use_temp_anneal = False,
min_temp = .33, custom_replace_dict=None):
super().__init__(base_model, custom_replace_dict)
self.temp = gumble_temperature
self.init_temp = gumble_temperature
self.use_temp_anneal = use_temp_anneal
self.min_temp = min_temp
def on_epoch_start(self):
if self.use_temp_anneal:
self.temp = (1 - self.trainer.current_epoch / self.trainer.max_epochs) * (self.init_temp - self.min_temp) + self.min_temp
self.temp = max(self.temp, self.min_temp)
for _, nas_module in self.nas_modules:
nas_module.temp = self.temp
return self.model.on_epoch_start()
@property
def default_replace_dict(self):
return {
LayerChoice : SNASLayerChoice,
InputChoice : SNASInputChoice
}
...@@ -145,7 +145,7 @@ class ReinforceController(nn.Module): ...@@ -145,7 +145,7 @@ class ReinforceController(nn.Module):
else: else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device) self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device)
sampled = sampled.detach().numpy().tolist() sampled = sampled.detach().cpu().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob) self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy) self.sample_entropy += self.entropy_reduction(entropy)
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import random
import torch
import torch.nn as nn
import torch.optim as optim
from nni.retiarii.nn.pytorch.api import LayerChoice, InputChoice
from .random import PathSamplingLayerChoice, PathSamplingInputChoice
from .base_lightning import BaseOneShotLightningModule
from .enas import ReinforceController, ReinforceField
class EnasModule(BaseOneShotLightningModule):
"""
The ENAS module. There are 2 steps in an epoch. 1: training model parameters. 2: training ENAS RL agent. The agent will produce
a sample of model architecture to get the best reward.
The ENASModule should be trained with :class:`nni.retiarii.oneshot.utils.ConcatenateTrainValDataloader`.
Parameters
----------
base_model : pl.LightningModule
he evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
grad_clip : float
Gradient clipping value.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
Reference
----------
.. [enas] H. Pham, M. Guan, B. Zoph, Q. Le, and J. Dean, “Efficient Neural Architecture Search via Parameters Sharing,”
in Proceedings of the 35th International Conference on Machine Learning, Jul. 2018, pp. 4095-4104.
Available: https://proceedings.mlr.press/v80/pham18a.html
"""
def __init__(self, base_model, ctrl_kwargs = None,
entropy_weight = 1e-4, skip_weight = .8, baseline_decay = .999,
ctrl_steps_aggregate = 20, grad_clip = 0, custom_replace_dict = None):
super().__init__(base_model, custom_replace_dict)
self.nas_fields = [ReinforceField(name, len(module),
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1)
for name, module in self.nas_modules]
self.controller = ReinforceController(self.nas_fields, **(ctrl_kwargs or {}))
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.grad_clip = grad_clip
def configure_architecture_optimizers(self):
return optim.Adam(self.controller.parameters(), lr=3.5e-4)
@property
def default_replace_dict(self):
return {
LayerChoice : PathSamplingLayerChoice,
InputChoice : PathSamplingInputChoice
}
def training_step(self, batch, batch_idx):
# The ConcatenateTrainValDataloader yields both data and which dataloader it comes from.
batch, source = batch
if source == 'train':
# step 1: train model params
self._resample()
self.call_user_optimizers('zero_grad')
loss_and_metrics = self.model.training_step(batch, batch_idx)
w_step_loss = loss_and_metrics['loss'] \
if isinstance(loss_and_metrics, dict) else loss_and_metrics
self.manual_backward(w_step_loss)
self.call_user_optimizers('step')
return loss_and_metrics
if source == 'val':
# step 2: train ENAS agent
x, y = batch
arc_opt = self.architecture_optimizers
arc_opt.zero_grad()
self._resample()
with torch.no_grad():
logits = self.model(x)
# use the default metric of self.model as reward function
if len(self.model.metrics) == 1:
_, metric = next(iter(self.model.metrics.items()))
else:
if 'default' not in self.model.metrics.keys():
raise KeyError('model.metrics should contain a ``default`` key when' \
'there are multiple metrics')
metric = self.model.metrics['default']
reward = metric(logits, y)
if self.entropy_weight:
reward = reward + self.entropy_weight * self.controller.sample_entropy.item()
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
rnn_step_loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
rnn_step_loss = rnn_step_loss + self.skip_weight * self.controller.sample_skip_penalty
rnn_step_loss = rnn_step_loss / self.ctrl_steps_aggregate
self.manual_backward(rnn_step_loss)
if (batch_idx + 1) % self.ctrl_steps_aggregate == 0:
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.grad_clip)
arc_opt.step()
arc_opt.zero_grad()
def _resample(self):
"""
Resample the architecture as ENAS result. This doesn't require an ``export`` method in nas_modules to work.
"""
result = self.controller.resample()
for name, module in self.nas_modules:
module.sampled = result[name]
def export(self):
self.controller.eval()
with torch.no_grad():
return self.controller.resample()
class RandomSampleModule(BaseOneShotLightningModule):
"""
Random Sampling NAS Algorithm. In each epoch, model parameters are trained after a uniformly random sampling of each choice.
The training result is also a random sample of the search space.
Parameters
----------
base_model : pl.LightningModule
he evaluator in ``nni.retiarii.evaluator.lightning``. User defined model is wrapped by base_model, and base_model will
be wrapped by this model.
custom_replace_dict : Dict[Type[nn.Module], Callable[[nn.Module], nn.Module]], default = None
The custom xxxChoice replace method. Keys should be xxxChoice type and values should return an ``nn.module``. This custom
replace dict will override the default replace dict of each NAS method.
"""
automatic_optimization = True
def training_step(self, batch, batch_idx):
self._resample()
return self.model.training_step(batch, batch_idx)
@property
def default_replace_dict(self):
return {
LayerChoice : PathSamplingLayerChoice,
InputChoice : PathSamplingInputChoice
}
def _resample(self):
"""
Resample the architecture as RandomSample result. This is simply a uniformly sampling that doesn't require an ``export``
method in nas_modules to work.
"""
result = {}
for name, module in self.nas_modules:
if name not in result:
result[name] = random.randint(0, len(module) - 1)
module.sampled = result[name]
return result
def export(self):
return self._resample()
...@@ -6,6 +6,7 @@ from collections import OrderedDict ...@@ -6,6 +6,7 @@ from collections import OrderedDict
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import DataLoader
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
from nni.nas.pytorch.mutables import InputChoice, LayerChoice from nni.nas.pytorch.mutables import InputChoice, LayerChoice
...@@ -127,7 +128,6 @@ class AverageMeter: ...@@ -127,7 +128,6 @@ class AverageMeter:
def _replace_module_with_type(root_module, init_fn, type_name, modules): def _replace_module_with_type(root_module, init_fn, type_name, modules):
if modules is None: if modules is None:
modules = [] modules = []
def apply(m): def apply(m):
for name, child in m.named_children(): for name, child in m.named_children():
if isinstance(child, type_name): if isinstance(child, type_name):
...@@ -180,3 +180,119 @@ def replace_input_choice(root_module, init_fn, modules=None): ...@@ -180,3 +180,119 @@ def replace_input_choice(root_module, init_fn, modules=None):
A list from layer choice keys (names) and replaced modules. A list from layer choice keys (names) and replaced modules.
""" """
return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules) return _replace_module_with_type(root_module, init_fn, (InputChoice, nn.InputChoice), modules)
class InterleavedTrainValDataLoader(DataLoader):
"""
Dataloader that yields both train data and validation data in a batch, with an order of (train_batch, val_batch). The shorter
one will be upsampled (repeated) to the length of the longer one, and the tail of the last repeat will be dropped. This enables
users to train both model parameters and architecture parameters in parallel in an epoch.
Some NAS algorithms, i.e. DARTS and Proxyless, require this type of dataloader.
Parameters
----------
train_data : DataLoader
training dataloader
val_data : DataLoader
validation dataloader
Example
--------
Fit your dataloaders into a parallel one.
>>> para_loader = InterleavedTrainValDataLoader(train_dataloader, val_dataloader)
Then you can use the ``para_loader`` as a normal training loader.
"""
def __init__(self, train_dataloader, val_dataloader):
self.train_loader = train_dataloader
self.val_loader = val_dataloader
self.equal_len = len(train_dataloader) == len(val_dataloader)
self.train_longer = len(train_dataloader) > len(val_dataloader)
super().__init__(None)
def __iter__(self):
self.train_iter = iter(self.train_loader)
self.val_iter = iter(self.val_loader)
return self
def __next__(self):
try:
train_batch = next(self.train_iter)
except StopIteration:
# training data is used up
if self.equal_len or self.train_longer:
# if training is the longger one or equal, stop iteration
raise StopIteration()
# if training is the shorter one, upsample it
self.train_iter = iter(self.train_loader)
train_batch = next(self.train_iter)
try:
val_batch = next(self.val_iter)
except StopIteration:
# validation data is used up
if not self.train_longer:
# if validation is the longger one (the equal condition is
# covered above), stop iteration
raise StopIteration()
# if validation is the shorter one, upsample it
self.val_iter = iter(self.val_loader)
val_batch = next(self.val_iter)
return train_batch, val_batch
def __len__(self) -> int:
return max(len(self.train_loader), len(self.val_loader))
class ConcatenateTrainValDataLoader(DataLoader):
"""
Dataloader that yields validation data after training data in an epoch. You will get a batch with the form of (batch, source) in the
training step, where ``source`` is a string which is either 'train' or 'val', indicating which dataloader the batch comes from. This
enables users to train model parameters first in an epoch, and then train architecture parameters.
Some NAS algorithms, i.e. ENAS, may require this type of dataloader.
Parameters
----------
train_data : DataLoader
training dataloader
val_data : DataLoader
validation dataloader
Warnings
----------
If you set ``limit_train_batches`` of the trainer, the validation batches may be skipped.
Consider downsampling the train dataset and the validation dataset instead if you want to shorten the length of data.
Example
--------
Fit your dataloaders into a concatenated one.
>>> concat_loader = ConcatenateTrainValDataLoader(train_dataloader, val_datalodaer)
Then you can use the ``concat_loader`` as a normal training loader.
"""
def __init__(self, train_dataloader, val_dataloader):
self.train_loader = train_dataloader
self.val_loader = val_dataloader
super().__init__(None)
def __iter__(self):
self.cur_iter = iter(self.train_loader)
self.source = 'train'
return self
def __next__(self):
try:
batch = next(self.cur_iter)
except StopIteration:
# training data is used up, change to validation data
if self.source == 'train':
self.cur_iter = iter(self.val_loader)
self.source = 'val'
return next(self)
raise StopIteration()
else:
return batch, self.source
def __len__(self):
return len(self.train_loader) + len(self.val_loader)
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import pytorch_lightning as pl
import pytest
from torchvision import transforms
from torchvision.datasets import MNIST
from torch.utils.data.sampler import RandomSampler
from nni.retiarii.evaluator.pytorch.lightning import Classification, DataLoader
from nni.retiarii.nn.pytorch import LayerChoice, InputChoice
from nni.retiarii.oneshot.pytorch import (ConcatenateTrainValDataLoader,
DartsModule, EnasModule, SNASModule,
InterleavedTrainValDataLoader,
ProxylessModule, RandomSampleModule)
class DepthwiseSeparableConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.depthwise = nn.Conv2d(in_ch, in_ch, kernel_size=3, groups=in_ch)
self.pointwise = nn.Conv2d(in_ch, out_ch, kernel_size=1)
def forward(self, x):
return self.pointwise(self.depthwise(x))
class Net(pl.LightningModule):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 1)
self.conv2 = LayerChoice([
nn.Conv2d(32, 64, 3, 1),
DepthwiseSeparableConv(32, 64)
])
self.dropout1 = nn.Dropout(.25)
self.dropout2 = nn.Dropout(0.5)
self.dropout_choice = InputChoice(2, 1)
self.fc = LayerChoice([
nn.Sequential(
nn.Linear(9216, 64),
nn.ReLU(),
nn.Linear(64, 10),
),
nn.Sequential(
nn.Linear(9216, 128),
nn.ReLU(),
nn.Linear(128, 10),
),
nn.Sequential(
nn.Linear(9216, 256),
nn.ReLU(),
nn.Linear(256, 10),
)
])
self.rpfc = nn.Linear(10, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(self.conv2(x), 2)
x1 = torch.flatten(self.dropout1(x), 1)
x2 = torch.flatten(self.dropout2(x), 1)
x = self.dropout_choice([x1, x2])
x = self.fc(x)
x = self.rpfc(x)
output = F.log_softmax(x, dim=1)
return output
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def prepare_model_data():
base_model = Net()
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
train_dataset = MNIST('data/mnist', train = True, download=True, transform=transform)
train_random_sampler = RandomSampler(train_dataset, True, int(len(train_dataset) / 10))
train_loader = DataLoader(train_dataset, 64, sampler = train_random_sampler)
valid_dataset = MNIST('data/mnist', train = False, download=True, transform=transform)
valid_random_sampler = RandomSampler(valid_dataset, True, int(len(valid_dataset) / 10))
valid_loader = DataLoader(valid_dataset, 64, sampler = valid_random_sampler)
trainer_kwargs = {
'max_epochs' : 1
}
return base_model, train_loader, valid_loader, trainer_kwargs
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_darts():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader=train_loader, val_dataloaders = valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
darts_model = DartsModule(cls.module)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(darts_model, para_loader)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_proxyless():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
proxyless_model = ProxylessModule(cls.module)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(proxyless_model, para_loader)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_enas():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader = train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
enas_model = EnasModule(cls.module)
concat_loader = ConcatenateTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(enas_model, concat_loader)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_random():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader = train_loader, val_dataloaders=valid_loader , **trainer_kwargs)
cls.module.set_model(base_model)
random_model = RandomSampleModule(cls.module)
cls.trainer.fit(random_model, cls.train_dataloader, cls.val_dataloaders)
@pytest.mark.skipif(pl.__version__< '1.0', reason='Incompatible APIs')
def test_snas():
base_model, train_loader, valid_loader, trainer_kwargs = prepare_model_data()
cls = Classification(train_dataloader=train_loader, val_dataloaders=valid_loader, **trainer_kwargs)
cls.module.set_model(base_model)
proxyless_model = SNASModule(cls.module, 1, use_temp_anneal=True)
para_loader = InterleavedTrainValDataLoader(cls.train_dataloader, cls.val_dataloaders)
cls.trainer.fit(proxyless_model, para_loader)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--exp', type=str, default='all', metavar='E',
help='exp to run, default = all' )
args = parser.parse_args()
if args.exp == 'all':
test_darts()
test_proxyless()
test_enas()
test_random()
test_snas()
else:
globals()[f'test_{args.exp}']()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment