Commit 77e91e8b authored by Yuge Zhang's avatar Yuge Zhang Committed by Chi Song
Browse files

Extract controller from mutator to make offline decisions (#1758)

parent 9dda5370
......@@ -2,7 +2,7 @@ import torch
import torch.nn as nn
import ops
from nni.nas.pytorch import mutables, darts
from nni.nas.pytorch import mutables
class AuxiliaryHead(nn.Module):
......@@ -31,12 +31,14 @@ class AuxiliaryHead(nn.Module):
return logits
class Node(darts.DartsNode):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, drop_path_prob=0.):
super().__init__(node_id, limitation=2)
class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
mutables.LayerChoice(
[
......@@ -48,18 +50,19 @@ class Node(darts.DartsNode):
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False),
],
key="{}_p{}".format(node_id, i)))
self.drop_path = ops.DropPath_(drop_path_prob)
key=choice_keys[-1]))
self.drop_path = ops.DropPath_()
self.input_switch = mutables.InputChoice(choose_from=choice_keys, n_chosen=2, key="{}_switch".format(node_id))
def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
out = [op(node) for op, node in zip(self.ops, prev_nodes)]
return sum(self.drop_path(o) for o in out if o is not None)
return self.input_switch(out)
class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, drop_path_prob=0.):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
......@@ -74,10 +77,9 @@ class Cell(nn.Module):
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(self.n_nodes):
self.mutable_ops.append(Node("r{:d}_n{}".format(reduction, depth),
depth + 2, channels, 2 if reduction else 0,
drop_path_prob=drop_path_prob))
for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth, channels, 2 if reduction else 0))
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
......@@ -93,7 +95,7 @@ class Cell(nn.Module):
class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
stem_multiplier=3, auxiliary=False, drop_path_prob=0.):
stem_multiplier=3, auxiliary=False):
super().__init__()
self.in_channels = in_channels
self.channels = channels
......@@ -120,7 +122,7 @@ class CNN(nn.Module):
c_cur *= 2
reduction = True
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, drop_path_prob=drop_path_prob)
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out
......@@ -147,3 +149,8 @@ class CNN(nn.Module):
if aux_logits is not None:
return logits, aux_logits
return logits
def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath_):
module.p = p
import logging
from argparse import ArgumentParser
import torch
import torch.nn as nn
import datasets
import utils
from model import CNN
from nni.nas.pytorch.fixed import apply_fixed_architecture
from nni.nas.pytorch.utils import AverageMeter
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def train(config, train_loader, model, optimizer, criterion, epoch):
top1 = AverageMeter("top1")
top5 = AverageMeter("top5")
losses = AverageMeter("losses")
cur_step = epoch * len(train_loader)
cur_lr = optimizer.param_groups[0]['lr']
logger.info("Epoch %d LR %.6f", epoch, cur_lr)
model.train()
for step, (x, y) in enumerate(train_loader):
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
bs = x.size(0)
optimizer.zero_grad()
logits, aux_logits = model(x)
loss = criterion(logits, y)
if config.aux_weight > 0.:
loss += config.aux_weight * criterion(aux_logits, y)
loss.backward()
# gradient clipping
nn.utils.clip_grad_norm_(model.parameters(), config.grad_clip)
optimizer.step()
accuracy = utils.accuracy(logits, y, topk=(1, 5))
losses.update(loss.item(), bs)
top1.update(accuracy["acc1"], bs)
top5.update(accuracy["acc5"], bs)
if step % config.log_frequency == 0 or step == len(train_loader) - 1:
logger.info(
"Train: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
epoch + 1, config.epochs, step, len(train_loader) - 1, losses=losses,
top1=top1, top5=top5))
cur_step += 1
logger.info("Train: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))
def validate(config, valid_loader, model, criterion, epoch, cur_step):
top1 = AverageMeter("top1")
top5 = AverageMeter("top5")
losses = AverageMeter("losses")
model.eval()
with torch.no_grad():
for step, (X, y) in enumerate(valid_loader):
X, y = X.to(device, non_blocking=True), y.to(device, non_blocking=True)
N = X.size(0)
logits = model(X)
loss = criterion(logits, y)
accuracy = utils.accuracy(logits, y, topk=(1, 5))
losses.update(loss.item(), N)
top1.update(accuracy["acc1"], N)
top5.update(accuracy["acc5"], N)
if step % config.log_frequency == 0 or step == len(valid_loader) - 1:
logger.info(
"Valid: [{:3d}/{}] Step {:03d}/{:03d} Loss {losses.avg:.3f} "
"Prec@(1,5) ({top1.avg:.1%}, {top5.avg:.1%})".format(
epoch + 1, config.epochs, step, len(valid_loader) - 1, losses=losses,
top1=top1, top5=top5))
logger.info("Valid: [{:3d}/{}] Final Prec@1 {:.4%}".format(epoch + 1, config.epochs, top1.avg))
return top1.avg
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=20, type=int)
parser.add_argument("--batch-size", default=96, type=int)
parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--epochs", default=600, type=int)
parser.add_argument("--aux-weight", default=0.4, type=float)
parser.add_argument("--drop-path-prob", default=0.2, type=float)
parser.add_argument("--workers", default=4)
parser.add_argument("--grad-clip", default=5., type=float)
parser.add_argument("--arc-checkpoint", default="./checkpoints/epoch_0.json")
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10", cutout_length=16)
model = CNN(32, 3, 36, 10, args.layers, auxiliary=True)
apply_fixed_architecture(model, args.arc_checkpoint, device=device)
criterion = nn.CrossEntropyLoss()
model.to(device)
criterion.to(device)
optimizer = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs, eta_min=1E-6)
train_loader = torch.utils.data.DataLoader(dataset_train,
batch_size=args.batch_size,
shuffle=True,
num_workers=args.workers,
pin_memory=True)
valid_loader = torch.utils.data.DataLoader(dataset_valid,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.workers,
pin_memory=True)
best_top1 = 0.
for epoch in range(args.epochs):
drop_prob = args.drop_path_prob * epoch / args.epochs
model.drop_path_prob(drop_prob)
# training
train(args, train_loader, model, optimizer, criterion, epoch)
# validation
cur_step = (epoch + 1) * len(train_loader)
top1 = validate(args, valid_loader, model, criterion, epoch, cur_step)
best_top1 = max(best_top1, top1)
lr_scheduler.step()
logger.info("Final best Prec@1 = {:.4%}".format(best_top1))
......@@ -13,7 +13,7 @@ from utils import accuracy
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=8, type=int)
parser.add_argument("--batch-size", default=96, type=int)
parser.add_argument("--batch-size", default=64, type=int)
parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--epochs", default=50, type=int)
args = parser.parse_args()
......@@ -36,4 +36,4 @@ if __name__ == "__main__":
batch_size=args.batch_size,
log_frequency=args.log_frequency,
callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
trainer.train_and_validate()
trainer.train()
......@@ -6,7 +6,7 @@ from ops import FactorizedReduce, ConvBranch, PoolBranch
class ENASLayer(mutables.MutableScope):
def __init__(self, key, num_prev_layers, in_filters, out_filters):
def __init__(self, key, prev_labels, in_filters, out_filters):
super().__init__(key)
self.in_filters = in_filters
self.out_filters = out_filters
......@@ -18,16 +18,16 @@ class ENASLayer(mutables.MutableScope):
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if num_prev_layers > 0:
self.skipconnect = mutables.InputChoice(num_prev_layers, n_selected=None, reduction="sum")
if len(prev_labels) > 0:
self.skipconnect = mutables.InputChoice(choose_from=prev_labels, n_chosen=None, reduction="sum")
else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
def forward(self, prev_layers, prev_labels):
def forward(self, prev_layers):
out = self.mutable(prev_layers[-1])
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[:-1], tags=prev_labels)
connection = self.skipconnect(prev_layers[:-1])
if connection is not None:
out += connection
return self.batch_norm(out)
......@@ -53,11 +53,12 @@ class GeneralNetwork(nn.Module):
self.layers = nn.ModuleList()
self.pool_layers = nn.ModuleList()
labels = []
for layer_id in range(self.num_layers):
labels.append("layer_{}".format(layer_id))
if layer_id in self.pool_layers_idx:
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
self.layers.append(ENASLayer("layer_{}".format(layer_id), layer_id,
self.out_filters, self.out_filters))
self.layers.append(ENASLayer(labels[-1], labels[:-1], self.out_filters, self.out_filters))
self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(self.out_filters, self.num_classes)
......@@ -66,12 +67,11 @@ class GeneralNetwork(nn.Module):
bs = x.size(0)
cur = self.stem(x)
layers, labels = [cur], []
layers = [cur]
for layer_id in range(self.num_layers):
cur = self.layers[layer_id](layers, labels)
cur = self.layers[layer_id](layers)
layers.append(cur)
labels.append(self.layers[layer_id].key)
if layer_id in self.pool_layers_idx:
for i, layer in enumerate(layers):
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
......
......@@ -32,9 +32,9 @@ class AuxiliaryHead(nn.Module):
class Cell(nn.Module):
def __init__(self, cell_name, num_prev_layers, channels):
def __init__(self, cell_name, prev_labels, channels):
super().__init__()
self.input_choice = mutables.InputChoice(num_prev_layers, n_selected=1, return_mask=True,
self.input_choice = mutables.InputChoice(choose_from=prev_labels, n_chosen=1, return_mask=True,
key=cell_name + "_input")
self.op_choice = mutables.LayerChoice([
SepConvBN(channels, channels, 3, 1),
......@@ -44,21 +44,21 @@ class Cell(nn.Module):
nn.Identity()
], key=cell_name + "_op")
def forward(self, prev_layers, prev_labels):
chosen_input, chosen_mask = self.input_choice(prev_layers, tags=prev_labels)
def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers)
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask
class Node(mutables.MutableScope):
def __init__(self, node_name, num_prev_layers, channels):
def __init__(self, node_name, prev_node_names, channels):
super().__init__(node_name)
self.cell_x = Cell(node_name + "_x", num_prev_layers, channels)
self.cell_y = Cell(node_name + "_y", num_prev_layers, channels)
self.cell_x = Cell(node_name + "_x", prev_node_names, channels)
self.cell_y = Cell(node_name + "_y", prev_node_names, channels)
def forward(self, prev_layers, prev_labels):
out_x, mask_x = self.cell_x(prev_layers, prev_labels)
out_y, mask_y = self.cell_y(prev_layers, prev_labels)
def forward(self, prev_layers):
out_x, mask_x = self.cell_x(prev_layers)
out_y, mask_y = self.cell_y(prev_layers)
return out_x + out_y, mask_x | mask_y
......@@ -93,8 +93,11 @@ class ENASLayer(nn.Module):
self.num_nodes = num_nodes
name_prefix = "reduce" if reduction else "normal"
self.nodes = nn.ModuleList([Node("{}_node_{}".format(name_prefix, i),
i + 2, out_channels) for i in range(num_nodes)])
self.nodes = nn.ModuleList()
node_labels = [mutables.InputChoice.NO_KEY, mutables.InputChoice.NO_KEY]
for i in range(num_nodes):
node_labels.append("{}_node_{}".format(name_prefix, i))
self.nodes.append(Node(node_labels[-1], node_labels[:-1], out_channels))
self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True)
self.bn = nn.BatchNorm2d(out_channels, affine=False)
self.reset_parameters()
......@@ -106,13 +109,11 @@ class ENASLayer(nn.Module):
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)
prev_nodes_out = [pprev_, prev_]
prev_nodes_labels = ["prev1", "prev2"]
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
for i in range(self.num_nodes):
node_out, mask = self.nodes[i](prev_nodes_out, prev_nodes_labels)
node_out, mask = self.nodes[i](prev_nodes_out)
nodes_used_mask[:mask.size(0)] |= mask
prev_nodes_out.append(node_out)
prev_nodes_labels.append(self.nodes[i].key)
unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1)
unused_nodes = F.relu(unused_nodes)
......
......@@ -13,7 +13,7 @@ from utils import accuracy, reward_accuracy
if __name__ == "__main__":
parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
args = parser.parse_args()
......@@ -43,5 +43,6 @@ if __name__ == "__main__":
num_epochs=num_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency)
trainer.train_and_validate()
log_frequency=args.log_frequency,
mutator=mutator)
trainer.train()
import logging
import torch.nn as nn
from nni.nas.pytorch.mutables import Mutable
from nni.nas.pytorch.mutables import Mutable, MutableScope, InputChoice
from nni.nas.pytorch.utils import StructuredMutableTreeNode
logger = logging.getLogger(__name__)
class BaseMutator(nn.Module):
"""
A mutator is responsible for mutating a graph by obtaining the search space from the network and implementing
callbacks that are called in ``forward`` in Mutables.
"""
def __init__(self, model):
super().__init__()
self.__dict__["model"] = model
self.before_parse_search_space()
self._parse_search_space()
self.after_parse_search_space()
def before_parse_search_space(self):
pass
def after_parse_search_space(self):
pass
def _parse_search_space(self):
for name, mutable, _ in self.named_mutables(distinct=False):
mutable.name = name
mutable.set_mutator(self)
self._structured_mutables = self._parse_search_space(self.model)
def named_mutables(self, root=None, distinct=True):
def _parse_search_space(self, module, root=None, prefix="", memo=None, nested_detection=None):
if memo is None:
memo = set()
if root is None:
root = self.model
# if distinct is true, the method will filter out those with duplicated keys
key2module = dict()
for name, module in root.named_modules():
root = StructuredMutableTreeNode(None)
if module not in memo:
memo.add(module)
if isinstance(module, Mutable):
module_distinct = False
if module.key in key2module:
assert key2module[module.key].similar(module), \
"Mutable \"{}\" that share the same key must be similar to each other".format(module.key)
else:
module_distinct = True
key2module[module.key] = module
if distinct:
if module_distinct:
yield name, module
else:
yield name, module, module_distinct
def __setattr__(self, key, value):
if key in ["model", "net", "network"]:
logger.warning("Think twice if you are including the network into mutator.")
return super().__setattr__(key, value)
if nested_detection is not None:
raise RuntimeError("Cannot have nested search space. Error at {} in {}"
.format(module, nested_detection))
module.name = prefix
module.set_mutator(self)
root = root.add_child(module)
if not isinstance(module, MutableScope):
nested_detection = module
if isinstance(module, InputChoice):
for k in module.choose_from:
if k != InputChoice.NO_KEY and k not in [m.key for m in memo if isinstance(m, Mutable)]:
raise RuntimeError("'{}' required by '{}' not found in keys that appeared before, and is not NO_KEY."
.format(k, module.key))
for name, submodule in module._modules.items():
if submodule is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
self._parse_search_space(submodule, root, submodule_prefix, memo=memo,
nested_detection=nested_detection)
return root
@property
def mutables(self):
return self._structured_mutables
@property
def forward(self, *inputs):
raise NotImplementedError("Mutator is not forward-able")
raise RuntimeError("Forward is undefined for mutators.")
def enter_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is entered.
Parameters
----------
mutable_scope: MutableScope
Returns
-------
None
"""
pass
def exit_mutable_scope(self, mutable_scope):
"""
Callback when forward of a MutableScope is exited.
Parameters
----------
mutable_scope: MutableScope
Returns
-------
None
"""
pass
def on_forward_layer_choice(self, mutable, *inputs):
"""
Callbacks of forward in LayerChoice.
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
output tensor and mask
"""
raise NotImplementedError
def on_forward_input_choice(self, mutable, tensor_list, tags):
def on_forward_input_choice(self, mutable, tensor_list):
"""
Callbacks of forward in InputChoice.
Parameters
----------
mutable: InputChoice
tensor_list: list of torch.Tensor
Returns
-------
tuple of torch.Tensor and torch.Tensor
output tensor and mask
"""
raise NotImplementedError
def export(self):
"""
Export the data of all decisions. This should output the decisions of all the mutables, so that the whole
network can be fully determined with these decisions for further training from scratch.
Returns
-------
dict
"""
raise NotImplementedError
......@@ -12,5 +12,9 @@ class BaseTrainer(ABC):
raise NotImplementedError
@abstractmethod
def train_and_validate(self):
def export(self, file):
raise NotImplementedError
@abstractmethod
def checkpoint(self):
raise NotImplementedError
import json
import logging
import os
import torch
_logger = logging.getLogger(__name__)
......@@ -44,26 +41,11 @@ class LearningRateScheduler(Callback):
class ArchitectureCheckpoint(Callback):
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
olist = o.tolist()
if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
_logger.warning("Every element in %s is either 0 or 1. "
"You might consider convert it into bool.", olist)
return olist
return super().default(o)
def __init__(self, checkpoint_dir, every="epoch"):
super().__init__()
assert every == "epoch"
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def _export_to_file(self, file):
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=self.TorchTensorEncoder)
def on_epoch_end(self, epoch):
self._export_to_file(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)))
self.trainer.export(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)))
from .mutator import DartsMutator
from .trainer import DartsTrainer
\ No newline at end of file
from .scope import DartsNode
\ No newline at end of file
......@@ -2,35 +2,47 @@ import torch
from torch import nn as nn
from torch.nn import functional as F
from nni.nas.pytorch.mutables import LayerChoice
from nni.nas.pytorch.mutator import Mutator
from .scope import DartsNode
from nni.nas.pytorch.mutables import LayerChoice, InputChoice
class DartsMutator(Mutator):
def after_parse_search_space(self):
def __init__(self, model):
super().__init__(model)
self.choices = nn.ParameterDict()
for _, mutable in self.named_mutables():
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(len(mutable) + 1))
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length + 1))
def on_calc_layer_choice_mask(self, mutable: LayerChoice):
return F.softmax(self.choices[mutable.key], dim=-1)[:-1]
def device(self):
for v in self.choices.values():
return v.device
def export(self):
result = super().export()
for _, darts_node in self.named_mutables():
if isinstance(darts_node, DartsNode):
keys, edges_max = [], [] # key of all the layer choices in current node, and their best edge weight
for _, choice in self.named_mutables(darts_node):
if isinstance(choice, LayerChoice):
keys.append(choice.key)
max_val, index = torch.max(result[choice.key], 0)
edges_max.append(max_val)
result[choice.key] = F.one_hot(index, num_classes=len(result[choice.key])).view(-1).bool()
_, topk_edge_indices = torch.topk(torch.tensor(edges_max).view(-1), darts_node.limitation) # pylint: disable=not-callable
for i, key in enumerate(keys):
if i not in topk_edge_indices:
result[key] = torch.zeros_like(result[key])
def sample_search(self):
result = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
result[mutable.key] = F.softmax(self.choices[mutable.key], dim=-1)[:-1]
elif isinstance(mutable, InputChoice):
result[mutable.key] = torch.ones(mutable.n_candidates, dtype=torch.bool, device=self.device())
return result
def sample_final(self):
result = dict()
edges_max = dict()
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
max_val, index = torch.max(F.softmax(self.choices[mutable.key], dim=-1)[:-1], 0)
edges_max[mutable.key] = max_val
result[mutable.key] = F.one_hot(index, num_classes=mutable.length).view(-1).bool()
for mutable in self.mutables:
if isinstance(mutable, InputChoice):
weights = torch.tensor([edges_max.get(src_key, 0.) for src_key in mutable.choose_from]) # pylint: disable=not-callable
_, topk_edge_indices = torch.topk(weights, mutable.n_chosen or mutable.n_candidates)
selected_multihot = []
for i, src_key in enumerate(mutable.choose_from):
if i not in topk_edge_indices and src_key in result:
result[src_key] = torch.zeros_like(result[src_key]) # clear this choice to optimize calc graph
selected_multihot.append(i in topk_edge_indices)
result[mutable.key] = torch.tensor(selected_multihot, dtype=torch.bool, device=self.device()) # pylint: disable=not-callable
return result
from nni.nas.pytorch.mutables import MutableScope
class DartsNode(MutableScope):
"""
At most `limitation` choice is activated in a `DartsNode` when exporting.
"""
def __init__(self, key, limitation):
super().__init__(key)
self.limitation = limitation
......@@ -4,7 +4,7 @@ import torch
from torch import nn as nn
from nni.nas.pytorch.trainer import Trainer
from nni.nas.utils import AverageMeterGroup
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import DartsMutator
......@@ -13,9 +13,9 @@ class DartsTrainer(Trainer):
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
callbacks=None):
super().__init__(model, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency,
mutator if mutator is not None else DartsMutator(model), callbacks)
super().__init__(model, mutator if mutator is not None else DartsMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999),
weight_decay=1.0E-3)
n_train = len(self.dataset_train)
......@@ -31,6 +31,9 @@ class DartsTrainer(Trainer):
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_one_epoch(self, epoch):
self.model.train()
......@@ -47,7 +50,7 @@ class DartsTrainer(Trainer):
# phase 1. child network step
self.optimizer.zero_grad()
with self.mutator.forward_pass():
self.mutator.reset()
logits = self.model(trn_X)
loss = self.loss(logits, trn_y)
loss.backward()
......@@ -76,9 +79,9 @@ class DartsTrainer(Trainer):
self.mutator.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (X, y) in enumerate(self.valid_loader):
self.mutator.reset()
for step, (X, y) in enumerate(self.test_loader):
X, y = X.to(self.device), y.to(self.device)
with self.mutator.forward_pass():
logits = self.model(X)
metrics = self.metrics(logits, y)
meters.update(metrics)
......@@ -93,7 +96,7 @@ class DartsTrainer(Trainer):
v_model: backup model before this step
lr: learning rate for virtual gradient step (same as net lr)
"""
with self.mutator.forward_pass():
self.mutator.reset()
loss = self.loss(self.model(val_X), val_y)
w_model = tuple(self.model.parameters())
w_ctrl = tuple(self.mutator.parameters())
......@@ -125,7 +128,7 @@ class DartsTrainer(Trainer):
for p, d in zip(self.model.parameters(), dw):
p += eps * d
with self.mutator.forward_pass():
self.mutator.reset()
loss = self.loss(self.model(trn_X), trn_y)
if e > 0:
dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) }
......
......@@ -2,8 +2,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutables import LayerChoice
from nni.nas.pytorch.mutator import Mutator
from nni.nas.pytorch.mutables import LayerChoice, InputChoice, MutableScope
class StackedLSTMCell(nn.Module):
......@@ -27,15 +27,14 @@ class StackedLSTMCell(nn.Module):
class EnasMutator(Mutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4, branch_bias=0.25):
super().__init__(model)
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.cell_exit_extra_step = cell_exit_extra_step
self.skip_target = skip_target
self.branch_bias = branch_bias
super().__init__(model)
def before_parse_search_space(self):
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
......@@ -45,9 +44,8 @@ class EnasMutator(Mutator):
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict()
def after_parse_search_space(self):
self.max_layer_choice = 0
for _, mutable in self.named_mutables():
for mutable in self.mutables:
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length
......@@ -64,8 +62,29 @@ class EnasMutator(Mutator):
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
def before_pass(self):
super().before_pass()
def sample_search(self):
self._initialize()
self._sample(self.mutables)
return self._choices
def sample_final(self):
return self.sample_search()
def _sample(self, tree):
mutable = tree.mutable
if isinstance(mutable, LayerChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_layer_choice(mutable)
elif isinstance(mutable, InputChoice) and mutable.key not in self._choices:
self._choices[mutable.key] = self._sample_input_choice(mutable)
for child in tree.children:
self._sample(child)
if isinstance(mutable, MutableScope) and mutable.key not in self._anchors_hid:
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable.key)
def _initialize(self):
self._choices = dict()
self._anchors_hid = dict()
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
......@@ -84,7 +103,7 @@ class EnasMutator(Mutator):
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
def on_calc_layer_choice_mask(self, mutable):
def _sample_layer_choice(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.tanh_constant is not None:
......@@ -94,14 +113,14 @@ class EnasMutator(Mutator):
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach()
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy)
self._inputs = self.embedding(branch_id)
return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
def on_calc_input_choice_mask(self, mutable, tags):
def _sample_input_choice(self, mutable):
query, anchors = [], []
for label in tags:
for label in mutable.choose_from:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
......@@ -113,8 +132,8 @@ class EnasMutator(Mutator):
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
if mutable.n_selected is None:
logit = torch.cat([-query, query], 1)
if mutable.n_chosen is None:
logit = torch.cat([-query, query], 1) # pylint: disable=invalid-unary-operand-type
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
......@@ -123,19 +142,14 @@ class EnasMutator(Mutator):
log_prob = self.cross_entropy_loss(logit, skip)
self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
else:
assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS."
assert mutable.n_chosen == 1, "Input choice must select exactly one or any in ENAS."
logit = query.view(1, -1)
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip = F.one_hot(index).view(-1)
skip = F.one_hot(index, num_classes=mutable.n_candidates).view(-1)
log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()]
self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach()
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += torch.sum(entropy)
return skip.bool()
def exit_mutable_scope(self, mutable_scope):
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable_scope.key)
......@@ -2,7 +2,7 @@ import torch
import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer
from nni.nas.utils import AverageMeterGroup
from nni.nas.pytorch.utils import AverageMeterGroup
from .mutator import EnasMutator
......@@ -12,9 +12,9 @@ class EnasTrainer(Trainer):
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4):
super().__init__(model, loss, metrics, optimizer, num_epochs,
dataset_train, dataset_valid, batch_size, workers, device, log_frequency,
mutator if mutator is not None else EnasMutator(model), callbacks)
super().__init__(model, mutator if mutator is not None else EnasMutator(model),
loss, metrics, optimizer, num_epochs, dataset_train, dataset_valid,
batch_size, workers, device, log_frequency, callbacks)
self.reward_function = reward_function
self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
......@@ -52,7 +52,8 @@ class EnasTrainer(Trainer):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
with self.mutator.forward_pass():
with torch.no_grad():
self.mutator.reset()
logits = self.model(x)
if isinstance(logits, tuple):
......@@ -81,7 +82,8 @@ class EnasTrainer(Trainer):
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
with self.mutator.forward_pass():
self.mutator.reset()
with torch.no_grad():
logits = self.model(x)
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
......@@ -107,7 +109,7 @@ class EnasTrainer(Trainer):
self.mutator_optim.zero_grad()
if self.log_frequency is not None and step % self.log_frequency == 0:
print("Mutator Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs,
print("RL Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs,
mutator_step // self.mutator_steps_aggregate,
self.mutator_steps, meters))
mutator_step += 1
......
......@@ -2,10 +2,12 @@ import json
import torch
from nni.nas.pytorch.mutables import MutableScope
from nni.nas.pytorch.mutator import Mutator
class FixedArchitecture(Mutator):
def __init__(self, model, fixed_arc, strict=True):
"""
Initialize a fixed architecture mutator.
......@@ -20,39 +22,57 @@ class FixedArchitecture(Mutator):
Force everything that appears in `fixed_arc` to be used at least once.
"""
super().__init__(model)
if isinstance(fixed_arc, str):
with open(fixed_arc, "r") as f:
fixed_arc = json.load(f.read())
self._fixed_arc = fixed_arc
self._strict = strict
def _encode_tensor(self, data):
mutable_keys = set([mutable.key for mutable in self.mutables if not isinstance(mutable, MutableScope)])
fixed_arc_keys = set(self._fixed_arc.keys())
if fixed_arc_keys - mutable_keys:
raise RuntimeError("Unexpected keys found in fixed architecture: {}.".format(fixed_arc_keys - mutable_keys))
if mutable_keys - fixed_arc_keys:
raise RuntimeError("Missing keys in fixed architecture: {}.".format(mutable_keys - fixed_arc_keys))
def sample_search(self):
return self._fixed_arc
def sample_final(self):
return self._fixed_arc
def _encode_tensor(data, device):
if isinstance(data, list):
if all(map(lambda o: isinstance(o, bool), data)):
return torch.tensor(data, dtype=torch.bool) # pylint: disable=not-callable
return torch.tensor(data, dtype=torch.bool, device=device) # pylint: disable=not-callable
else:
return torch.tensor(data, dtype=torch.float) # pylint: disable=not-callable
return torch.tensor(data, dtype=torch.float, device=device) # pylint: disable=not-callable
if isinstance(data, dict):
return {k: self._encode_tensor(v) for k, v in data.items()}
return {k: _encode_tensor(v, device) for k, v in data.items()}
return data
def before_pass(self):
self._unused_key = set(self._fixed_arc.keys())
def after_pass(self):
if self._strict:
if self._unused_key:
raise ValueError("{} are never used by the network. "
"Set strict=False if you want to disable this check.".format(self._unused_key))
def apply_fixed_architecture(model, fixed_arc_path, device=None):
"""
Load architecture from `fixed_arc_path` and apply to model.
def _check_key(self, key):
if key not in self._fixed_arc:
raise ValueError("\"{}\" is demanded by the network, but not found in saved architecture.".format(key))
Parameters
----------
model: torch.nn.Module
Model with mutables.
fixed_arc_path: str
Path to the JSON that stores the architecture.
device: torch.device
Architecture weights will be transfered to `device`.
def on_calc_layer_choice_mask(self, mutable):
self._check_key(mutable.key)
return self._fixed_arc[mutable.key]
Returns
-------
FixedArchitecture
"""
def on_calc_input_choice_mask(self, mutable, tags):
self._check_key(mutable.key)
return self._fixed_arc[mutable.key]
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if isinstance(fixed_arc_path, str):
with open(fixed_arc_path, "r") as f:
fixed_arc = json.load(f)
fixed_arc = _encode_tensor(fixed_arc, device)
architecture = FixedArchitecture(model, fixed_arc)
architecture.to(device)
architecture.reset()
import torch.nn as nn
from nni.nas.utils import global_mutable_counting
from nni.nas.pytorch.utils import global_mutable_counting
class Mutable(nn.Module):
......@@ -37,7 +37,7 @@ class Mutable(nn.Module):
self.__dict__["mutator"] = mutator
def forward(self, *inputs):
raise NotImplementedError("Mutable forward must be implemented.")
raise NotImplementedError
@property
def key(self):
......@@ -51,9 +51,6 @@ class Mutable(nn.Module):
def name(self, name):
self._name = name
def similar(self, other):
return type(self) == type(other)
def _check_built(self):
if not hasattr(self, "mutator"):
raise ValueError(
......@@ -66,19 +63,17 @@ class Mutable(nn.Module):
class MutableScope(Mutable):
"""
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch
corresponding events, and do status dump or update.
Mutable scope labels a subgraph/submodule to help mutators make better decisions.
Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
"""
def __init__(self, key):
super().__init__(key=key)
def build(self):
self.mutator.on_init_mutable_scope(self)
def __call__(self, *args, **kwargs):
try:
self._check_built()
self.mutator.enter_mutable_scope(self)
return super().__call__(*args, **kwargs)
finally:
......@@ -93,43 +88,92 @@ class LayerChoice(Mutable):
self.reduction = reduction
self.return_mask = return_mask
def __len__(self):
return len(self.choices)
def forward(self, *inputs):
out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
if self.return_mask:
return out, mask
return out
def similar(self, other):
return type(self) == type(other) and self.length == other.length
class InputChoice(Mutable):
def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None):
"""
Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners,
use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to
know about `choose_from`.
The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones.
The keys are designed to be the keys of the sources. To help mutators make better decisions,
mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the
output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g.,
``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a
module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed.
"""
NO_KEY = ""
def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
reduction="mean", return_mask=False, key=None):
"""
Initialization.
Parameters
----------
n_candidates: int
Number of inputs to choose from.
choose_from: list of str
List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
number of empty string.
n_chosen: int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction: str
`mean`, `concat`, `sum` or `none`.
return_mask: bool
If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
key: str
Key of the input choice.
"""
super().__init__(key=key)
# precondition check
assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \
"must be not None."
if choose_from is not None and n_candidates is None:
n_candidates = len(choose_from)
elif choose_from is None and n_candidates is not None:
choose_from = [self.NO_KEY] * n_candidates
assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`."
assert n_candidates > 0, "Number of candidates must be greater than 0."
assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \
"than number of candidates."
self.n_candidates = n_candidates
self.n_selected = n_selected
self.choose_from = choose_from
self.n_chosen = n_chosen
self.reduction = reduction
self.return_mask = return_mask
def build(self):
self.mutator.on_init_input_choice(self)
def forward(self, optional_inputs, tags=None):
def forward(self, optional_inputs):
"""
Forward method of LayerChoice.
Parameters
----------
optional_inputs: list or dict
Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
`choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as
`choose_from`.
Returns
-------
tuple of torch.Tensor and torch.Tensor or torch.Tensor
"""
optional_input_list = optional_inputs
if isinstance(optional_inputs, dict):
optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
assert isinstance(optional_input_list, list), "Optional input list must be a list"
assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates."
if tags is None:
tags = [""] * self.n_candidates
else:
assert len(tags) == self.n_candidates, "Length of tags must be equal to number of candidates."
out, mask = self.mutator.on_forward_input_choice(self, optional_inputs, tags)
out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
if self.return_mask:
return out, mask
return out
def similar(self, other):
return type(self) == type(other) and \
self.n_candidates == other.n_candidates and self.n_selected and other.n_selected
from contextlib import contextmanager
import torch
import torch.nn as nn
from nni.nas.pytorch.base_mutator import BaseMutator
class Mutator(BaseMutator, nn.Module):
class Mutator(BaseMutator):
def export(self):
if self._in_forward_pass:
raise RuntimeError("Still in forward pass. Exporting might induce incompleteness.")
if not self._cache:
raise RuntimeError("No running history found. You need to call your model at least once before exporting. "
"You might also want to check if there are no valid mutables in your model.")
return self._cache
@contextmanager
def forward_pass(self):
self._in_forward_pass = True
def __init__(self, model):
super().__init__(model)
self._cache = dict()
self.before_pass()
try:
yield self
finally:
self.after_pass()
self._in_forward_pass = False
def before_pass(self):
pass
def sample_search(self):
"""
Override to implement this method to iterate over mutables and make decisions.
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
def sample_final(self):
"""
Override to implement this method to iterate over mutables and make decisions that is final
for export and retraining.
def after_pass(self):
pass
Returns
-------
dict
A mapping from key of mutables to decisions.
"""
raise NotImplementedError
def _check_in_forward_pass(self):
if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass:
raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
"super().before_pass() and after_pass() in your override method?")
def reset(self):
"""
Reset the mutator by call the `sample_search` to resample (for search).
Returns
-------
None
"""
self._cache = self.sample_search()
def export(self):
"""
Resample (for final) and return results.
Returns
-------
dict
"""
return self.sample_final()
def on_forward_layer_choice(self, mutable, *inputs):
"""
Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy specified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
......@@ -54,18 +67,17 @@ class Mutator(BaseMutator, nn.Module):
-------
tuple of torch.Tensor and torch.Tensor
"""
self._check_in_forward_pass()
def _map_fn(op, *inputs):
return op(*inputs)
mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable))
mask = self._get_decision(mutable)
assert len(mask) == len(mutable.choices)
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list, tags):
def on_forward_input_choice(self, mutable, tensor_list):
"""
Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `tags`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy specified in `mutable.reduction`. It will also cache the
......@@ -81,48 +93,11 @@ class Mutator(BaseMutator, nn.Module):
-------
tuple of torch.Tensor and torch.Tensor
"""
self._check_in_forward_pass()
mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, tags))
mask = self._get_decision(mutable)
assert len(mask) == mutable.n_candidates
out = self._select_with_mask(lambda x: x, [(t,) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_calc_layer_choice_mask(self, mutable):
"""
Recommended to override. Calculate a mask tensor for a layer choice.
Parameters
----------
mutable: LayerChoice
Corresponding layer choice object.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise NotImplementedError("Layer choice mask calculation must be implemented")
def on_calc_input_choice_mask(self, mutable, tags):
"""
Recommended to override. Calculate a mask tensor for a input choice.
Parameters
----------
mutable: InputChoice
Corresponding input choice object.
tags: list of string
The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise NotImplementedError("Input choice mask calculation must be implemented")
def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type():
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
......@@ -146,3 +121,20 @@ class Mutator(BaseMutator, nn.Module):
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
def _get_decision(self, mutable):
"""
By default, this method checks whether `mutable.key` is already in the decision cache,
and returns the result without double-check.
Parameters
----------
mutable: Mutable
Returns
-------
any
"""
if mutable.key not in self._cache:
raise ValueError("\"{}\" not found in decision cache.".format(mutable.key))
return self._cache[mutable.key]
......@@ -11,14 +11,14 @@ from nni.nas.pytorch.mutables import LayerChoice
class PdartsMutator(DartsMutator):
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches=None):
def __init__(self, pdarts_epoch_index, pdarts_num_to_drop, switches=None):
self.pdarts_epoch_index = pdarts_epoch_index
self.pdarts_num_to_drop = pdarts_num_to_drop
self.switches = switches
super(PdartsMutator, self).__init__(model)
super(PdartsMutator, self).__init__()
def before_build(self, model):
def before_build(self):
self.choices = nn.ParameterDict()
if self.switches is None:
self.switches = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment