Unverified Commit cd3a912a authored by SparkSnail's avatar SparkSnail Committed by GitHub
Browse files

Merge pull request #218 from microsoft/master

merge master
parents a0846f2a e9cba778
# Copyright (c) Microsoft Corporation # Copyright (c) Microsoft Corporation.
# All rights reserved. # Licensed under the MIT license.
#
# MIT License
#
# Permission is hereby granted, free of charge,
# to any person obtaining a copy of this software and associated
# documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and
# to permit persons to whom the Software is furnished to do so, subject to the following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING
# BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.
""" """
metis_tuner.py metis_tuner.py
""" """
......
# Copyright (c) Microsoft Corporation. All rights reserved. # Copyright (c) Microsoft Corporation.
# # Licensed under the MIT license.
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import logging import logging
from collections import defaultdict from collections import defaultdict
...@@ -184,7 +167,7 @@ class MsgDispatcher(MsgDispatcherBase): ...@@ -184,7 +167,7 @@ class MsgDispatcher(MsgDispatcherBase):
""" """
id_ = data['parameter_id'] id_ = data['parameter_id']
value = data['value'] value = data['value']
if not id_ or id_ in _customized_parameter_ids: if id_ is None or id_ in _customized_parameter_ids:
if not hasattr(self.tuner, '_accept_customized'): if not hasattr(self.tuner, '_accept_customized'):
self.tuner._accept_customized = False self.tuner._accept_customized = False
if not self.tuner._accept_customized: if not self.tuner._accept_customized:
......
# Copyright (c) Microsoft Corporation. All rights reserved. # Copyright (c) Microsoft Corporation.
# # Licensed under the MIT license.
# MIT License
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and
# associated documentation files (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge, publish, distribute,
# sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all copies or
# substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED *AS IS*, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT
# NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
# NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM,
# DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT
# OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
# ==================================================================================================
import threading import threading
import logging import logging
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch.nn as nn
from nni.nas.pytorch.mutables import Mutable, MutableScope, InputChoice
from nni.nas.pytorch.utils import StructuredMutableTreeNode
logger = logging.getLogger(__name__)
class BaseMutator(nn.Module):
"""
A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing
callbacks that are called in ``forward`` in Mutables.
"""
def __init__(self, model):
super().__init__()
self.__dict__["model"] = model
self._structured_mutables = self._parse_search_space(self.model)
def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None):
if memo is None:
memo = set()
if root is None:
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
if nested_detection is not None:
raise RuntimeError("Cannot have nested search space. Error at {} in {}"
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY."
.format(k, module.key))
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo,
nested_detection=nested_detection)
return root
@property
def mutables(self):
return self._structured_mutables
def forward(self, *inputs):
raise RuntimeError("Forward is undefined for mutators.")
def __setattr__(self, name, value):
if name == "model":
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include you network, as it will include all parameters in model into the mutator.")
return super().__setattr__(name, value)
def enter_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is entered.
Parameters
----------
mutable_scope : MutableScope
"""
pass
def exit_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is exited.
Parameters
----------
mutable_scope : MutableScope
"""
pass
def on_forward_layer_choice(self, mutable, *inputs):
"""
Callbacks of forward in LayerChoice.
Parameters
----------
mutable : LayerChoice
inputs : list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
output tensor and mask
"""
raise NotImplementedError
def on_forward_input_choice(self, mutable, tensor_list):
"""
Callbacks of forward in InputChoice.
Parameters
----------
mutable : InputChoice
tensor_list : list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
output tensor and mask
"""
raise NotImplementedError
def export(self):
"""
Export the data of all decisions. This should output the decisions of all the mutables, so that the whole
network can be fully determined with these decisions for further training from scratch.
Returns
-------
dict
"""
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from abc import ABC, abstractmethod
class BaseTrainer(ABC):
@abstractmethod
def train(self):
raise NotImplementedError
@abstractmethod
def validate(self):
raise NotImplementedError
@abstractmethod
def export(self, file):
raise NotImplementedError
@abstractmethod
def checkpoint(self):
raise NotImplementedError
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import os
_logger = logging.getLogger(__name__)
class Callback:
def __init__(self):
self.model = None
self.mutator = None
self.trainer = None
def build(self, model, mutator, trainer):
self.model = model
self.mutator = mutator
self.trainer = trainer
def on_epoch_begin(self, epoch):
pass
def on_epoch_end(self, epoch):
pass
def on_batch_begin(self, epoch):
pass
def on_batch_end(self, epoch):
pass
class LRSchedulerCallback(Callback):
def __init__(self, scheduler, mode="epoch"):
super().__init__()
assert mode == "epoch"
self.scheduler = scheduler
self.mode = mode
def on_epoch_end(self, epoch):
self.scheduler.step()
class ArchitectureCheckpoint(Callback):
def __init__(self, checkpoint_dir, every="epoch"):
super().__init__()
assert every == "epoch"
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch):
self.trainer.export(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)))
from .mutator import get_and_apply_next_architecture
import os
import sys
import json
import logging
import torch
import nni
from nni.env_vars import trial_env_vars
from nni.nas.pytorch.base_mutator import BaseMutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
logger = logging.getLogger(__name__)
def get_and_apply_next_architecture(model):
"""
Wrapper of ClassicMutator to make it more meaningful,
similar to ```get_next_parameter``` for HPO.
Parameters
----------
model : pytorch model
user's model with search space (e.g., LayerChoice, InputChoice) embedded in it
"""
ClassicMutator(model)
class ClassicMutator(BaseMutator):
"""
This mutator is to apply the architecture chosen from tuner.
It implements the forward function of LayerChoice and InputChoice,
to only activate the chosen ones
"""
def __init__(self, model):
"""
Generate search space based on ```model```.
If env ```NNI_GEN_SEARCH_SPACE``` exists, this is in dry run mode for
generating search space for the experiment.
If not, there are still two mode, one is nni experiment mode where users
use ```nnictl``` to start an experiment. The other is standalone mode
where users directly run the trial command, this mode chooses the first
one(s) for each LayerChoice and InputChoice.
Parameters
----------
model : pytorch model
user's model with search space (e.g., LayerChoice, InputChoice) embedded in it
"""
super(ClassicMutator, self).__init__(model)
self.chosen_arch = {}
self.search_space = self._generate_search_space()
if 'NNI_GEN_SEARCH_SPACE' in os.environ:
# dry run for only generating search space
self._dump_search_space(self.search_space, os.environ.get('NNI_GEN_SEARCH_SPACE'))
sys.exit(0)
# get chosen arch from tuner
self.chosen_arch = nni.get_next_parameter()
if not self.chosen_arch and trial_env_vars.NNI_PLATFORM is None:
logger.warning('This is in standalone mode, the chosen are the first one(s)')
self.chosen_arch = self._standalone_generate_chosen()
self._validate_chosen_arch()
def _validate_chosen_arch(self):
pass
def _standalone_generate_chosen(self):
"""
Generate the chosen architecture for standalone mode,
i.e., choose the first one(s) for LayerChoice and InputChoice
{ key_name: {'_value': "conv1",
'_idx': 0} }
{ key_name: {'_value': ["in1"],
'_idx': [0]} }
Returns
-------
dict
the chosen architecture
"""
chosen_arch = {}
for key, val in self.search_space.items():
if val['_type'] == 'layer_choice':
choices = val['_value']
chosen_arch[key] = {'_value': choices[0], '_idx': 0}
elif val['_type'] == 'input_choice':
choices = val['_value']['candidates']
n_chosen = val['_value']['n_chosen']
chosen_arch[key] = {'_value': choices[:n_chosen], '_idx': list(range(n_chosen))}
else:
raise ValueError('Unknown key %s and value %s' % (key, val))
return chosen_arch
def _generate_search_space(self):
"""
Generate search space from mutables.
Here is the search space format:
{ key_name: {'_type': 'layer_choice',
'_value': ["conv1", "conv2"]} }
{ key_name: {'_type': 'input_choice',
'_value': {'candidates': ["in1", "in2"],
'n_chosen': 1}} }
Returns
-------
dict
the generated search space
"""
search_space = {}
for mutable in self.mutables:
# for now we only generate flattened search space
if isinstance(mutable, LayerChoice):
key = mutable.key
val = [repr(choice) for choice in mutable.choices]
search_space[key] = {"_type": "layer_choice", "_value": val}
elif isinstance(mutable, InputChoice):
key = mutable.key
search_space[key] = {"_type": "input_choice",
"_value": {"candidates": mutable.choose_from,
"n_chosen": mutable.n_chosen}}
else:
raise TypeError('Unsupported mutable type: %s.' % type(mutable))
return search_space
def _dump_search_space(self, search_space, file_path):
with open(file_path, 'w') as ss_file:
json.dump(search_space, ss_file)
def _tensor_reduction(self, reduction_type, tensor_list):
if tensor_list == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def on_forward_layer_choice(self, mutable, *inputs):
"""
Implement the forward of LayerChoice
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
Returns
-------
tuple
return of the chosen op, the index of the chosen op
"""
assert mutable.key in self.chosen_arch
val = self.chosen_arch[mutable.key]
assert isinstance(val, dict)
idx = val['_idx']
assert self.search_space[mutable.key]['_value'][idx] == val['_value']
return mutable.choices[idx](*inputs), idx
def on_forward_input_choice(self, mutable, tensor_list):
"""
Implement the forward of InputChoice
Parameters
----------
mutable: InputChoice
tensor_list: list of torch.Tensor
tags: list of string
Returns
-------
tuple of torch.Tensor and list
reduced tensor, mask list
"""
assert mutable.key in self.chosen_arch
val = self.chosen_arch[mutable.key]
assert isinstance(val, dict)
mask = [0 for _ in range(mutable.n_candidates)]
out = []
for i, idx in enumerate(val['_idx']):
# check whether idx matches the chosen candidate name
assert self.search_space[mutable.key]['_value']['candidates'][idx] == val['_value'][i]
out.append(tensor_list[idx])
mask[idx] = 1
return self._tensor_reduction(mutable.reduction, out), mask
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import DartsMutator
from .trainer import DartsTrainer
\ No newline at end of file
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
_logger = logging.getLogger(__name__)
class DartsMutator(Mutator):
def __init__(self, model):
super().__init__(model)
self.choices = nn.ParameterDict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1))
def device(self):
for v in self.choices.values():
return v.device
def sample_search(self):
result = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1]
elif isinstance(mutable, InputChoice):
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device())
return result
def sample_final(self):
result = dict()
edges_max = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool()
for mutable in self.mutables:
if isinstance(mutable, InputChoice) and mutable.n_chosen is not None:
weights = []
for src_key in mutable.choose_from:
if src_key not in edges_max:
_logger.warning("InputChoice.NO_KEY in '%s' is weighted 0 when selecting inputs.", mutable.key)
weights.append(edges_max.get(src_key, 0.))
weights = torch.tensor(weights) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen)
selected_multihot = []
for i, src_key in enumerate(mutable.choose_from):
if i not in topk_edge_indices and src_key in result:
result[src_key] = torch.zeros_like(result[src_key]) # clear this choice to optimize calc graph
selected_multihot.append(i in topk_edge_indices)
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import logging
import torch
import torch.nn as nn
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import DartsMutator
logger = logging.getLogger(__name__)
class DartsTrainer(Trainer):
def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None, arc_learning_rate=3.0E-4, unrolled=True):
super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3)
self.unrolled = unrolled
n_train = len(self.dataset_train)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_one_epoch(self, epoch):
self.model.train()
self.mutator.train()
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
# phase 1. architecture step
self.ctrl_optim.zero_grad()
if self.unrolled:
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
else:
self._backward(val_X, val_y)
self.ctrl_optim.step()
# phase 2: child network step
self.optimizer.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping
self.optimizer.step()
metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
self.mutator.eval()
meters = AverageMeterGroup()
with torch.no_grad():
self.mutator.reset()
for step, (X, y) in enumerate(self.test_loader):
X, y = X.to(self.device), y.to(self.device)
logits = self.model(X)
metrics = self.metrics(logits, y)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.test_loader), meters)
def _logits_and_loss(self, X, y):
self.mutator.reset()
logits = self.model(X)
loss = self.loss(logits, y)
return logits, loss
def _backward(self, val_X, val_y):
"""
Simple backward with gradient descent
"""
_, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
"""
Compute unrolled loss and backward its gradients
"""
backup_params = copy.deepcopy(tuple(self.model.parameters()))
# do virtual step on training data
lr = self.optimizer.param_groups[0]["lr"]
momentum = self.optimizer.param_groups[0]["momentum"]
weight_decay = self.optimizer.param_groups[0]["weight_decay"]
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_, loss = self._logits_and_loss(val_X, val_y)
w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters())
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
with torch.no_grad():
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h
# restore weights
self._restore_weights(backup_params)
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
"""
Compute unrolled weights w`
"""
# don't need zero_grad, using autograd to calculate gradients
_, loss = self._logits_and_loss(X, y)
gradients = torch.autograd.grad(loss, self.model.parameters())
with torch.no_grad():
for w, g in zip(self.model.parameters(), gradients):
m = self.optimizer.state[w].get("momentum_buffer", 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
with torch.no_grad():
for param, backup in zip(self.model.parameters(), backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item())
dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.model.parameters(), dw):
p += e * d
_, loss = self._logits_and_loss(trn_X, trn_y)
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import EnasMutator
from .trainer import EnasTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_c, prev_h = hidden
next_c, next_h = [], []
for i, m in enumerate(self.lstm_modules):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
next_c.append(curr_c)
next_h.append(curr_h)
inputs = curr_h[-1]
return next_c, next_h
class EnasMutator(Mutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, branch_bias=0.25):
super().__init__(model)
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target
self.branch_bias = branch_bias
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict()
self.max_layer_choice = 0
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \
"ENAS mutator requires all layer choice have the same number of candidates."
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key:
def is_conv(choice):
return "conv" in str(type(choice)).lower()
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
for choice in mutable.choices])
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
def sample_search(self):
self._initialize()
self._sample(self.mutables)
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)
def _initialize(self):
self._choices = dict()
self._anchors_hid = dict()
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._c, self._h = self.lstm(self._inputs, (self._c, self._h))
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
def _sample_layer_choice(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy)
self._inputs = self.embedding(branch_id)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
if mutable.n_chosen is None:
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip)
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = query.view(1, -1)
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()]
self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy)
return skip.bool()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
class EnasTrainer(Trainer):
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4):
super().__init__(model, mutator if mutator is not None else EnasMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.reward_function = reward_function
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.mutator_steps_aggregate = mutator_steps_aggregate
self.mutator_steps = mutator_steps
self.aux_weight = aux_weight
n_train = len(self.dataset_train)
split = n_train // 10
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_one_epoch(self, epoch):
# Sample model and train
self.model.train()
self.mutator.eval()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
with torch.no_grad():
self.mutator.reset()
logits = self.model(x)
if isinstance(logits, tuple):
logits, aux_logits = logits
aux_loss = self.loss(aux_logits, y)
else:
aux_loss = 0.
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
loss = loss + self.aux_weight * aux_loss
loss.backward()
self.optimizer.step()
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
# Train sampler (mutator)
self.model.eval()
self.mutator.train()
meters = AverageMeterGroup()
mutator_step, total_mutator_steps = 0, self.mutator_steps * self.mutator_steps_aggregate
while mutator_step < total_mutator_steps:
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self.mutator.reset()
with torch.no_grad():
logits = self.model(x)
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight is not None:
reward += self.entropy_weight * self.mutator.sample_entropy
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
self.baseline = self.baseline.detach().item()
loss = self.mutator.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.mutator.sample_skip_penalty
metrics["reward"] = reward
metrics["loss"] = loss.item()
metrics["ent"] = self.mutator.sample_entropy.item()
metrics["baseline"] = self.baseline
metrics["skip"] = self.mutator.sample_skip_penalty
loss = loss / self.mutator_steps_aggregate
loss.backward()
meters.update(metrics)
if mutator_step % self.mutator_steps_aggregate == 0:
self.mutator_optim.step()
self.mutator_optim.zero_grad()
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs,
mutator_step // self.mutator_steps_aggregate + 1, self.mutator_steps, meters)
mutator_step += 1
if mutator_step >= total_mutator_steps:
break
def validate_one_epoch(self, epoch):
pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import json
import torch
from nni.nas.pytorch.mutables import MutableScope
from nni.nas.pytorch.mutator import Mutator
class FixedArchitecture(Mutator):
def __init__(self, model, fixed_arc, strict=True):
"""
Initialize a fixed architecture mutator.
Parameters
----------
model : nn.Module
A mutable network.
fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict : bool
Force everything that appears in `fixed_arc` to be used at least once.
"""
super().__init__(model)
self._fixed_arc = fixed_arc
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
fixed_arc_keys = set(self._fixed_arc.keys())
if fixed_arc_keys - mutable_keys:
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
if mutable_keys - fixed_arc_keys:
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
def sample_search(self):
return self._fixed_arc
def sample_final(self):
return self._fixed_arc
def _encode_tensor(data, device):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool, device=device) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float, device=device) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: _encode_tensor(v, device) for k, v in data.items()}
return data
def apply_fixed_architecture(model, fixed_arc_path, device=None):
"""
Load architecture from `fixed_arc_path` and apply to model.
Parameters
----------
model : torch.nn.Module
Model with mutables.
fixed_arc_path : str
Path to the JSON that stores the architecture.
device : torch.device
Architecture weights will be transfered to `device`.
Returns
-------
FixedArchitecture
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if isinstance(fixed_arc_path, str):
with open(fixed_arc_path, "r") as f:
fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc, device)
architecture = FixedArchitecture(model, fixed_arc)
architecture.to(device)
architecture.reset()
return architecture
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch.nn as nn
from nni.nas.pytorch.utils import global_mutable_counting
logger = logging.getLogger(__name__)
class Mutable(nn.Module):
"""
Mutable is designed to function as a normal layer, with all necessary operators' weights.
States and weights of architectures should be included in mutator, instead of the layer itself.
Mutable has a key, which marks the identity of the mutable. This key can be used by users to share
decisions among different mutables. In mutator's implementation, mutators should use the key to
distinguish different mutables. Mutables that share the same key should be "similar" to each other.
Currently the default scope for keys is global.
"""
def __init__(self, key=None):
super().__init__()
if key is not None:
if not isinstance(key, str):
key = str(key)
logger.warning("Warning: key \"%s\" is not string, converted to string.", key)
self._key = key
else:
self._key = self.__class__.__name__ + str(global_mutable_counting())
self.init_hook = self.forward_hook = None
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator):
if "mutator" in self.__dict__:
raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? "
"Or did you apply multiple fixed architectures?")
self.__dict__["mutator"] = mutator
def forward(self, *inputs):
raise NotImplementedError
@property
def key(self):
return self._key
@property
def name(self):
return self._name if hasattr(self, "_name") else "_key"
@name.setter
def name(self, name):
self._name = name
def _check_built(self):
if not hasattr(self, "mutator"):
raise ValueError(
"Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
"Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(Mutable):
"""
Mutable scope marks a subgraph/submodule to help mutators make better decisions.
Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
MutableScope are also mutables that are listed in the mutables (search space).
"""
def __init__(self, key):
super().__init__(key=key)
def __call__(self, *args, **kwargs):
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key)
self.length = len(op_candidates)
self.choices = nn.ModuleList(op_candidates)
self.reduction = reduction
self.return_mask = return_mask
def forward(self, *inputs):
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out
class InputChoice(Mutable):
"""
Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners,
use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to
know about `choose_from`.
The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones.
The keys are designed to be the keys of the sources. To help mutators make better decisions,
mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the
output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g.,
``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a
module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed.
"""
NO_KEY = ""
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
reduction="sum", return_mask=False, key=None):
"""
Initialization.
Parameters
----------
n_candidates : int
Number of inputs to choose from.
choose_from : list of str
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
number of empty string.
n_chosen : int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction : str
`mean`, `concat`, `sum` or `none`.
return_mask : bool
If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
key : str
Key of the input choice.
"""
super().__init__(key=key)
# precondition check
assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \
"must be not None."
if choose_from is not None and n_candidates is None:
n_candidates = len(choose_from)
elif choose_from is None and n_candidates is not None:
choose_from = [self.NO_KEY] * n_candidates
assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`."
assert n_candidates > 0, "Number of candidates must be greater than 0."
assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \
"than number of candidates."
self.n_candidates = n_candidates
self.choose_from = choose_from
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
def forward(self, optional_inputs):
"""
Forward method of LayerChoice.
Parameters
----------
optional_inputs : list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from`.
Returns
-------
tuple of torch.Tensor and torch.Tensor or torch.Tensor
"""
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), \
"Optional input list must be a list, not a {}.".format(type(optional_input_list))
assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates."
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
if self.return_mask:
return out, mask
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
from nni.nas.pytorch.base_mutator import BaseMutator
logger = logging.getLogger(__name__)
class Mutator(BaseMutator):
def __init__(self, model):
super().__init__(model)
self._cache = dict()
def sample_search(self):
"""
Override to implement this method to iterate over mutables and make decisions.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
def sample_final(self):
"""
Override to implement this method to iterate over mutables and make decisions that is final
for export and retraining.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
def reset(self):
"""
Reset the mutator by call the `sample_search` to resample (for search).
Returns
-------
None
"""
self._cache = self.sample_search()
def export(self):
"""
Resample (for final) and return results.
Returns
-------
dict
"""
return self.sample_final()
def on_forward_layer_choice(self, mutable, *inputs):
"""
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
Parameters
----------
mutable : LayerChoice
inputs : list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
"""
def _map_fn(op, *inputs):
return op(*inputs)
mask = self._get_decision(mutable)
assert len(mask) == len(mutable.choices), \
"Invalid mask, expected {} to be of length {}.".format(mask, len(mutable.choices))
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list):
"""
On default, this method calls :meth:`on_calc_input_choice_mask` with `tags`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
Parameters
----------
mutable : InputChoice
tensor_list : list of torch.Tensor
tags : list of string
Returns
-------
tuple of torch.Tensor and torch.Tensor
"""
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates, \
"Invalid mask, expected {} to be of length {}.".format(mask, mutable.n_candidates)
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)]
else:
raise ValueError("Unrecognized mask")
return out
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def _get_decision(self, mutable):
"""
By default, this method checks whether `mutable.key` is already in the decision cache,
and returns the result without double-check.
Parameters
----------
mutable : Mutable
Returns
-------
object
"""
if mutable.key not in self._cache:
raise ValueError("\"{}\" not found in decision cache.".format(mutable.key))
result = self._cache[mutable.key]
logger.debug("Decision %s: %s", mutable.key, result)
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trainer import PdartsTrainer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment