Unverified Commit d07f7280 authored by chicm-ms's avatar chicm-ms Committed by GitHub
Browse files

Merge pull request #1769 from microsoft/dev-nas-refactor

NAS refactor merge back to master (DO NOT SQUASH)
parents 503a3579 17ea5f0a
import logging
import torch.nn as nn
from nni.nas.pytorch.mutables import Mutable, MutableScope, InputChoice
from nni.nas.pytorch.utils import StructuredMutableTreeNode
logger = logging.getLogger(__name__)
class BaseMutator(nn.Module):
"""
A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing
callbacks that are called in ``forward`` in Mutables.
"""
def __init__(self, model):
super().__init__()
self.__dict__["model"] = model
self._structured_mutables = self._parse_search_space(self.model)
def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None):
if memo is None:
memo = set()
if root is None:
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
if nested_detection is not None:
raise RuntimeError("Cannot have nested search space. Error at {} in {}"
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY."
.format(k, module.key))
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo,
nested_detection=nested_detection)
return root
@property
def mutables(self):
return self._structured_mutables
def forward(self, *inputs):
raise RuntimeError("Forward is undefined for mutators.")
def __setattr__(self, name, value):
if name == "model":
raise AttributeError("Attribute `model` can be set at most once, and you shouldn't use `self.model = model` to "
"include you network, as it will include all parameters in model into the mutator.")
return super().__setattr__(name, value)
def enter_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is entered.
Parameters
----------
mutable_scope : MutableScope
"""
pass
def exit_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is exited.
Parameters
----------
mutable_scope : MutableScope
"""
pass
def on_forward_layer_choice(self, mutable, *inputs):
"""
Callbacks of forward in LayerChoice.
Parameters
----------
mutable : LayerChoice
inputs : list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
output tensor and mask
"""
raise NotImplementedError
def on_forward_input_choice(self, mutable, tensor_list):
"""
Callbacks of forward in InputChoice.
Parameters
----------
mutable : InputChoice
tensor_list : list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
output tensor and mask
"""
raise NotImplementedError
def export(self):
"""
Export the data of all decisions. This should output the decisions of all the mutables, so that the whole
network can be fully determined with these decisions for further training from scratch.
Returns
-------
dict
"""
raise NotImplementedError
from abc import ABC, abstractmethod
class BaseTrainer(ABC):
@abstractmethod
def train(self):
raise NotImplementedError
@abstractmethod
def validate(self):
raise NotImplementedError
@abstractmethod
def export(self, file):
raise NotImplementedError
@abstractmethod
def checkpoint(self):
raise NotImplementedError
import logging
import os
_logger = logging.getLogger(__name__)
class Callback:
def __init__(self):
self.model = None
self.mutator = None
self.trainer = None
def build(self, model, mutator, trainer):
self.model = model
self.mutator = mutator
self.trainer = trainer
def on_epoch_begin(self, epoch):
pass
def on_epoch_end(self, epoch):
pass
def on_batch_begin(self, epoch):
pass
def on_batch_end(self, epoch):
pass
class LRSchedulerCallback(Callback):
def __init__(self, scheduler, mode="epoch"):
super().__init__()
assert mode == "epoch"
self.scheduler = scheduler
self.mode = mode
def on_epoch_end(self, epoch):
self.scheduler.step()
class ArchitectureCheckpoint(Callback):
def __init__(self, checkpoint_dir, every="epoch"):
super().__init__()
assert every == "epoch"
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def on_epoch_end(self, epoch):
self.trainer.export(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)))
from .mutator import DartsMutator
from .trainer import DartsTrainer
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class DartsMutator(Mutator):
def __init__(self, model):
super().__init__(model)
self.choices = nn.ParameterDict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1))
def device(self):
for v in self.choices.values():
return v.device
def sample_search(self):
result = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1]
elif isinstance(mutable, InputChoice):
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device())
return result
def sample_final(self):
result = dict()
edges_max = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool()
for mutable in self.mutables:
if isinstance(mutable, InputChoice):
weights = torch.tensor([edges_max.get(src_key, 0.) for src_key in mutable.choose_from]) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen or mutable.n_candidates)
selected_multihot = []
for i, src_key in enumerate(mutable.choose_from):
if i not in topk_edge_indices and src_key in result:
result[src_key] = torch.zeros_like(result[src_key]) # clear this choice to optimize calc graph
selected_multihot.append(i in topk_edge_indices)
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
return result
import copy
import logging
import torch
import torch.nn as nn
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import DartsMutator
logger = logging.getLogger(__name__)
class DartsTrainer(Trainer):
def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None, arc_learning_rate=3.0E-4, unrolled=True):
super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3)
self.unrolled = unrolled
n_train = len(self.dataset_train)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_one_epoch(self, epoch):
self.model.train()
self.mutator.train()
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
# phase 1. architecture step
self.ctrl_optim.zero_grad()
if self.unrolled:
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
else:
self._backward(val_X, val_y)
self.ctrl_optim.step()
# phase 2: child network step
self.optimizer.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) # gradient clipping
self.optimizer.step()
metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def validate_one_epoch(self, epoch):
self.model.eval()
self.mutator.eval()
meters = AverageMeterGroup()
with torch.no_grad():
self.mutator.reset()
for step, (X, y) in enumerate(self.test_loader):
X, y = X.to(self.device), y.to(self.device)
logits = self.model(X)
metrics = self.metrics(logits, y)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.test_loader), meters)
def _logits_and_loss(self, X, y):
self.mutator.reset()
logits = self.model(X)
loss = self.loss(logits, y)
return logits, loss
def _backward(self, val_X, val_y):
"""
Simple backward with gradient descent
"""
_, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
"""
Compute unrolled loss and backward its gradients
"""
backup_params = copy.deepcopy(tuple(self.model.parameters()))
# do virtual step on training data
lr = self.optimizer.param_groups[0]["lr"]
momentum = self.optimizer.param_groups[0]["momentum"]
weight_decay = self.optimizer.param_groups[0]["weight_decay"]
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_, loss = self._logits_and_loss(val_X, val_y)
w_model, w_ctrl = tuple(self.model.parameters()), tuple(self.mutator.parameters())
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
with torch.no_grad():
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h
# restore weights
self._restore_weights(backup_params)
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
"""
Compute unrolled weights w`
"""
# don't need zero_grad, using autograd to calculate gradients
_, loss = self._logits_and_loss(X, y)
gradients = torch.autograd.grad(loss, self.model.parameters())
with torch.no_grad():
for w, g in zip(self.model.parameters(), gradients):
m = self.optimizer.state[w].get("momentum_buffer", 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
with torch.no_grad():
for param, backup in zip(self.model.parameters(), backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
logger.warning("In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.", norm.item())
dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.model.parameters(), dw):
p += e * d
_, loss = self._logits_and_loss(trn_X, trn_y)
dalphas.append(torch.autograd.grad(loss, self.mutator.parameters()))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian
from .mutator import EnasMutator
from .trainer import EnasTrainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_c, prev_h = hidden
next_c, next_h = [], []
for i, m in enumerate(self.lstm_modules):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
next_c.append(curr_c)
next_h.append(curr_h)
inputs = curr_h[-1]
return next_c, next_h
class EnasMutator(Mutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, branch_bias=0.25):
super().__init__(model)
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target
self.branch_bias = branch_bias
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict()
self.max_layer_choice = 0
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \
"ENAS mutator requires all layer choice have the same number of candidates."
# We are judging by keys and module types to add biases to layer choices. Needs refactor.
if "reduce" in mutable.key:
def is_conv(choice):
return "conv" in str(type(choice)).lower()
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
for choice in mutable.choices])
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
def sample_search(self):
self._initialize()
self._sample(self.mutables)
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)
def _initialize(self):
self._choices = dict()
self._anchors_hid = dict()
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._c, self._h = self.lstm(self._inputs, (self._c, self._h))
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
def _sample_layer_choice(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy)
self._inputs = self.embedding(branch_id)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
if mutable.n_chosen is None:
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip)
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
else:
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = query.view(1, -1)
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()]
self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy)
return skip.bool()
import logging
import torch
import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import EnasMutator
logger = logging.getLogger(__name__)
class EnasTrainer(Trainer):
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4):
super().__init__(model, mutator if mutator is not None else EnasMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.reward_function = reward_function
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.mutator_steps_aggregate = mutator_steps_aggregate
self.mutator_steps = mutator_steps
self.aux_weight = aux_weight
n_train = len(self.dataset_train)
split = n_train // 10
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_one_epoch(self, epoch):
# Sample model and train
self.model.train()
self.mutator.eval()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
with torch.no_grad():
self.mutator.reset()
logits = self.model(x)
if isinstance(logits, tuple):
logits, aux_logits = logits
aux_loss = self.loss(aux_logits, y)
else:
aux_loss = 0.
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
loss = loss + self.aux_weight * aux_loss
loss.backward()
self.optimizer.step()
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("Model Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
# Train sampler (mutator)
self.model.eval()
self.mutator.train()
meters = AverageMeterGroup()
mutator_step, total_mutator_steps = 0, self.mutator_steps * self.mutator_steps_aggregate
while mutator_step < total_mutator_steps:
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self.mutator.reset()
with torch.no_grad():
logits = self.model(x)
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight is not None:
reward += self.entropy_weight * self.mutator.sample_entropy
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
self.baseline = self.baseline.detach().item()
loss = self.mutator.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.mutator.sample_skip_penalty
metrics["reward"] = reward
metrics["loss"] = loss.item()
metrics["ent"] = self.mutator.sample_entropy.item()
metrics["baseline"] = self.baseline
metrics["skip"] = self.mutator.sample_skip_penalty
loss = loss / self.mutator_steps_aggregate
loss.backward()
meters.update(metrics)
if mutator_step % self.mutator_steps_aggregate == 0:
self.mutator_optim.step()
self.mutator_optim.zero_grad()
if self.log_frequency is not None and step % self.log_frequency == 0:
logger.info("RL Epoch [%s/%s] Step [%s/%s] %s", epoch + 1, self.num_epochs,
mutator_step // self.mutator_steps_aggregate + 1, self.mutator_steps, meters)
mutator_step += 1
if mutator_step >= total_mutator_steps:
break
def validate_one_epoch(self, epoch):
pass
import json
import torch
from nni.nas.pytorch.mutables import MutableScope
from nni.nas.pytorch.mutator import Mutator
class FixedArchitecture(Mutator):
def __init__(self, model, fixed_arc, strict=True):
"""
Initialize a fixed architecture mutator.
Parameters
----------
model : nn.Module
A mutable network.
fixed_arc : str or dict
Path to the architecture checkpoint (a string), or preloaded architecture object (a dict).
strict : bool
Force everything that appears in `fixed_arc` to be used at least once.
"""
super().__init__(model)
self._fixed_arc = fixed_arc
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
fixed_arc_keys = set(self._fixed_arc.keys())
if fixed_arc_keys - mutable_keys:
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
if mutable_keys - fixed_arc_keys:
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
def sample_search(self):
return self._fixed_arc
def sample_final(self):
return self._fixed_arc
def _encode_tensor(data, device):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool, device=device) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float, device=device) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: _encode_tensor(v, device) for k, v in data.items()}
return data
def apply_fixed_architecture(model, fixed_arc_path, device=None):
"""
Load architecture from `fixed_arc_path` and apply to model.
Parameters
----------
model : torch.nn.Module
Model with mutables.
fixed_arc_path : str
Path to the JSON that stores the architecture.
device : torch.device
Architecture weights will be transfered to `device`.
Returns
-------
FixedArchitecture
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if isinstance(fixed_arc_path, str):
with open(fixed_arc_path, "r") as f:
fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc, device)
architecture = FixedArchitecture(model, fixed_arc)
architecture.to(device)
architecture.reset()
return architecture
import logging
import torch.nn as nn
from nni.nas.pytorch.utils import global_mutable_counting
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class Mutable(nn.Module):
"""
Mutable is designed to function as a normal layer, with all necessary operators' weights.
States and weights of architectures should be included in mutator, instead of the layer itself.
Mutable has a key, which marks the identity of the mutable. This key can be used by users to share
decisions among different mutables. In mutator's implementation, mutators should use the key to
distinguish different mutables. Mutables that share the same key should be "similar" to each other.
Currently the default scope for keys is global.
"""
def __init__(self, key=None):
super().__init__()
if key is not None:
if not isinstance(key, str):
key = str(key)
logger.warning("Warning: key \"%s\" is not string, converted to string.", key)
self._key = key
else:
self._key = self.__class__.__name__ + str(global_mutable_counting())
self.init_hook = self.forward_hook = None
def __deepcopy__(self, memodict=None):
raise NotImplementedError("Deep copy doesn't work for mutables.")
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator):
if "mutator" in self.__dict__:
raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? "
"Or did you apply multiple fixed architectures?")
self.__dict__["mutator"] = mutator
def forward(self, *inputs):
raise NotImplementedError
@property
def key(self):
return self._key
@property
def name(self):
return self._name if hasattr(self, "_name") else "_key"
@name.setter
def name(self, name):
self._name = name
def _check_built(self):
if not hasattr(self, "mutator"):
raise ValueError(
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(Mutable):
"""
Mutable scope marks a subgraph/submodule to help mutators make better decisions.
Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
MutableScope are also mutables that are listed in the mutables (search space).
"""
def __init__(self, key):
super().__init__(key=key)
def __call__(self, *args, **kwargs):
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
self.mutator.exit_mutable_scope(self)
class LayerChoice(Mutable):
def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
super().__init__(key=key)
self.length = len(op_candidates)
self.choices = nn.ModuleList(op_candidates)
self.reduction = reduction
self.return_mask = return_mask
def forward(self, *inputs):
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out
class InputChoice(Mutable):
"""
Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners,
use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to
know about `choose_from`.
The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones.
The keys are designed to be the keys of the sources. To help mutators make better decisions,
mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the
output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g.,
``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a
module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed.
"""
NO_KEY = ""
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
reduction="sum", return_mask=False, key=None):
"""
Initialization.
Parameters
----------
n_candidates : int
Number of inputs to choose from.
choose_from : list of str
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
number of empty string.
n_chosen : int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction : str
`mean`, `concat`, `sum` or `none`.
return_mask : bool
If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
key : str
Key of the input choice.
"""
super().__init__(key=key)
# precondition check
assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \
"must be not None."
if choose_from is not None and n_candidates is None:
n_candidates = len(choose_from)
elif choose_from is None and n_candidates is not None:
choose_from = [self.NO_KEY] * n_candidates
assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`."
assert n_candidates > 0, "Number of candidates must be greater than 0."
assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \
"than number of candidates."
self.n_candidates = n_candidates
self.choose_from = choose_from
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
def forward(self, optional_inputs):
"""
Forward method of LayerChoice.
Parameters
----------
optional_inputs : list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from`.
Returns
-------
tuple of torch.Tensor and torch.Tensor or torch.Tensor
"""
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), "Optional input list must be a list"
assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates."
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
if self.return_mask:
return out, mask
return out
import logging
import torch
from nni.nas.pytorch.base_mutator import BaseMutator
logger = logging.getLogger(__name__)
class Mutator(BaseMutator):
def __init__(self, model):
super().__init__(model)
self._cache = dict()
def sample_search(self):
"""
Override to implement this method to iterate over mutables and make decisions.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
def sample_final(self):
"""
Override to implement this method to iterate over mutables and make decisions that is final
for export and retraining.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
def reset(self):
"""
Reset the mutator by call the `sample_search` to resample (for search).
Returns
-------
None
"""
self._cache = self.sample_search()
def export(self):
"""
Resample (for final) and return results.
Returns
-------
dict
"""
return self.sample_final()
def on_forward_layer_choice(self, mutable, *inputs):
"""
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
Parameters
----------
mutable : LayerChoice
inputs : list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
"""
def _map_fn(op, *inputs):
return op(*inputs)
mask = self._get_decision(mutable)
assert len(mask) == len(mutable.choices)
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list):
"""
On default, this method calls :meth:`on_calc_input_choice_mask` with `tags`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
Parameters
----------
mutable : InputChoice
tensor_list : list of torch.Tensor
tags : list of string
Returns
-------
tuple of torch.Tensor and torch.Tensor
"""
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)]
else:
raise ValueError("Unrecognized mask")
return out
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def _get_decision(self, mutable):
"""
By default, this method checks whether `mutable.key` is already in the decision cache,
and returns the result without double-check.
Parameters
----------
mutable : Mutable
Returns
-------
object
"""
if mutable.key not in self._cache:
raise ValueError("\"{}\" not found in decision cache.".format(mutable.key))
result = self._cache[mutable.key]
logger.debug("Decision %s: %s", mutable.key, result)
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .trainer import PdartsTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import numpy as np
import torch.nn.functional as F
from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.mutables import LayerChoice
class PdartsMutator(DartsMutator):
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches={}):
self.pdarts_epoch_index = pdarts_epoch_index
self.pdarts_num_to_drop = pdarts_num_to_drop
if switches is None:
self.switches = {}
else:
self.switches = switches
super(PdartsMutator, self).__init__(model)
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
switches = self.switches.get(mutable.key, [True for j in range(mutable.length)])
for index in range(len(switches)-1, -1, -1):
if switches[index] == False:
del(mutable.choices[index])
mutable.length -= 1
self.switches[mutable.key] = switches
def drop_paths(self):
for key in self.switches:
prob = F.softmax(self.choices[key], dim=-1).data.cpu().numpy()
switches = self.switches[key]
idxs = []
for j in range(len(switches)):
if switches[j]:
idxs.append(j)
if self.pdarts_epoch_index == len(self.pdarts_num_to_drop) - 1:
# for the last stage, drop all Zero operations
drop = self.get_min_k_no_zero(prob, idxs, self.pdarts_num_to_drop[self.pdarts_epoch_index])
else:
drop = self.get_min_k(prob, self.pdarts_num_to_drop[self.pdarts_epoch_index])
for idx in drop:
switches[idxs[idx]] = False
return self.switches
def get_min_k(self, input_in, k):
index = []
for _ in range(k):
idx = np.argmin(input)
index.append(idx)
return index
def get_min_k_no_zero(self, w_in, idxs, k):
w = copy.deepcopy(w_in)
index = []
if 0 in idxs:
zf = True
else:
zf = False
if zf:
w = w[1:]
index.append(0)
k = k - 1
for _ in range(k):
idx = np.argmin(w)
w[idx] = 1
if zf:
idx = idx + 1
index.append(idx)
return index
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from nni.nas.pytorch.callbacks import LRSchedulerCallback
from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.trainer import BaseTrainer
from .mutator import PdartsMutator
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
class PdartsTrainer(BaseTrainer):
def __init__(self, model_creator, layers, metrics,
num_epochs, dataset_train, dataset_valid,
pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2],
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None):
super(PdartsTrainer, self).__init__()
self.model_creator = model_creator
self.layers = layers
self.pdarts_num_layers = pdarts_num_layers
self.pdarts_num_to_drop = pdarts_num_to_drop
self.pdarts_epoch = len(pdarts_num_to_drop)
self.darts_parameters = {
"metrics": metrics,
"num_epochs": num_epochs,
"dataset_train": dataset_train,
"dataset_valid": dataset_valid,
"batch_size": batch_size,
"workers": workers,
"device": device,
"log_frequency": log_frequency
}
self.callbacks = callbacks if callbacks is not None else []
def train(self):
layers = self.layers
switches = None
for epoch in range(self.pdarts_epoch):
layers = self.layers+self.pdarts_num_layers[epoch]
model, criterion, optim, lr_scheduler = self.model_creator(layers)
self.mutator = PdartsMutator(model, epoch, self.pdarts_num_to_drop, switches)
for callback in self.callbacks:
callback.build(model, self.mutator, self)
callback.on_epoch_begin(epoch)
darts_callbacks = []
if lr_scheduler is not None:
darts_callbacks.append(LRSchedulerCallback(lr_scheduler))
self.trainer = DartsTrainer(model, mutator=self.mutator, loss=criterion, optimizer=optim,
callbacks=darts_callbacks, **self.darts_parameters)
logger.info("start pdarts training %s...", epoch)
self.trainer.train()
switches = self.mutator.drop_paths()
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def validate(self):
self.model.validate()
def checkpoint(self):
raise NotImplementedError("Not implemented yet")
import json
import logging
from abc import abstractmethod
import torch
from .base_trainer import BaseTrainer
_logger = logging.getLogger(__name__)
_logger.setLevel(logging.INFO)
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
olist = o.tolist()
if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
_logger.warning("Every element in %s is either 0 or 1. "
"You might consider convert it into bool.", olist)
return olist
return super().default(o)
class Trainer(BaseTrainer):
def __init__(self, model, mutator, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, callbacks):
"""
Trainer initialization.
Parameters
----------
model : nn.Module
Model with mutables.
mutator : BaseMutator
A mutator object that has been initialized with the model.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : torch.utils.data.Dataset
Dataset of training.
dataset_valid : torch.utils.data.Dataset
Dataset of validation/testing.
batch_size : int
Batch size.
workers : int
Number of workers used in data preprocessing.
device : torch.device
Device object. Either `torch.device("cuda")` or torch.device("cpu")`. When `None`, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
callbacks : list of Callback
Callbacks to plug into the trainer. See Callbacks.
"""
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.model = model
self.mutator = mutator
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.model.to(self.device)
self.mutator.to(self.device)
self.loss.to(self.device)
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.batch_size = batch_size
self.workers = workers
self.log_frequency = log_frequency
self.callbacks = callbacks if callbacks is not None else []
for callback in self.callbacks:
callback.build(self.model, self.mutator, self)
@abstractmethod
def train_one_epoch(self, epoch):
pass
@abstractmethod
def validate_one_epoch(self, epoch):
pass
def train(self, validate=True):
for epoch in range(self.num_epochs):
for callback in self.callbacks:
callback.on_epoch_begin(epoch)
# training
_logger.info("Epoch %d Training", epoch)
self.train_one_epoch(epoch)
if validate:
# validation
_logger.info("Epoch %d Validating", epoch)
self.validate_one_epoch(epoch)
for callback in self.callbacks:
callback.on_epoch_end(epoch)
def validate(self):
self.validate_one_epoch(-1)
def export(self, file):
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=TorchTensorEncoder)
def checkpoint(self):
raise NotImplementedError("Not implemented yet")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment