Commit 77e91e8b authored by Yuge Zhang's avatar Yuge Zhang Committed by Chi Song
Browse files

Extract controller from mutator to make offline decisions (#1758)

parent 9dda5370
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import ops import ops
from nni.nas.pytorch import mutables, darts from nni.nas.pytorch import mutables
class AuxiliaryHead(nn.Module): class AuxiliaryHead(nn.Module):
...@@ -31,12 +31,14 @@ class AuxiliaryHead(nn.Module): ...@@ -31,12 +31,14 @@ class AuxiliaryHead(nn.Module):
return logits return logits
class Node(darts.DartsNode): class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, drop_path_prob=0.): def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__(node_id, limitation=2) super().__init__()
self.ops = nn.ModuleList() self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes): for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1 stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append( self.ops.append(
mutables.LayerChoice( mutables.LayerChoice(
[ [
...@@ -48,18 +50,19 @@ class Node(darts.DartsNode): ...@@ -48,18 +50,19 @@ class Node(darts.DartsNode):
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False), ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False), ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False),
], ],
key="{}_p{}".format(node_id, i))) key=choice_keys[-1]))
self.drop_path = ops.DropPath_(drop_path_prob) self.drop_path = ops.DropPath_()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
def forward(self, prev_nodes): def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes) assert len(self.ops) == len(prev_nodes)
out = [op(node) for op, node in zip(self.ops, prev_nodes)] out = [op(node) for op, node in zip(self.ops, prev_nodes)]
return sum(self.drop_path(o) for o in out if o is not None) return self.input_switch(out)
class Cell(nn.Module): class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, drop_path_prob=0.): def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__() super().__init__()
self.reduction = reduction self.reduction = reduction
self.n_nodes = n_nodes self.n_nodes = n_nodes
...@@ -74,10 +77,9 @@ class Cell(nn.Module): ...@@ -74,10 +77,9 @@ class Cell(nn.Module):
# generate dag # generate dag
self.mutable_ops = nn.ModuleList() self.mutable_ops = nn.ModuleList()
for depth in range(self.n_nodes): for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("r{:d}_n{}".format(reduction, depth), self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth + 2, channels, 2 if reduction else 0, depth, channels, 2 if reduction else 0))
drop_path_prob=drop_path_prob))
def forward(self, s0, s1): def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively. # s0, s1 are the outputs of previous previous cell and previous cell, respectively.
...@@ -93,7 +95,7 @@ class Cell(nn.Module): ...@@ -93,7 +95,7 @@ class Cell(nn.Module):
class CNN(nn.Module): class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4, def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
stem_multiplier=3, auxiliary=False, drop_path_prob=0.): stem_multiplier=3, auxiliary=False):
super().__init__() super().__init__()
self.in_channels = in_channels self.in_channels = in_channels
self.channels = channels self.channels = channels
...@@ -120,7 +122,7 @@ class CNN(nn.Module): ...@@ -120,7 +122,7 @@ class CNN(nn.Module):
c_cur *= 2 c_cur *= 2
reduction = True reduction = True
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, drop_path_prob=drop_path_prob) cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell) self.cells.append(cell)
c_cur_out = c_cur * n_nodes c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out channels_pp, channels_p = channels_p, c_cur_out
...@@ -147,3 +149,8 @@ class CNN(nn.Module): ...@@ -147,3 +149,8 @@ class CNN(nn.Module):
if aux_logits is not None: if aux_logits is not None:
return logits, aux_logits return logits, aux_logits
return logits return logits
def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath_):
module.p = p
import logging
from argparse import ArgumentParser
import torch
import torch.nn as nn
import datasets
import utils
from model import CNN
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train(config, train_loader, model, optimizer, criterion, epoch):
top1 = AverageMeter("top1")
top5 = AverageMeter("top5")
losses = AverageMeter("losses")
cur_step = epoch * len(train_loader)
cur_lr = optimizer.param_groups[0]['lr']
logger.info("Epoch %d LR %.6f", epoch, cur_lr)
model.train()
for step, (x, y) in enumerate(train_loader):
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
bs = x.size(0)
optimizer.zero_grad()
logits, aux_logits = model(x)
loss = criterion(logits, y)
if config.aux_weight > 0.:
loss += config.aux_weight * criterion(aux_logits, y)
loss.backward()
# gradient clipping
nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
optimizer.step()
accuracy = utils.accuracy(logits, y, topk=(1, 5))
losses.update(loss.item(), bs)
top1.update(accuracy["acc1"], bs)
top5.update(accuracy["acc5"], bs)
if step % config.log_frequency == 0 or step == len(train_loader) - 1:
logger.info(
"Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
epoch + 1, config.epochs, step, len(train_loader) - 1, losses=losses,
top1=top1, top5=top5))
cur_step += 1
logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))
def validate(config, valid_loader, model, criterion, epoch, cur_step):
top1 = AverageMeter("top1")
top5 = AverageMeter("top5")
losses = AverageMeter("losses")
model.eval()
with torch.no_grad():
for step, (X, y) in enumerate(valid_loader):
X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
N = X.size(0)
logits = model(X)
loss = criterion(logits, y)
accuracy = utils.accuracy(logits, y, topk=(1, 5))
losses.update(loss.item(), N)
top1.update(accuracy["acc1"], N)
top5.update(accuracy["acc5"], N)
if step % config.log_frequency == 0 or step == len(valid_loader) - 1:
logger.info(
"Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses,
top1=top1, top5=top5))
logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))
return top1.avg
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=20, type=int)
parser.add_argument("--batch-size", default=96, type=int)
parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--epochs", default=600, type=int)
parser.add_argument("--aux-weight", default=0.4, type=float)
parser.add_argument("--drop-path-prob", default=0.2, type=float)
parser.add_argument("--workers", default=4)
parser.add_argument("--grad-clip", default=5., type=float)
parser.add_argument("--arc-checkpoint", default="./checkpoints/epoch_0.json")
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16)
model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
apply_fixed_architecture(model, args.arc_checkpoint, device=device)
criterion = nn.CrossEntropyLoss()
model.to(device)
criterion.to(device)
optimizer = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6)
train_loader = torch.utils.data.DataLoader(dataset_train,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True)
valid_loader = torch.utils.data.DataLoader(dataset_valid,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True)
best_top1 = 0.
for epoch in range(args.epochs):
drop_prob = args.drop_path_prob * epoch / args.epochs
model.drop_path_prob(drop_prob)
# training
train(args, train_loader, model, optimizer, criterion, epoch)
# validation
cur_step = (epoch + 1) * len(train_loader)
top1 = validate(args, valid_loader, model, criterion, epoch, cur_step)
best_top1 = max(best_top1, top1)
lr_scheduler.step()
logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
...@@ -13,7 +13,7 @@ from utils import accuracy ...@@ -13,7 +13,7 @@ from utils import accuracy
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser("darts") parser = ArgumentParser("darts")
parser.add_argument("--layers", default=8, type=int) parser.add_argument("--layers", default=8, type=int)
parser.add_argument("--batch-size", default=96, type=int) parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--log-frequency", default=10, type=int) parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--epochs", default=50, type=int) parser.add_argument("--epochs", default=50, type=int)
args = parser.parse_args() args = parser.parse_args()
...@@ -36,4 +36,4 @@ if __name__ == "__main__": ...@@ -36,4 +36,4 @@ if __name__ == "__main__":
batch_size=args.batch_size, batch_size=args.batch_size,
log_frequency=args.log_frequency, log_frequency=args.log_frequency,
callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
trainer.train_and_validate() trainer.train()
...@@ -6,7 +6,7 @@ from ops import FactorizedReduce, ConvBranch, PoolBranch ...@@ -6,7 +6,7 @@ from ops import FactorizedReduce, ConvBranch, PoolBranch
class ENASLayer(mutables.MutableScope): class ENASLayer(mutables.MutableScope):
def __init__(self, key, num_prev_layers, in_filters, out_filters): def __init__(self, key, prev_labels, in_filters, out_filters):
super().__init__(key) super().__init__(key)
self.in_filters = in_filters self.in_filters = in_filters
self.out_filters = out_filters self.out_filters = out_filters
...@@ -18,16 +18,16 @@ class ENASLayer(mutables.MutableScope): ...@@ -18,16 +18,16 @@ class ENASLayer(mutables.MutableScope):
PoolBranch('avg', in_filters, out_filters, 3, 1, 1), PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1) PoolBranch('max', in_filters, out_filters, 3, 1, 1)
]) ])
if num_prev_layers > 0: if len(prev_labels) > 0:
self.skipconnect = mutables.InputChoice(num_prev_layers, n_selected=None, reduction="sum") self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None, reduction="sum")
else: else:
self.skipconnect = None self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
def forward(self, prev_layers, prev_labels): def forward(self, prev_layers):
out = self.mutable(prev_layers[-1]) out = self.mutable(prev_layers[-1])
if self.skipconnect is not None: if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[:-1], tags=prev_labels) connection = self.skipconnect(prev_layers[:-1])
if connection is not None: if connection is not None:
out += connection out += connection
return self.batch_norm(out) return self.batch_norm(out)
...@@ -53,11 +53,12 @@ class GeneralNetwork(nn.Module): ...@@ -53,11 +53,12 @@ class GeneralNetwork(nn.Module):
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
self.pool_layers = nn.ModuleList() self.pool_layers = nn.ModuleList()
labels = []
for layer_id in range(self.num_layers): for layer_id in range(self.num_layers):
labels.append("layer_{}".format(layer_id))
if layer_id in self.pool_layers_idx: if layer_id in self.pool_layers_idx:
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters)) self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
self.layers.append(ENASLayer("layer_{}".format(layer_id), layer_id, self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters))
self.out_filters, self.out_filters))
self.gap = nn.AdaptiveAvgPool2d(1) self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(self.out_filters, self.num_classes) self.dense = nn.Linear(self.out_filters, self.num_classes)
...@@ -66,12 +67,11 @@ class GeneralNetwork(nn.Module): ...@@ -66,12 +67,11 @@ class GeneralNetwork(nn.Module):
bs = x.size(0) bs = x.size(0)
cur = self.stem(x) cur = self.stem(x)
layers, labels = [cur], [] layers = [cur]
for layer_id in range(self.num_layers): for layer_id in range(self.num_layers):
cur = self.layers[layer_id](layers, labels) cur = self.layers[layer_id](layers)
layers.append(cur) layers.append(cur)
labels.append(self.layers[layer_id].key)
if layer_id in self.pool_layers_idx: if layer_id in self.pool_layers_idx:
for i, layer in enumerate(layers): for i, layer in enumerate(layers):
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer) layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
......
...@@ -32,9 +32,9 @@ class AuxiliaryHead(nn.Module): ...@@ -32,9 +32,9 @@ class AuxiliaryHead(nn.Module):
class Cell(nn.Module): class Cell(nn.Module):
def __init__(self, cell_name, num_prev_layers, channels): def __init__(self, cell_name, prev_labels, channels):
super().__init__() super().__init__()
self.input_choice = mutables.InputChoice(num_prev_layers, n_selected=1, return_mask=True, self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
key=cell_name + "_input") key=cell_name + "_input")
self.op_choice = mutables.LayerChoice([ self.op_choice = mutables.LayerChoice([
SepConvBN(channels, channels, 3, 1), SepConvBN(channels, channels, 3, 1),
...@@ -44,21 +44,21 @@ class Cell(nn.Module): ...@@ -44,21 +44,21 @@ class Cell(nn.Module):
nn.Identity() nn.Identity()
], key=cell_name + "_op") ], key=cell_name + "_op")
def forward(self, prev_layers, prev_labels): def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers, tags=prev_labels) chosen_input, chosen_mask = self.input_choice(prev_layers)
cell_out = self.op_choice(chosen_input) cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask return cell_out, chosen_mask
class Node(mutables.MutableScope): class Node(mutables.MutableScope):
def __init__(self, node_name, num_prev_layers, channels): def __init__(self, node_name, prev_node_names, channels):
super().__init__(node_name) super().__init__(node_name)
self.cell_x = Cell(node_name + "_x", num_prev_layers, channels) self.cell_x = Cell(node_name + "_x", prev_node_names, channels)
self.cell_y = Cell(node_name + "_y", num_prev_layers, channels) self.cell_y = Cell(node_name + "_y", prev_node_names, channels)
def forward(self, prev_layers, prev_labels): def forward(self, prev_layers):
out_x, mask_x = self.cell_x(prev_layers, prev_labels) out_x, mask_x = self.cell_x(prev_layers)
out_y, mask_y = self.cell_y(prev_layers, prev_labels) out_y, mask_y = self.cell_y(prev_layers)
return out_x + out_y, mask_x | mask_y return out_x + out_y, mask_x | mask_y
...@@ -93,8 +93,11 @@ class ENASLayer(nn.Module): ...@@ -93,8 +93,11 @@ class ENASLayer(nn.Module):
self.num_nodes = num_nodes self.num_nodes = num_nodes
name_prefix = "reduce" if reduction else "normal" name_prefix = "reduce" if reduction else "normal"
self.nodes = nn.ModuleList([Node("{}_node_{}".format(name_prefix, i), self.nodes = nn.ModuleList()
i + 2, out_channels) for i in range(num_nodes)]) node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY]
for i in range(num_nodes):
node_labels.append("{}_node_{}".format(name_prefix, i))
self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels))
self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True) self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True)
self.bn = nn.BatchNorm2d(out_channels, affine=False) self.bn = nn.BatchNorm2d(out_channels, affine=False)
self.reset_parameters() self.reset_parameters()
...@@ -106,14 +109,12 @@ class ENASLayer(nn.Module): ...@@ -106,14 +109,12 @@ class ENASLayer(nn.Module):
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev) pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)
prev_nodes_out = [pprev_, prev_] prev_nodes_out = [pprev_, prev_]
prev_nodes_labels = ["prev1", "prev2"]
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device) nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
for i in range(self.num_nodes): for i in range(self.num_nodes):
node_out, mask = self.nodes[i](prev_nodes_out, prev_nodes_labels) node_out, mask = self.nodes[i](prev_nodes_out)
nodes_used_mask[:mask.size(0)] |= mask nodes_used_mask[:mask.size(0)] |= mask
prev_nodes_out.append(node_out) prev_nodes_out.append(node_out)
prev_nodes_labels.append(self.nodes[i].key)
unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1) unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1)
unused_nodes = F.relu(unused_nodes) unused_nodes = F.relu(unused_nodes)
conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :] conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :]
......
...@@ -13,7 +13,7 @@ from utils import accuracy, reward_accuracy ...@@ -13,7 +13,7 @@ from utils import accuracy, reward_accuracy
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser("enas") parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int) parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--search-for", choices=["macro", "micro"], default="macro") parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
args = parser.parse_args() args = parser.parse_args()
...@@ -43,5 +43,6 @@ if __name__ == "__main__": ...@@ -43,5 +43,6 @@ if __name__ == "__main__":
num_epochs=num_epochs, num_epochs=num_epochs,
dataset_train=dataset_train, dataset_train=dataset_train,
dataset_valid=dataset_valid, dataset_valid=dataset_valid,
log_frequency=args.log_frequency) log_frequency=args.log_frequency,
trainer.train_and_validate() mutator=mutator)
trainer.train()
import logging import logging
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.mutables import Mutable, MutableScope, InputChoice
from nni.nas.pytorch.mutables import Mutable from nni.nas.pytorch.utils import StructuredMutableTreeNode
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
class BaseMutator(nn.Module): class BaseMutator(nn.Module):
"""
A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing
callbacks that are called in ``forward`` in Mutables.
"""
def __init__(self, model): def __init__(self, model):
super().__init__() super().__init__()
self.__dict__["model"] = model self.__dict__["model"] = model
self.before_parse_search_space() self._structured_mutables = self._parse_search_space(self.model)
self._parse_search_space()
self.after_parse_search_space()
def before_parse_search_space(self):
pass
def after_parse_search_space(self):
pass
def _parse_search_space(self):
for name, mutable, _ in self.named_mutables(distinct=False):
mutable.name = name
mutable.set_mutator(self)
def named_mutables(self, root=None, distinct=True): def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None):
if memo is None:
memo = set()
if root is None: if root is None:
root = self.model root = StructuredMutableTreeNode(None)
# if distinct is true, the method will filter out those with duplicated keys if module not in memo:
key2module = dict() memo.add(module)
for name, module in root.named_modules():
if isinstance(module, Mutable): if isinstance(module, Mutable):
module_distinct = False if nested_detection is not None:
if module.key in key2module: raise RuntimeError("Cannot have nested search space. Error at {} in {}"
assert key2module[module.key].similar(module), \ .format(module, nested_detection))
"Mutable \"{}\" that share the same key must be similar to each other".format(module.key) module.name = prefix
else: module.set_mutator(self)
module_distinct = True root = root.add_child(module)
key2module[module.key] = module if not isinstance(module, MutableScope):
if distinct: nested_detection = module
if module_distinct: if isinstance(module, InputChoice):
yield name, module for k in module.choose_from:
else: if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
yield name, module, module_distinct raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY."
.format(k, module.key))
def __setattr__(self, key, value): for name, submodule in module._modules.items():
if key in ["model", "net", "network"]: if submodule is None:
logger.warning("Think twice if you are including the network into mutator.") continue
return super().__setattr__(key, value) submodule_prefix = prefix + ("." if prefix else "") + name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo,
nested_detection=nested_detection)
return root
@property
def mutables(self):
return self._structured_mutables
@property
def forward(self, *inputs): def forward(self, *inputs):
raise NotImplementedError("Mutator is not forward-able") raise RuntimeError("Forward is undefined for mutators.")
def enter_mutable_scope(self, mutable_scope): def enter_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is entered.
Parameters
----------
mutable_scope: MutableScope
Returns
-------
None
"""
pass pass
def exit_mutable_scope(self, mutable_scope): def exit_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is exited.
Parameters
----------
mutable_scope: MutableScope
Returns
-------
None
"""
pass pass
def on_forward_layer_choice(self, mutable, *inputs): def on_forward_layer_choice(self, mutable, *inputs):
"""
Callbacks of forward in LayerChoice.
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
output tensor and mask
"""
raise NotImplementedError raise NotImplementedError
def on_forward_input_choice(self, mutable, tensor_list, tags): def on_forward_input_choice(self, mutable, tensor_list):
"""
Callbacks of forward in InputChoice.
Parameters
----------
mutable: InputChoice
tensor_list: list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
output tensor and mask
"""
raise NotImplementedError raise NotImplementedError
def export(self): def export(self):
"""
Export the data of all decisions. This should output the decisions of all the mutables, so that the whole
network can be fully determined with these decisions for further training from scratch.
Returns
-------
dict
"""
raise NotImplementedError raise NotImplementedError
...@@ -12,5 +12,9 @@ class BaseTrainer(ABC): ...@@ -12,5 +12,9 @@ class BaseTrainer(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def train_and_validate(self): def export(self, file):
raise NotImplementedError
@abstractmethod
def checkpoint(self):
raise NotImplementedError raise NotImplementedError
import json
import logging import logging
import os import os
import torch
_logger = logging.getLogger(__name__) _logger = logging.getLogger(__name__)
...@@ -44,26 +41,11 @@ class LearningRateScheduler(Callback): ...@@ -44,26 +41,11 @@ class LearningRateScheduler(Callback):
class ArchitectureCheckpoint(Callback): class ArchitectureCheckpoint(Callback):
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
olist = o.tolist()
if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
_logger.warning("Every element in %s is either 0 or 1. "
"You might consider convert it into bool.", olist)
return olist
return super().default(o)
def __init__(self, checkpoint_dir, every="epoch"): def __init__(self, checkpoint_dir, every="epoch"):
super().__init__() super().__init__()
assert every == "epoch" assert every == "epoch"
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True) os.makedirs(self.checkpoint_dir, exist_ok=True)
def _export_to_file(self, file):
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=self.TorchTensorEncoder)
def on_epoch_end(self, epoch): def on_epoch_end(self, epoch):
self._export_to_file(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch))) self.trainer.export(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)))
from .mutator import DartsMutator from .mutator import DartsMutator
from .trainer import DartsTrainer from .trainer import DartsTrainer
from .scope import DartsNode \ No newline at end of file
\ No newline at end of file
...@@ -2,35 +2,47 @@ import torch ...@@ -2,35 +2,47 @@ import torch
from torch import nn as nn from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from nni.nas.pytorch.mutables import LayerChoice
from nni.nas.pytorch.mutator import Mutator from nni.nas.pytorch.mutator import Mutator
from .scope import DartsNode from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class DartsMutator(Mutator): class DartsMutator(Mutator):
def __init__(self, model):
def after_parse_search_space(self): super().__init__(model)
self.choices = nn.ParameterDict() self.choices = nn.ParameterDict()
for _, mutable in self.named_mutables(): for mutable in self.mutables:
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(len(mutable) + 1)) self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1))
def on_calc_layer_choice_mask(self, mutable: LayerChoice): def device(self):
return F.softmax(self.choices[mutable.key], dim=-1)[:-1] for v in self.choices.values():
return v.device
def export(self): def sample_search(self):
result = super().export() result = dict()
for _, darts_node in self.named_mutables(): for mutable in self.mutables:
if isinstance(darts_node, DartsNode): if isinstance(mutable, LayerChoice):
keys, edges_max = [], [] # key of all the layer choices in current node, and their best edge weight result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1]
for _, choice in self.named_mutables(darts_node): elif isinstance(mutable, InputChoice):
if isinstance(choice, LayerChoice): result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device())
keys.append(choice.key) return result
max_val, index = torch.max(result[choice.key], 0)
edges_max.append(max_val) def sample_final(self):
result[choice.key] = F.one_hot(index, num_classes=len(result[choice.key])).view(-1).bool() result = dict()
_, topk_edge_indices = torch.topk(torch.tensor(edges_max).view(-1), darts_node.limitation) # pylint: disable=not-callable edges_max = dict()
for i, key in enumerate(keys): for mutable in self.mutables:
if i not in topk_edge_indices: if isinstance(mutable, LayerChoice):
result[key] = torch.zeros_like(result[key]) max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool()
for mutable in self.mutables:
if isinstance(mutable, InputChoice):
weights = torch.tensor([edges_max.get(src_key, 0.) for src_key in mutable.choose_from]) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen or mutable.n_candidates)
selected_multihot = []
for i, src_key in enumerate(mutable.choose_from):
if i not in topk_edge_indices and src_key in result:
result[src_key] = torch.zeros_like(result[src_key]) # clear this choice to optimize calc graph
selected_multihot.append(i in topk_edge_indices)
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
return result return result
from nni.nas.pytorch.mutables import MutableScope
class DartsNode(MutableScope):
"""
At most `limitation` choice is activated in a `DartsNode` when exporting.
"""
def __init__(self, key, limitation):
super().__init__(key)
self.limitation = limitation
...@@ -4,7 +4,7 @@ import torch ...@@ -4,7 +4,7 @@ import torch
from torch import nn as nn from torch import nn as nn
from nni.nas.pytorch.trainer import Trainer from nni.nas.pytorch.trainer import Trainer
from nni.nas.utils import AverageMeterGroup from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import DartsMutator from .mutator import DartsMutator
...@@ -13,9 +13,9 @@ class DartsTrainer(Trainer): ...@@ -13,9 +13,9 @@ class DartsTrainer(Trainer):
optimizer, num_epochs, dataset_train, dataset_valid, optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None): callbacks=None):
super().__init__(model, loss, metrics, optimizer, num_epochs, super().__init__(model, mutator if mutator is not None else DartsMutator(model),
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
mutator if mutator is not None else DartsMutator(model), callbacks) batch_size, workers, device, log_frequency, callbacks)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999), self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999),
weight_decay=1.0E-3) weight_decay=1.0E-3)
n_train = len(self.dataset_train) n_train = len(self.dataset_train)
...@@ -31,6 +31,9 @@ class DartsTrainer(Trainer): ...@@ -31,6 +31,9 @@ class DartsTrainer(Trainer):
batch_size=batch_size, batch_size=batch_size,
sampler=valid_sampler, sampler=valid_sampler,
num_workers=workers) num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_one_epoch(self, epoch): def train_one_epoch(self, epoch):
self.model.train() self.model.train()
...@@ -47,8 +50,8 @@ class DartsTrainer(Trainer): ...@@ -47,8 +50,8 @@ class DartsTrainer(Trainer):
# phase 1. child network step # phase 1. child network step
self.optimizer.zero_grad() self.optimizer.zero_grad()
with self.mutator.forward_pass(): self.mutator.reset()
logits = self.model(trn_X) logits = self.model(trn_X)
loss = self.loss(logits, trn_y) loss = self.loss(logits, trn_y)
loss.backward() loss.backward()
# gradient clipping # gradient clipping
...@@ -76,10 +79,10 @@ class DartsTrainer(Trainer): ...@@ -76,10 +79,10 @@ class DartsTrainer(Trainer):
self.mutator.eval() self.mutator.eval()
meters = AverageMeterGroup() meters = AverageMeterGroup()
with torch.no_grad(): with torch.no_grad():
for step, (X, y) in enumerate(self.valid_loader): self.mutator.reset()
for step, (X, y) in enumerate(self.test_loader):
X, y = X.to(self.device), y.to(self.device) X, y = X.to(self.device), y.to(self.device)
with self.mutator.forward_pass(): logits = self.model(X)
logits = self.model(X)
metrics = self.metrics(logits, y) metrics = self.metrics(logits, y)
meters.update(metrics) meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
...@@ -93,8 +96,8 @@ class DartsTrainer(Trainer): ...@@ -93,8 +96,8 @@ class DartsTrainer(Trainer):
v_model: backup model before this step v_model: backup model before this step
lr: learning rate for virtual gradient step (same as net lr) lr: learning rate for virtual gradient step (same as net lr)
""" """
with self.mutator.forward_pass(): self.mutator.reset()
loss = self.loss(self.model(val_X), val_y) loss = self.loss(self.model(val_X), val_y)
w_model = tuple(self.model.parameters()) w_model = tuple(self.model.parameters())
w_ctrl = tuple(self.mutator.parameters()) w_ctrl = tuple(self.mutator.parameters())
w_grads = torch.autograd.grad(loss, w_model + w_ctrl) w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
...@@ -125,8 +128,8 @@ class DartsTrainer(Trainer): ...@@ -125,8 +128,8 @@ class DartsTrainer(Trainer):
for p, d in zip(self.model.parameters(), dw): for p, d in zip(self.model.parameters(), dw):
p += eps * d p += eps * d
with self.mutator.forward_pass(): self.mutator.reset()
loss = self.loss(self.model(trn_X), trn_y) loss = self.loss(self.model(trn_X), trn_y)
if e > 0: if e > 0:
dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) } dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) }
elif e < 0: elif e < 0:
......
...@@ -2,8 +2,8 @@ import torch ...@@ -2,8 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nni.nas.pytorch.mutables import LayerChoice
from nni.nas.pytorch.mutator import Mutator from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class StackedLSTMCell(nn.Module): class StackedLSTMCell(nn.Module):
...@@ -27,15 +27,14 @@ class StackedLSTMCell(nn.Module): ...@@ -27,15 +27,14 @@ class StackedLSTMCell(nn.Module):
class EnasMutator(Mutator): class EnasMutator(Mutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False, def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, branch_bias=0.25): skip_target=0.4, branch_bias=0.25):
super().__init__(model)
self.lstm_size = lstm_size self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant self.tanh_constant = tanh_constant
self.cell_exit_extra_step = cell_exit_extra_step self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target self.skip_target = skip_target
self.branch_bias = branch_bias self.branch_bias = branch_bias
super().__init__(model)
def before_parse_search_space(self):
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
...@@ -45,9 +44,8 @@ class EnasMutator(Mutator): ...@@ -45,9 +44,8 @@ class EnasMutator(Mutator):
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none") self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict() self.bias_dict = nn.ParameterDict()
def after_parse_search_space(self):
self.max_layer_choice = 0 self.max_layer_choice = 0
for _, mutable in self.named_mutables(): for mutable in self.mutables:
if isinstance(mutable, LayerChoice): if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0: if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length self.max_layer_choice = mutable.length
...@@ -64,8 +62,29 @@ class EnasMutator(Mutator): ...@@ -64,8 +62,29 @@ class EnasMutator(Mutator):
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False) self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
def before_pass(self): def sample_search(self):
super().before_pass() self._initialize()
self._sample(self.mutables)
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)
def _initialize(self):
self._choices = dict()
self._anchors_hid = dict() self._anchors_hid = dict()
self._inputs = self.g_emb.data self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size), self._c = [torch.zeros((1, self.lstm_size),
...@@ -84,7 +103,7 @@ class EnasMutator(Mutator): ...@@ -84,7 +103,7 @@ class EnasMutator(Mutator):
def _mark_anchor(self, key): def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1] self._anchors_hid[key] = self._h[-1]
def on_calc_layer_choice_mask(self, mutable): def _sample_layer_choice(self, mutable):
self._lstm_next_step() self._lstm_next_step()
logit = self.soft(self._h[-1]) logit = self.soft(self._h[-1])
if self.tanh_constant is not None: if self.tanh_constant is not None:
...@@ -94,14 +113,14 @@ class EnasMutator(Mutator): ...@@ -94,14 +113,14 @@ class EnasMutator(Mutator):
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id) log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += torch.sum(log_prob) self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy) self.sample_entropy += torch.sum(entropy)
self._inputs = self.embedding(branch_id) self._inputs = self.embedding(branch_id)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1) return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
def on_calc_input_choice_mask(self, mutable, tags): def _sample_input_choice(self, mutable):
query, anchors = [], [] query, anchors = [], []
for label in tags: for label in mutable.choose_from:
if label not in self._anchors_hid: if label not in self._anchors_hid:
self._lstm_next_step() self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found self._mark_anchor(label) # empty loop, fill not found
...@@ -113,8 +132,8 @@ class EnasMutator(Mutator): ...@@ -113,8 +132,8 @@ class EnasMutator(Mutator):
if self.tanh_constant is not None: if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query) query = self.tanh_constant * torch.tanh(query)
if mutable.n_selected is None: if mutable.n_chosen is None:
logit = torch.cat([-query, query], 1) logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit) skip_prob = torch.sigmoid(logit)
...@@ -123,19 +142,14 @@ class EnasMutator(Mutator): ...@@ -123,19 +142,14 @@ class EnasMutator(Mutator):
log_prob = self.cross_entropy_loss(logit, skip) log_prob = self.cross_entropy_loss(logit, skip)
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0) self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
else: else:
assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS." assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = query.view(1, -1) logit = query.view(1, -1)
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip = F.one_hot(index).view(-1) skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
log_prob = self.cross_entropy_loss(logit, index) log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()] self._inputs = anchors[index.item()]
self.sample_log_prob += torch.sum(log_prob) self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy) self.sample_entropy += torch.sum(entropy)
return skip.bool() return skip.bool()
def exit_mutable_scope(self, mutable_scope):
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable_scope.key)
...@@ -2,7 +2,7 @@ import torch ...@@ -2,7 +2,7 @@ import torch
import torch.optim as optim import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer from nni.nas.pytorch.trainer import Trainer
from nni.nas.utils import AverageMeterGroup from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import EnasMutator from .mutator import EnasMutator
...@@ -12,9 +12,9 @@ class EnasTrainer(Trainer): ...@@ -12,9 +12,9 @@ class EnasTrainer(Trainer):
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4): mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4):
super().__init__(model, loss, metrics, optimizer, num_epochs, super().__init__(model, mutator if mutator is not None else EnasMutator(model),
dataset_train, dataset_valid, batch_size, workers, device, log_frequency, loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
mutator if mutator is not None else EnasMutator(model), callbacks) batch_size, workers, device, log_frequency, callbacks)
self.reward_function = reward_function self.reward_function = reward_function
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr) self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
...@@ -52,8 +52,9 @@ class EnasTrainer(Trainer): ...@@ -52,8 +52,9 @@ class EnasTrainer(Trainer):
x, y = x.to(self.device), y.to(self.device) x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad() self.optimizer.zero_grad()
with self.mutator.forward_pass(): with torch.no_grad():
logits = self.model(x) self.mutator.reset()
logits = self.model(x)
if isinstance(logits, tuple): if isinstance(logits, tuple):
logits, aux_logits = logits logits, aux_logits = logits
...@@ -81,7 +82,8 @@ class EnasTrainer(Trainer): ...@@ -81,7 +82,8 @@ class EnasTrainer(Trainer):
for step, (x, y) in enumerate(self.valid_loader): for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device) x, y = x.to(self.device), y.to(self.device)
with self.mutator.forward_pass(): self.mutator.reset()
with torch.no_grad():
logits = self.model(x) logits = self.model(x)
metrics = self.metrics(logits, y) metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y) reward = self.reward_function(logits, y)
...@@ -107,9 +109,9 @@ class EnasTrainer(Trainer): ...@@ -107,9 +109,9 @@ class EnasTrainer(Trainer):
self.mutator_optim.zero_grad() self.mutator_optim.zero_grad()
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
print("Mutator Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, print("RL Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs,
mutator_step // self.mutator_steps_aggregate, mutator_step // self.mutator_steps_aggregate,
self.mutator_steps, meters)) self.mutator_steps, meters))
mutator_step += 1 mutator_step += 1
if mutator_step >= total_mutator_steps: if mutator_step >= total_mutator_steps:
break break
......
...@@ -2,10 +2,12 @@ import json ...@@ -2,10 +2,12 @@ import json
import torch import torch
from nni.nas.pytorch.mutables import MutableScope
from nni.nas.pytorch.mutator import Mutator from nni.nas.pytorch.mutator import Mutator
class FixedArchitecture(Mutator): class FixedArchitecture(Mutator):
def __init__(self, model, fixed_arc, strict=True): def __init__(self, model, fixed_arc, strict=True):
""" """
Initialize a fixed architecture mutator. Initialize a fixed architecture mutator.
...@@ -20,39 +22,57 @@ class FixedArchitecture(Mutator): ...@@ -20,39 +22,57 @@ class FixedArchitecture(Mutator):
Force everything that appears in `fixed_arc` to be used at least once. Force everything that appears in `fixed_arc` to be used at least once.
""" """
super().__init__(model) super().__init__(model)
if isinstance(fixed_arc, str):
with open(fixed_arc, "r") as f:
fixed_arc = json.load(f.read())
self._fixed_arc = fixed_arc self._fixed_arc = fixed_arc
self._strict = strict
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
def _encode_tensor(self, data): fixed_arc_keys = set(self._fixed_arc.keys())
if isinstance(data, list): if fixed_arc_keys - mutable_keys:
if all(map(lambda o: isinstance(o, bool), data)): raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable if mutable_keys - fixed_arc_keys:
else: raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
if isinstance(data, dict): def sample_search(self):
return {k: self._encode_tensor(v) for k, v in data.items()} return self._fixed_arc
return data
def sample_final(self):
def before_pass(self): return self._fixed_arc
self._unused_key = set(self._fixed_arc.keys())
def after_pass(self): def _encode_tensor(data, device):
if self._strict: if isinstance(data, list):
if self._unused_key: if all(map(lambda o: isinstance(o, bool), data)):
raise ValueError("{} are never used by the network. " return torch.tensor(data, dtype=torch.bool, device=device) # pylint: disable=not-callable
"Set strict=False if you want to disable this check.".format(self._unused_key)) else:
return torch.tensor(data, dtype=torch.float, device=device) # pylint: disable=not-callable
def _check_key(self, key): if isinstance(data, dict):
if key not in self._fixed_arc: return {k: _encode_tensor(v, device) for k, v in data.items()}
raise ValueError("\"{}\" is demanded by the network, but not found in saved architecture.".format(key)) return data
def on_calc_layer_choice_mask(self, mutable):
self._check_key(mutable.key) def apply_fixed_architecture(model, fixed_arc_path, device=None):
return self._fixed_arc[mutable.key] """
Load architecture from `fixed_arc_path` and apply to model.
def on_calc_input_choice_mask(self, mutable, tags):
self._check_key(mutable.key) Parameters
return self._fixed_arc[mutable.key] ----------
model: torch.nn.Module
Model with mutables.
fixed_arc_path: str
Path to the JSON that stores the architecture.
device: torch.device
Architecture weights will be transfered to `device`.
Returns
-------
FixedArchitecture
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if isinstance(fixed_arc_path, str):
with open(fixed_arc_path, "r") as f:
fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc, device)
architecture = FixedArchitecture(model, fixed_arc)
architecture.to(device)
architecture.reset()
import torch.nn as nn import torch.nn as nn
from nni.nas.utils import global_mutable_counting from nni.nas.pytorch.utils import global_mutable_counting
class Mutable(nn.Module): class Mutable(nn.Module):
...@@ -37,7 +37,7 @@ class Mutable(nn.Module): ...@@ -37,7 +37,7 @@ class Mutable(nn.Module):
self.__dict__["mutator"] = mutator self.__dict__["mutator"] = mutator
def forward(self, *inputs): def forward(self, *inputs):
raise NotImplementedError("Mutable forward must be implemented.") raise NotImplementedError
@property @property
def key(self): def key(self):
...@@ -51,9 +51,6 @@ class Mutable(nn.Module): ...@@ -51,9 +51,6 @@ class Mutable(nn.Module):
def name(self, name): def name(self, name):
self._name = name self._name = name
def similar(self, other):
return type(self) == type(other)
def _check_built(self): def _check_built(self):
if not hasattr(self, "mutator"): if not hasattr(self, "mutator"):
raise ValueError( raise ValueError(
...@@ -66,19 +63,17 @@ class Mutable(nn.Module): ...@@ -66,19 +63,17 @@ class Mutable(nn.Module):
class MutableScope(Mutable): class MutableScope(Mutable):
""" """
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope Mutable scope labels a subgraph/submodule to help mutators make better decisions.
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
corresponding events, and do status dump or update. and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
""" """
def __init__(self, key): def __init__(self, key):
super().__init__(key=key) super().__init__(key=key)
def build(self):
self.mutator.on_init_mutable_scope(self)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
try: try:
self._check_built()
self.mutator.enter_mutable_scope(self) self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs) return super().__call__(*args, **kwargs)
finally: finally:
...@@ -93,43 +88,92 @@ class LayerChoice(Mutable): ...@@ -93,43 +88,92 @@ class LayerChoice(Mutable):
self.reduction = reduction self.reduction = reduction
self.return_mask = return_mask self.return_mask = return_mask
def __len__(self):
return len(self.choices)
def forward(self, *inputs): def forward(self, *inputs):
out, mask = self.mutator.on_forward_layer_choice(self, *inputs) out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask: if self.return_mask:
return out, mask return out, mask
return out return out
def similar(self, other):
return type(self) == type(other) and self.length == other.length
class InputChoice(Mutable): class InputChoice(Mutable):
def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None): """
Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners,
use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to
know about `choose_from`.
The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones.
The keys are designed to be the keys of the sources. To help mutators make better decisions,
mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the
output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g.,
``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a
module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed.
"""
NO_KEY = ""
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
reduction="mean", return_mask=False, key=None):
"""
Initialization.
Parameters
----------
n_candidates: int
Number of inputs to choose from.
choose_from: list of str
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
number of empty string.
n_chosen: int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction: str
`mean`, `concat`, `sum` or `none`.
return_mask: bool
If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
key: str
Key of the input choice.
"""
super().__init__(key=key) super().__init__(key=key)
# precondition check
assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \
"must be not None."
if choose_from is not None and n_candidates is None:
n_candidates = len(choose_from)
elif choose_from is None and n_candidates is not None:
choose_from = [self.NO_KEY] * n_candidates
assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`."
assert n_candidates > 0, "Number of candidates must be greater than 0." assert n_candidates > 0, "Number of candidates must be greater than 0."
assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \
"than number of candidates."
self.n_candidates = n_candidates self.n_candidates = n_candidates
self.n_selected = n_selected self.choose_from = choose_from
self.n_chosen = n_chosen
self.reduction = reduction self.reduction = reduction
self.return_mask = return_mask self.return_mask = return_mask
def build(self): def forward(self, optional_inputs):
self.mutator.on_init_input_choice(self) """
Forward method of LayerChoice.
def forward(self, optional_inputs, tags=None):
Parameters
----------
optional_inputs: list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from`.
Returns
-------
tuple of torch.Tensor and torch.Tensor or torch.Tensor
"""
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), "Optional input list must be a list"
assert len(optional_inputs) == self.n_candidates, \ assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates." "Length of the input list must be equal to number of candidates."
if tags is None: out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
tags = [""] * self.n_candidates
else:
assert len(tags) == self.n_candidates, "Length of tags must be equal to number of candidates."
out, mask = self.mutator.on_forward_input_choice(self, optional_inputs, tags)
if self.return_mask: if self.return_mask:
return out, mask return out, mask
return out return out
def similar(self, other):
return type(self) == type(other) and \
self.n_candidates == other.n_candidates and self.n_selected and other.n_selected
from contextlib import contextmanager
import torch import torch
import torch.nn as nn
from nni.nas.pytorch.base_mutator import BaseMutator from nni.nas.pytorch.base_mutator import BaseMutator
class Mutator(BaseMutator, nn.Module): class Mutator(BaseMutator):
def export(self): def __init__(self, model):
if self._in_forward_pass: super().__init__(model)
raise RuntimeError("Still in forward pass. Exporting might induce incompleteness.")
if not self._cache:
raise RuntimeError("No running history found. You need to call your model at least once before exporting. "
"You might also want to check if there are no valid mutables in your model.")
return self._cache
@contextmanager
def forward_pass(self):
self._in_forward_pass = True
self._cache = dict() self._cache = dict()
self.before_pass()
try:
yield self
finally:
self.after_pass()
self._in_forward_pass = False
def before_pass(self): def sample_search(self):
pass """
Override to implement this method to iterate over mutables and make decisions.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
def sample_final(self):
"""
Override to implement this method to iterate over mutables and make decisions that is final
for export and retraining.
def after_pass(self): Returns
pass -------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
def _check_in_forward_pass(self): def reset(self):
if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass: """
raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call " Reset the mutator by call the `sample_search` to resample (for search).
"super().before_pass() and after_pass() in your override method?")
Returns
-------
None
"""
self._cache = self.sample_search()
def export(self):
"""
Resample (for final) and return results.
Returns
-------
dict
"""
return self.sample_final()
def on_forward_layer_choice(self, mutable, *inputs): def on_forward_layer_choice(self, mutable, *inputs):
""" """
Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified (either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`. in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
...@@ -54,18 +67,17 @@ class Mutator(BaseMutator, nn.Module): ...@@ -54,18 +67,17 @@ class Mutator(BaseMutator, nn.Module):
------- -------
tuple of torch.Tensor and torch.Tensor tuple of torch.Tensor and torch.Tensor
""" """
self._check_in_forward_pass()
def _map_fn(op, *inputs): def _map_fn(op, *inputs):
return op(*inputs) return op(*inputs)
mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable)) mask = self._get_decision(mutable)
assert len(mask) == len(mutable.choices)
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask) out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list, tags): def on_forward_input_choice(self, mutable, tensor_list):
""" """
Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `tags` On default, this method calls :meth:`on_calc_input_choice_mask` with `tags`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
...@@ -81,48 +93,11 @@ class Mutator(BaseMutator, nn.Module): ...@@ -81,48 +93,11 @@ class Mutator(BaseMutator, nn.Module):
------- -------
tuple of torch.Tensor and torch.Tensor tuple of torch.Tensor and torch.Tensor
""" """
self._check_in_forward_pass() mask = self._get_decision(mutable)
mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, tags)) assert len(mask) == mutable.n_candidates
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask) out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask return self._tensor_reduction(mutable.reduction, out), mask
def on_calc_layer_choice_mask(self, mutable):
"""
Recommended to override. Calculate a mask tensor for a layer choice.
Parameters
----------
mutable: LayerChoice
Corresponding layer choice object.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise NotImplementedError("Layer choice mask calculation must be implemented")
def on_calc_input_choice_mask(self, mutable, tags):
"""
Recommended to override. Calculate a mask tensor for a input choice.
Parameters
----------
mutable: InputChoice
Corresponding input choice object.
tags: list of string
The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise NotImplementedError("Input choice mask calculation must be implemented")
def _select_with_mask(self, map_fn, candidates, mask): def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type(): if "BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m] out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
...@@ -146,3 +121,20 @@ class Mutator(BaseMutator, nn.Module): ...@@ -146,3 +121,20 @@ class Mutator(BaseMutator, nn.Module):
if reduction_type == "concat": if reduction_type == "concat":
return torch.cat(tensor_list, dim=1) return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type)) raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def _get_decision(self, mutable):
"""
By default, this method checks whether `mutable.key` is already in the decision cache,
and returns the result without double-check.
Parameters
----------
mutable: Mutable
Returns
-------
any
"""
if mutable.key not in self._cache:
raise ValueError("\"{}\" not found in decision cache.".format(mutable.key))
return self._cache[mutable.key]
...@@ -11,14 +11,14 @@ from nni.nas.pytorch.mutables import LayerChoice ...@@ -11,14 +11,14 @@ from nni.nas.pytorch.mutables import LayerChoice
class PdartsMutator(DartsMutator): class PdartsMutator(DartsMutator):
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches=None): def __init__(self, pdarts_epoch_index, pdarts_num_to_drop, switches=None):
self.pdarts_epoch_index = pdarts_epoch_index self.pdarts_epoch_index = pdarts_epoch_index
self.pdarts_num_to_drop = pdarts_num_to_drop self.pdarts_num_to_drop = pdarts_num_to_drop
self.switches = switches self.switches = switches
super(PdartsMutator, self).__init__(model) super(PdartsMutator, self).__init__()
def before_build(self, model): def before_build(self):
self.choices = nn.ParameterDict() self.choices = nn.ParameterDict()
if self.switches is None: if self.switches is None:
self.switches = {} self.switches = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment