Unverified Commit bb797e10 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Update APIs and add preliminary support for ENAS macro space (#1714)

* add enas macro

* refactor example directory structure

* update docstring
parent e238d34a
from argparse import ArgumentParser
import datasets
import image_ops as ops
import nni.nas.pytorch as nas
import torch
import torch.nn as nn
from nni.nas.pytorch.darts import DartsTrainer
import ops
from nni.nas import pytorch as nas
class SearchCell(nn.Module):
......@@ -142,57 +139,3 @@ class SearchCNN(nn.Module):
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
return logits
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=4, type=int)
parser.add_argument("--nodes", default=2, type=int)
parser.add_argument("--batch-size", default=3, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001)
trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
model_optim=optim,
lr_scheduler=lr_scheduler,
num_epochs=50,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
trainer.train()
trainer.finalize()
# augment step
# ...
from argparse import ArgumentParser
import datasets
import torch
import torch.nn as nn
from model import SearchCNN
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=4, type=int)
parser.add_argument("--nodes", default=2, type=int)
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001)
trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
model_optim=optim,
lr_scheduler=lr_scheduler,
num_epochs=50,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
trainer.train()
trainer.export()
# augment step
# ...
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
\ No newline at end of file
from torchvision import transforms
from torchvision.datasets import CIFAR10
def get_dataset(cls):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]
train_transform = transforms.Compose(transf + normalize)
valid_transform = transforms.Compose(normalize)
if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
import torch
import torch.nn as nn
class StdConv(nn.Module):
def __init__(self, C_in, C_out):
super(StdConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
class PoolBranch(nn.Module):
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__()
self.preproc = StdConv(C_in, C_out)
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = self.preproc(x)
out = self.pool(out)
out = self.bn(out)
return out
class SeparableConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SeparableConv, self).__init__()
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
groups=C_in, bias=False)
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class ConvBranch(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
super(ConvBranch, self).__init__()
self.preproc = StdConv(C_in, C_out)
if separable:
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
else:
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
self.postproc = nn.Sequential(
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
out = self.preproc(x)
out = self.conv(out)
out = self.postproc(out)
return out
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=False):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
from argparse import ArgumentParser
import torch
import torch.nn as nn
import datasets
from ops import FactorizedReduce, ConvBranch, PoolBranch
from nni.nas.pytorch import mutables, enas
class ENASLayer(nn.Module):
def __init__(self, layer_id, in_filters, out_filters):
super().__init__()
self.in_filters = in_filters
self.out_filters = out_filters
self.mutable = mutables.LayerChoice([
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if layer_id > 0:
self.skipconnect = mutables.InputChoice(layer_id, n_selected=None, reduction="sum")
else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
self.mutable_scope = mutables.MutableScope("layer_{}".format(layer_id))
def forward(self, prev_layers):
with self.mutable_scope:
out = self.mutable(prev_layers[-1])
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[:-1],
["layer_{}".format(i) for i in range(len(prev_layers) - 1)])
if connection is not None:
out += connection
return self.batch_norm(out)
class GeneralNetwork(nn.Module):
def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,
dropout_rate=0.0):
super().__init__()
self.num_layers = num_layers
self.num_classes = num_classes
self.out_filters = out_filters
self.stem = nn.Sequential(
nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_filters)
)
pool_distance = self.num_layers // 3
self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
self.dropout_rate = dropout_rate
self.dropout = nn.Dropout(self.dropout_rate)
self.layers = nn.ModuleList()
self.pool_layers = nn.ModuleList()
for layer_id in range(self.num_layers):
if layer_id in self.pool_layers_idx:
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
self.layers.append(ENASLayer(layer_id, self.out_filters, self.out_filters))
self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(self.out_filters, self.num_classes)
def forward(self, x):
bs = x.size(0)
cur = self.stem(x)
layers = [cur]
for layer_id in range(self.num_layers):
cur = self.layers[layer_id](layers)
layers.append(cur)
if layer_id in self.pool_layers_idx:
for i, layer in enumerate(layers):
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
cur = layers[-1]
cur = self.gap(cur).view(bs, -1)
cur = self.dropout(cur)
logits = self.dense(cur)
return logits
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
def reward_accuracy(output, target, topk=(1,)):
batch_size = target.size(0)
_, predicted = torch.max(output.data, 1)
return (predicted == target).sum().item() / batch_size
if __name__ == "__main__":
parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=3, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = GeneralNetwork()
criterion = nn.CrossEntropyLoss()
n_epochs = 310
optim = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs, eta_min=0.001)
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optim,
lr_scheduler=lr_scheduler,
batch_size=args.batch_size,
num_epochs=n_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency)
trainer.train()
import torch
import torch.nn as nn
class StdConv(nn.Module):
def __init__(self, C_in, C_out):
super(StdConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
class PoolBranch(nn.Module):
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__()
self.preproc = StdConv(C_in, C_out)
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = self.preproc(x)
out = self.pool(out)
out = self.bn(out)
return out
class SeparableConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SeparableConv, self).__init__()
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
groups=C_in, bias=False)
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class ConvBranch(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
super(ConvBranch, self).__init__()
self.preproc = StdConv(C_in, C_out)
if separable:
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
else:
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
self.postproc = nn.Sequential(
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
out = self.preproc(x)
out = self.conv(out)
out = self.postproc(out)
return out
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=False):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
......@@ -14,6 +14,5 @@ class DartsMutator(PyTorchMutator):
def on_init_layer_choice(self, mutable: LayerChoice):
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length))
def on_forward_layer_choice(self, mutable: LayerChoice, ops, *inputs):
weights = F.softmax(self.choices[mutable.key], dim=-1)
return sum(w * op(*inputs) for w, op in zip(weights, ops))
def on_calc_layer_choice_mask(self, mutable: LayerChoice):
return F.softmax(self.choices[mutable.key], dim=-1)
......@@ -61,6 +61,7 @@ class DartsTrainer(Trainer):
# phase 1. child network step
self.model_optim.zero_grad()
with self.mutator.forward_pass():
logits = self.model(trn_X)
loss = self.loss(logits, trn_y)
loss.backward()
......@@ -117,6 +118,7 @@ class DartsTrainer(Trainer):
v_model: backup model before this step
lr: learning rate for virtual gradient step (same as net lr)
"""
with self.mutator.forward_pass():
loss = self.loss(self.model(val_X), val_y)
w_model = tuple(self.model.parameters())
w_ctrl = tuple(self.mutator.parameters())
......@@ -148,7 +150,8 @@ class DartsTrainer(Trainer):
for p, d in zip(self.model.parameters(), dw):
p += eps * d
loss = self.loss(self.model(trn_X), trn_y) # TODO: should use model instead of self.model
with self.mutator.forward_pass():
loss = self.loss(self.model(trn_X), trn_y)
if e > 0:
dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) }
elif e < 0:
......@@ -157,5 +160,5 @@ class DartsTrainer(Trainer):
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian
def finalize(self):
def export(self):
pass
from .mutator import EnasMutator
from .trainer import EnasTrainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import PyTorchMutator
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_c, prev_h = hidden
next_c, next_h = [], []
for i, m in enumerate(self.lstm_modules):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
next_c.append(curr_c)
next_h.append(curr_h)
inputs = curr_h[-1]
return next_c, next_h
class EnasMutator(PyTorchMutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, anchor_extra_step=False,
skip_target=0.4):
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.max_layer_choice = 0
self.anchor_extra_step = anchor_extra_step
self.skip_target = skip_target
super().__init__(model)
def before_build(self, model):
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False)
self.cross_entropy_loss = nn.CrossEntropyLoss()
def after_build(self, model):
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice)
def before_pass(self):
super().before_pass()
self._anchors_hid = dict()
self._selected_layers = []
self._selected_inputs = []
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._c, self._h = self.lstm(self._inputs, (self._c, self._h))
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
def on_init_layer_choice(self, mutable):
if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \
"ENAS mutator requires all layer choice have the same number of candidates."
def on_calc_layer_choice_mask(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += log_prob
entropy = (log_prob * torch.exp(-log_prob)).detach()
self.sample_entropy += entropy
self._inputs = self.embedding(branch_id)
self._selected_layers.append(branch_id.item())
return F.one_hot(branch_id).bool().view(-1)
def on_calc_input_choice_mask(self, mutable, semantic_labels):
if mutable.n_selected is None:
query, anchors = [], []
for label in semantic_labels:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
logit = torch.cat([-query, query], 1)
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip)
self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach()
self.sample_entropy += torch.sum(entropy)
self.inputs = torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))
self._selected_inputs.append(skip)
return skip.bool()
else:
assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS."
raise NotImplementedError
def exit_mutable_scope(self, mutable_scope):
self._mark_anchor(mutable_scope.key)
import torch
import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer
from nni.nas.utils import AverageMeterGroup, auto_device
from .mutator import EnasMutator
class EnasTrainer(Trainer):
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid, lr_scheduler=None,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
mutator_lr=0.00035):
self.model = model
self.loss = loss
self.metrics = metrics
self.reward_function = reward_function
self.mutator = mutator
if self.mutator is None:
self.mutator = EnasMutator(model)
self.optim = optimizer
self.mut_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
self.lr_scheduler = lr_scheduler
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.device = auto_device() if device is None else device
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.model.to(self.device)
self.loss.to(self.device)
self.mutator.to(self.device)
n_train = len(self.dataset_train)
split = n_train // 10
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_epoch(self, epoch):
self.model.train()
self.mutator.train()
for phase in ["model", "mutator"]:
if phase == "model":
self.model.train()
self.mutator.eval()
else:
self.model.eval()
self.mutator.train()
loader = self.train_loader if phase == "model" else self.valid_loader
meters = AverageMeterGroup()
for step, (x, y) in enumerate(loader):
x, y = x.to(self.device), y.to(self.device)
self.optim.zero_grad()
self.mut_optim.zero_grad()
with self.mutator.forward_pass():
logits = self.model(x)
metrics = self.metrics(logits, y)
if phase == "model":
loss = self.loss(logits, y)
loss.backward()
self.optim.step()
else:
reward = self.reward_function(logits, y)
if self.entropy_weight is not None:
reward += self.entropy_weight * self.mutator.sample_entropy
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
self.baseline = self.baseline.detach().item()
loss = self.mutator.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.mutator.sample_skip_penalty
loss.backward()
self.mut_optim.step()
metrics["reward"] = reward
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
print("Epoch {} {} Step [{}/{}] {}".format(epoch, phase.capitalize(), step,
len(loader), meters))
# print(self.mutator._selected_layers)
# print(self.mutator._selected_inputs)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
def validate_epoch(self, epoch):
pass
def train(self):
for epoch in range(self.num_epochs):
# training
print("Epoch {} Training".format(epoch))
self.train_epoch(epoch)
# validation
print("Epoch {} Validating".format(epoch))
self.validate_epoch(epoch)
def export(self):
pass
......@@ -18,51 +18,101 @@ class PyTorchMutable(nn.Module):
def __init__(self, key=None):
super().__init__()
if key is not None:
self.key = key
if not isinstance(key, str):
key = str(key)
print("Warning: key \"{}\" is not string, converted to string.".format(key))
self._key = key
else:
self.key = self.__class__.__name__ + str(global_mutable_counting())
self._key = self.__class__.__name__ + str(global_mutable_counting())
self.name = self.key
def __deepcopy__(self, memodict=None):
raise NotImplementedError
raise NotImplementedError("Deep copy doesn't work for mutables.")
def __enter__(self):
self._check_built()
return super().__enter__()
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator):
self.__dict__["mutator"] = mutator
def forward(self, *inputs):
raise NotImplementedError("Mutable forward must be implemented")
raise NotImplementedError("Mutable forward must be implemented.")
def __repr__(self):
return "{} ({})".format(self.name, self.key)
@property
def key(self):
return self._key
def similar(self, other):
return self == other
def _check_built(self):
if not hasattr(self, "mutator"):
raise ValueError(
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(PyTorchMutable):
"""
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch
corresponding events, and do status dump or update.
"""
def __init__(self, key):
super().__init__(key=key)
def __enter__(self):
self.mutator.enter_mutable_scope(self)
def __exit__(self, exc_type, exc_val, exc_tb):
self.mutator.exit_mutable_scope(self)
class LayerChoice(PyTorchMutable):
def __init__(self, ops, key=None):
def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None):
super().__init__(key=key)
self.length = len(ops)
self.choices = nn.ModuleList(ops)
self.length = len(op_candidates)
self.choices = nn.ModuleList(op_candidates)
self.reduction = reduction
self.return_mask = return_mask
def forward(self, *inputs):
return self.mutator.on_forward(self, self.choices, *inputs)
out, mask = self.mutator.on_forward(self, *inputs)
if self.return_mask:
return out, mask
return out
def similar(self, other):
return type(self) == type(other) and self.length == other.length
class InputChoice(PyTorchMutable):
def __init__(self, n_candidates, n_selected=None, reduction="mean", return_index=False, key=None):
def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None):
super().__init__(key=key)
assert n_candidates > 0, "Number of candidates must be greater than 0."
self.n_candidates = n_candidates
self.n_selected = n_selected
self.reduction = reduction
self.return_index = return_index
def forward(self, *inputs):
assert len(inputs) == self.n_candidates, "Length of the input list must be equal to number of candidates."
return self.mutator.on_forward(self, *inputs)
self.return_mask = return_mask
def forward(self, optional_inputs, semantic_labels=None):
assert len(optional_inputs) == self.n_candidates, \
"Length of the input list must be equal to number of candidates."
if semantic_labels is None:
semantic_labels = ["default_label"] * self.n_candidates
out, mask = self.mutator.on_forward(self, optional_inputs, semantic_labels)
if self.return_mask:
return out, mask
return out
def similar(self, other):
return type(self) == type(other) and \
......
import logging
from contextlib import contextmanager
from torch import nn as nn
import torch
import torch.nn as nn
from nni.nas.pytorch.mutables import PyTorchMutable
from nni.nas.utils import to_snake_case
......@@ -28,8 +30,8 @@ class PyTorchMutator(nn.Module):
if isinstance(module, PyTorchMutable):
distinct = False
if module.key in key2module:
assert key2module[module.key].similar(module), "Mutable that share the same key must be similar " \
"to each other"
assert key2module[module.key].similar(module), \
"Mutable \"{}\" that share the same key must be similar to each other".format(module.key)
else:
distinct = True
key2module[module.key] = module
......@@ -56,11 +58,35 @@ class PyTorchMutator(nn.Module):
def on_init_general(self, mutable):
pass
def on_forward_general(self, mutable, *inputs):
raise NotImplementedError("Forward has to be implemented")
@contextmanager
def forward_pass(self):
self.before_pass()
try:
yield self
finally:
self.after_pass()
def before_pass(self):
self._in_forward_pass = True
self._cache = dict()
def after_pass(self):
self._in_forward_pass = False
def enter_mutable_scope(self, mutable_scope):
pass
def exit_mutable_scope(self, mutable_scope):
pass
def forward(self, *inputs):
raise NotImplementedError("Mutator is not forward-able")
def on_forward(self, mutable, *inputs):
"""Callback on forwarding a mutable"""
if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass:
raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
"super().before_pass() and after_pass() in your override method?")
forward_method_name = "on_forward_{}".format(to_snake_case(mutable.__class__.__name__))
if hasattr(self, forward_method_name) and callable(getattr(self, forward_method_name)):
return getattr(self, forward_method_name)(mutable, *inputs)
......@@ -68,5 +94,110 @@ class PyTorchMutator(nn.Module):
# fallback to general forward
return self.on_forward_general(mutable, *inputs)
def forward(self, *inputs):
raise NotImplementedError("Mutator is not forward-able")
def on_forward_general(self, mutable, *inputs):
raise NotImplementedError("Forward has to be implemented")
def on_forward_layer_choice(self, mutable, *inputs):
"""
Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy speicified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
Returns
-------
torch.Tensor
"""
def _map_fn(op, *inputs):
return op(*inputs)
mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable))
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list, semantic_labels):
"""
Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `semantic_labels`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy speicified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
Parameters
----------
mutable: InputChoice
inputs: list of torch.Tensor
Returns
-------
torch.Tensor
"""
mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, semantic_labels))
out = self._select_with_mask(lambda x: x, [(t, ) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_calc_layer_choice_mask(self, mutable):
"""
Recommended to override. Calculate a mask tensor for a layer choice.
Parameters
----------
mutable: LayerChoice
Corresponding layer choice object.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise NotImplementedError("Layer choice mask calculation must be implemented")
def on_calc_input_choice_mask(self, mutable, semantic_labels):
"""
Recommended to override. Calculate a mask tensor for a input choice.
Parameters
----------
mutable: InputChoice
Corresponding input choice object.
semantic_labels: list of string
The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise NotImplementedError("Input choice mask calculation must be implemented")
def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type():
# print(candidates[0], len(mask))
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)]
else:
raise ValueError("Unrecognized mask")
return out
def _tensor_reduction(self, reduction_type, tensor_list):
if tensor_list == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
......@@ -8,5 +8,5 @@ class Trainer(ABC):
raise NotImplementedError
@abstractmethod
def finalize(self):
def export(self):
raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment