Commit d43fbe82 authored by quzha's avatar quzha
Browse files

Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-nas-refactor

parents 0e3906aa bb797e10
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import PyTorchMutator
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_c, prev_h = hidden
next_c, next_h = [], []
for i, m in enumerate(self.lstm_modules):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
next_c.append(curr_c)
next_h.append(curr_h)
inputs = curr_h[-1]
return next_c, next_h
class EnasMutator(PyTorchMutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, anchor_extra_step=False,
skip_target=0.4):
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.max_layer_choice = 0
self.anchor_extra_step = anchor_extra_step
self.skip_target = skip_target
super().__init__(model)
def before_build(self, model):
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False)
self.cross_entropy_loss = nn.CrossEntropyLoss()
def after_build(self, model):
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice)
def before_pass(self):
super().before_pass()
self._anchors_hid = dict()
self._selected_layers = []
self._selected_inputs = []
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._c, self._h = self.lstm(self._inputs, (self._c, self._h))
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
def on_init_layer_choice(self, mutable):
if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \
"ENAS mutator requires all layer choice have the same number of candidates."
def on_calc_layer_choice_mask(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += log_prob
entropy = (log_prob * torch.exp(-log_prob)).detach()
self.sample_entropy += entropy
self._inputs = self.embedding(branch_id)
self._selected_layers.append(branch_id.item())
return F.one_hot(branch_id).bool().view(-1)
def on_calc_input_choice_mask(self, mutable, semantic_labels):
if mutable.n_selected is None:
query, anchors = [], []
for label in semantic_labels:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
logit = torch.cat([-query, query], 1)
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip)
self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach()
self.sample_entropy += torch.sum(entropy)
self.inputs = torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))
self._selected_inputs.append(skip)
return skip.bool()
else:
assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS."
raise NotImplementedError
def exit_mutable_scope(self, mutable_scope):
self._mark_anchor(mutable_scope.key)
import torch
import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer
from nni.nas.utils import AverageMeterGroup, auto_device
from .mutator import EnasMutator
class EnasTrainer(Trainer):
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid, lr_scheduler=None,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
mutator_lr=0.00035):
self.model = model
self.loss = loss
self.metrics = metrics
self.reward_function = reward_function
self.mutator = mutator
if self.mutator is None:
self.mutator = EnasMutator(model)
self.optim = optimizer
self.mut_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
self.lr_scheduler = lr_scheduler
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.device = auto_device() if device is None else device
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.model.to(self.device)
self.loss.to(self.device)
self.mutator.to(self.device)
n_train = len(self.dataset_train)
split = n_train // 10
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_epoch(self, epoch):
self.model.train()
self.mutator.train()
for phase in ["model", "mutator"]:
if phase == "model":
self.model.train()
self.mutator.eval()
else:
self.model.eval()
self.mutator.train()
loader = self.train_loader if phase == "model" else self.valid_loader
meters = AverageMeterGroup()
for step, (x, y) in enumerate(loader):
x, y = x.to(self.device), y.to(self.device)
self.optim.zero_grad()
self.mut_optim.zero_grad()
with self.mutator.forward_pass():
logits = self.model(x)
metrics = self.metrics(logits, y)
if phase == "model":
loss = self.loss(logits, y)
loss.backward()
self.optim.step()
else:
reward = self.reward_function(logits, y)
if self.entropy_weight is not None:
reward += self.entropy_weight * self.mutator.sample_entropy
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
self.baseline = self.baseline.detach().item()
loss = self.mutator.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.mutator.sample_skip_penalty
loss.backward()
self.mut_optim.step()
metrics["reward"] = reward
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
print("Epoch {} {} Step [{}/{}] {}".format(epoch, phase.capitalize(), step,
len(loader), meters))
# print(self.mutator._selected_layers)
# print(self.mutator._selected_inputs)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
def validate_epoch(self, epoch):
pass
def train(self):
for epoch in range(self.num_epochs):
# training
print("Epoch {} Training".format(epoch))
self.train_epoch(epoch)
# validation
print("Epoch {} Validating".format(epoch))
self.validate_epoch(epoch)
def export(self):
pass
import torch.nn as nn
from nni.nas.utils import global_mutable_counting
class PyTorchMutable(nn.Module):
"""
Mutable is designed to function as a normal layer, with all necessary operators' weights.
States and weights of architectures should be included in mutator, instead of the layer itself.
Mutable has a key, which marks the identity of the mutable. This key can be used by users to share
decisions among different mutables. In mutator's implementation, mutators should use the key to
distinguish different mutables. Mutables that share the same key should be "similar" to each other.
Currently the default scope for keys is global.
"""
def __init__(self, key=None):
super().__init__()
if key is not None:
if not isinstance(key, str):
key = str(key)
print("Warning: key \"{}\" is not string, converted to string.".format(key))
self._key = key
else:
self._key = self.__class__.__name__ + str(global_mutable_counting())
self.name = self.key
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def __enter__(self):
self._check_built()
return super().__enter__()
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator):
self.__dict__["mutator"] = mutator
def forward(self, *inputs):
raise NotImplementedError("Mutable forward must be implemented.")
@property
def key(self):
return self._key
def similar(self, other):
return self == other
def _check_built(self):
if not hasattr(self, "mutator"):
raise ValueError(
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(PyTorchMutable):
"""
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch
corresponding events, and do status dump or update.
"""
def __init__(self, key):
super().__init__(key=key)
def __enter__(self):
self.mutator.enter_mutable_scope(self)
def __exit__(self, exc_type, exc_val, exc_tb):
self.mutator.exit_mutable_scope(self)
class LayerChoice(PyTorchMutable):
def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None):
super().__init__(key=key)
self.length = len(op_candidates)
self.choices = nn.ModuleList(op_candidates)
self.reduction = reduction
self.return_mask = return_mask
def forward(self, *inputs):
out, mask = self.mutator.on_forward(self, *inputs)
if self.return_mask:
return out, mask
return out
def similar(self, other):
return type(self) == type(other) and self.length == other.length
class InputChoice(PyTorchMutable):
def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None):
super().__init__(key=key)
assert n_candidates > 0, "Number of candidates must be greater than 0."
self.n_candidates = n_candidates
self.n_selected = n_selected
self.reduction = reduction
self.return_mask = return_mask
def forward(self, optional_inputs, semantic_labels=None):
assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates."
if semantic_labels is None:
semantic_labels = ["default_label"] * self.n_candidates
out, mask = self.mutator.on_forward(self, optional_inputs, semantic_labels)
if self.return_mask:
return out, mask
return out
def similar(self, other):
return type(self) == type(other) and \
self.n_candidates == other.n_candidates and self.n_selected and other.n_selected
import logging
from contextlib import contextmanager
import torch
import torch.nn as nn
from nni.nas.pytorch.mutables import PyTorchMutable
from nni.nas.utils import to_snake_case
logger = logging.getLogger(__name__)
class PyTorchMutator(nn.Module):
def __init__(self, model):
super().__init__()
self.before_build(model)
self.parse_search_space(model)
self.after_build(model)
def before_build(self, model):
pass
def after_build(self, model):
pass
def named_mutables(self, model):
# if distinct is true, the method will filter out those with duplicated keys
key2module = dict()
for name, module in model.named_modules():
if isinstance(module, PyTorchMutable):
distinct = False
if module.key in key2module:
assert key2module[module.key].similar(module), \
"Mutable \"{}\" that share the same key must be similar to each other".format(module.key)
else:
distinct = True
key2module[module.key] = module
yield name, module, distinct
def __setattr__(self, key, value):
if key in ["model", "net", "network"]:
logger.warning("Think twice if you are including the network into mutator.")
return super().__setattr__(key, value)
def parse_search_space(self, model):
for name, mutable, distinct in self.named_mutables(model):
mutable.name = name
mutable.set_mutator(self)
if not distinct:
continue
init_method_name = "on_init_{}".format(to_snake_case(mutable.__class__.__name__))
if hasattr(self, init_method_name) and callable(getattr(self, init_method_name)):
getattr(self, init_method_name)(mutable)
else:
# fallback to general init
self.on_init_general(mutable)
def on_init_general(self, mutable):
pass
@contextmanager
def forward_pass(self):
self.before_pass()
try:
yield self
finally:
self.after_pass()
def before_pass(self):
self._in_forward_pass = True
self._cache = dict()
def after_pass(self):
self._in_forward_pass = False
def enter_mutable_scope(self, mutable_scope):
pass
def exit_mutable_scope(self, mutable_scope):
pass
def forward(self, *inputs):
raise NotImplementedError("Mutator is not forward-able")
def on_forward(self, mutable, *inputs):
"""Callback on forwarding a mutable"""
if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass:
raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
"super().before_pass() and after_pass() in your override method?")
forward_method_name = "on_forward_{}".format(to_snake_case(mutable.__class__.__name__))
if hasattr(self, forward_method_name) and callable(getattr(self, forward_method_name)):
return getattr(self, forward_method_name)(mutable, *inputs)
else:
# fallback to general forward
return self.on_forward_general(mutable, *inputs)
def on_forward_general(self, mutable, *inputs):
raise NotImplementedError("Forward has to be implemented")
def on_forward_layer_choice(self, mutable, *inputs):
"""
Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy speicified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
Returns
-------
torch.Tensor
"""
def _map_fn(op, *inputs):
return op(*inputs)
mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable))
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list, semantic_labels):
"""
Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `semantic_labels`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy speicified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
Parameters
----------
mutable: InputChoice
inputs: list of torch.Tensor
Returns
-------
torch.Tensor
"""
mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, semantic_labels))
out = self._select_with_mask(lambda x: x, [(t, ) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_calc_layer_choice_mask(self, mutable):
"""
Recommended to override. Calculate a mask tensor for a layer choice.
Parameters
----------
mutable: LayerChoice
Corresponding layer choice object.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise NotImplementedError("Layer choice mask calculation must be implemented")
def on_calc_input_choice_mask(self, mutable, semantic_labels):
"""
Recommended to override. Calculate a mask tensor for a input choice.
Parameters
----------
mutable: InputChoice
Corresponding input choice object.
semantic_labels: list of string
The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise NotImplementedError("Input choice mask calculation must be implemented")
def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type():
# print(candidates[0], len(mask))
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)]
else:
raise ValueError("Unrecognized mask")
return out
def _tensor_reduction(self, reduction_type, tensor_list):
if tensor_list == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
from abc import ABC, abstractmethod
class Trainer(ABC):
@abstractmethod
def train(self):
raise NotImplementedError
@abstractmethod
def export(self):
raise NotImplementedError
import re
from collections import OrderedDict
import torch
_counter = 0
def global_mutable_counting():
global _counter
_counter += 1
return _counter
def to_snake_case(camel_case):
return re.sub('(?!^)([A-Z]+)', r'_\1', camel_case).lower()
def auto_device():
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
class AverageMeterGroup(object):
def __init__(self):
self.meters = OrderedDict()
def update(self, data):
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k, ":4f")
self.meters[k].update(v)
def __str__(self):
return " ".join(str(v) for _, v in self.meters.items())
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment