"docs/git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "5f1366cef0b8d82269f762ada3d23a67205077b5"
Unverified Commit bb797e10 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Update APIs and add preliminary support for ENAS macro space (#1714)

* add enas macro

* refactor example directory structure

* update docstring
parent e238d34a
from argparse import ArgumentParser
import datasets
import image_ops as ops
import nni.nas.pytorch as nas
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.nas.pytorch.darts import DartsTrainer
import ops
from nni.nas import pytorch as nas
class SearchCell(nn.Module): class SearchCell(nn.Module):
...@@ -142,57 +139,3 @@ class SearchCNN(nn.Module): ...@@ -142,57 +139,3 @@ class SearchCNN(nn.Module):
out = out.view(out.size(0), -1) # flatten out = out.view(out.size(0), -1) # flatten
logits = self.linear(out) logits = self.linear(out)
return logits return logits
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=4, type=int)
parser.add_argument("--nodes", default=2, type=int)
parser.add_argument("--batch-size", default=3, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001)
trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
model_optim=optim,
lr_scheduler=lr_scheduler,
num_epochs=50,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
trainer.train()
trainer.finalize()
# augment step
# ...
from argparse import ArgumentParser
import datasets
import torch
import torch.nn as nn
from model import SearchCNN
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=4, type=int)
parser.add_argument("--nodes", default=2, type=int)
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001)
trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
model_optim=optim,
lr_scheduler=lr_scheduler,
num_epochs=50,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
trainer.train()
trainer.export()
# augment step
# ...
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
\ No newline at end of file
from torchvision import transforms
from torchvision.datasets import CIFAR10
def get_dataset(cls):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]
train_transform = transforms.Compose(transf + normalize)
valid_transform = transforms.Compose(normalize)
if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
import torch
import torch.nn as nn
class StdConv(nn.Module):
def __init__(self, C_in, C_out):
super(StdConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
class PoolBranch(nn.Module):
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__()
self.preproc = StdConv(C_in, C_out)
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = self.preproc(x)
out = self.pool(out)
out = self.bn(out)
return out
class SeparableConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SeparableConv, self).__init__()
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
groups=C_in, bias=False)
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class ConvBranch(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
super(ConvBranch, self).__init__()
self.preproc = StdConv(C_in, C_out)
if separable:
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
else:
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
self.postproc = nn.Sequential(
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
out = self.preproc(x)
out = self.conv(out)
out = self.postproc(out)
return out
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=False):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
from argparse import ArgumentParser
import torch
import torch.nn as nn
import datasets
from ops import FactorizedReduce, ConvBranch, PoolBranch
from nni.nas.pytorch import mutables, enas
class ENASLayer(nn.Module):
def __init__(self, layer_id, in_filters, out_filters):
super().__init__()
self.in_filters = in_filters
self.out_filters = out_filters
self.mutable = mutables.LayerChoice([
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if layer_id > 0:
self.skipconnect = mutables.InputChoice(layer_id, n_selected=None, reduction="sum")
else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
self.mutable_scope = mutables.MutableScope("layer_{}".format(layer_id))
def forward(self, prev_layers):
with self.mutable_scope:
out = self.mutable(prev_layers[-1])
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[:-1],
["layer_{}".format(i) for i in range(len(prev_layers) - 1)])
if connection is not None:
out += connection
return self.batch_norm(out)
class GeneralNetwork(nn.Module):
def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,
dropout_rate=0.0):
super().__init__()
self.num_layers = num_layers
self.num_classes = num_classes
self.out_filters = out_filters
self.stem = nn.Sequential(
nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_filters)
)
pool_distance = self.num_layers // 3
self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
self.dropout_rate = dropout_rate
self.dropout = nn.Dropout(self.dropout_rate)
self.layers = nn.ModuleList()
self.pool_layers = nn.ModuleList()
for layer_id in range(self.num_layers):
if layer_id in self.pool_layers_idx:
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
self.layers.append(ENASLayer(layer_id, self.out_filters, self.out_filters))
self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(self.out_filters, self.num_classes)
def forward(self, x):
bs = x.size(0)
cur = self.stem(x)
layers = [cur]
for layer_id in range(self.num_layers):
cur = self.layers[layer_id](layers)
layers.append(cur)
if layer_id in self.pool_layers_idx:
for i, layer in enumerate(layers):
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
cur = layers[-1]
cur = self.gap(cur).view(bs, -1)
cur = self.dropout(cur)
logits = self.dense(cur)
return logits
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
def reward_accuracy(output, target, topk=(1,)):
batch_size = target.size(0)
_, predicted = torch.max(output.data, 1)
return (predicted == target).sum().item() / batch_size
if __name__ == "__main__":
parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=3, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = GeneralNetwork()
criterion = nn.CrossEntropyLoss()
n_epochs = 310
optim = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs, eta_min=0.001)
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optim,
lr_scheduler=lr_scheduler,
batch_size=args.batch_size,
num_epochs=n_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency)
trainer.train()
import torch
import torch.nn as nn
class StdConv(nn.Module):
def __init__(self, C_in, C_out):
super(StdConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
class PoolBranch(nn.Module):
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__()
self.preproc = StdConv(C_in, C_out)
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = self.preproc(x)
out = self.pool(out)
out = self.bn(out)
return out
class SeparableConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SeparableConv, self).__init__()
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
groups=C_in, bias=False)
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class ConvBranch(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
super(ConvBranch, self).__init__()
self.preproc = StdConv(C_in, C_out)
if separable:
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
else:
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
self.postproc = nn.Sequential(
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
out = self.preproc(x)
out = self.conv(out)
out = self.postproc(out)
return out
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=False):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
...@@ -14,6 +14,5 @@ class DartsMutator(PyTorchMutator): ...@@ -14,6 +14,5 @@ class DartsMutator(PyTorchMutator):
def on_init_layer_choice(self, mutable: LayerChoice): def on_init_layer_choice(self, mutable: LayerChoice):
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length)) self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length))
def on_forward_layer_choice(self, mutable: LayerChoice, ops, *inputs): def on_calc_layer_choice_mask(self, mutable: LayerChoice):
weights = F.softmax(self.choices[mutable.key], dim=-1) return F.softmax(self.choices[mutable.key], dim=-1)
return sum(w * op(*inputs) for w, op in zip(weights, ops))
...@@ -61,7 +61,8 @@ class DartsTrainer(Trainer): ...@@ -61,7 +61,8 @@ class DartsTrainer(Trainer):
# phase 1. child network step # phase 1. child network step
self.model_optim.zero_grad() self.model_optim.zero_grad()
logits = self.model(trn_X) with self.mutator.forward_pass():
logits = self.model(trn_X)
loss = self.loss(logits, trn_y) loss = self.loss(logits, trn_y)
loss.backward() loss.backward()
# gradient clipping # gradient clipping
...@@ -117,7 +118,8 @@ class DartsTrainer(Trainer): ...@@ -117,7 +118,8 @@ class DartsTrainer(Trainer):
v_model: backup model before this step v_model: backup model before this step
lr: learning rate for virtual gradient step (same as net lr) lr: learning rate for virtual gradient step (same as net lr)
""" """
loss = self.loss(self.model(val_X), val_y) with self.mutator.forward_pass():
loss = self.loss(self.model(val_X), val_y)
w_model = tuple(self.model.parameters()) w_model = tuple(self.model.parameters())
w_ctrl = tuple(self.mutator.parameters()) w_ctrl = tuple(self.mutator.parameters())
w_grads = torch.autograd.grad(loss, w_model + w_ctrl) w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
...@@ -148,7 +150,8 @@ class DartsTrainer(Trainer): ...@@ -148,7 +150,8 @@ class DartsTrainer(Trainer):
for p, d in zip(self.model.parameters(), dw): for p, d in zip(self.model.parameters(), dw):
p += eps * d p += eps * d
loss = self.loss(self.model(trn_X), trn_y) # TODO: should use model instead of self.model with self.mutator.forward_pass():
loss = self.loss(self.model(trn_X), trn_y)
if e > 0: if e > 0:
dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) } dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) }
elif e < 0: elif e < 0:
...@@ -157,5 +160,5 @@ class DartsTrainer(Trainer): ...@@ -157,5 +160,5 @@ class DartsTrainer(Trainer):
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)] hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian return hessian
def finalize(self): def export(self):
pass pass
from .mutator import EnasMutator
from .trainer import EnasTrainer
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch.mutator import PyTorchMutator
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_c, prev_h = hidden
next_c, next_h = [], []
for i, m in enumerate(self.lstm_modules):
curr_c, curr_h = m(inputs, (prev_c[i], prev_h[i]))
next_c.append(curr_c)
next_h.append(curr_h)
inputs = curr_h[-1]
return next_c, next_h
class EnasMutator(PyTorchMutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, anchor_extra_step=False,
skip_target=0.4):
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.max_layer_choice = 0
self.anchor_extra_step = anchor_extra_step
self.skip_target = skip_target
super().__init__(model)
def before_build(self, model):
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False)
self.cross_entropy_loss = nn.CrossEntropyLoss()
def after_build(self, model):
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice)
def before_pass(self):
super().before_pass()
self._anchors_hid = dict()
self._selected_layers = []
self._selected_inputs = []
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._c, self._h = self.lstm(self._inputs, (self._c, self._h))
def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1]
def on_init_layer_choice(self, mutable):
if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \
"ENAS mutator requires all layer choice have the same number of candidates."
def on_calc_layer_choice_mask(self, mutable):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += log_prob
entropy = (log_prob * torch.exp(-log_prob)).detach()
self.sample_entropy += entropy
self._inputs = self.embedding(branch_id)
self._selected_layers.append(branch_id.item())
return F.one_hot(branch_id).bool().view(-1)
def on_calc_input_choice_mask(self, mutable, semantic_labels):
if mutable.n_selected is None:
query, anchors = [], []
for label in semantic_labels:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
logit = torch.cat([-query, query], 1)
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip)
self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach()
self.sample_entropy += torch.sum(entropy)
self.inputs = torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))
self._selected_inputs.append(skip)
return skip.bool()
else:
assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS."
raise NotImplementedError
def exit_mutable_scope(self, mutable_scope):
self._mark_anchor(mutable_scope.key)
import torch
import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer
from nni.nas.utils import AverageMeterGroup, auto_device
from .mutator import EnasMutator
class EnasTrainer(Trainer):
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid, lr_scheduler=None,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
mutator_lr=0.00035):
self.model = model
self.loss = loss
self.metrics = metrics
self.reward_function = reward_function
self.mutator = mutator
if self.mutator is None:
self.mutator = EnasMutator(model)
self.optim = optimizer
self.mut_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
self.lr_scheduler = lr_scheduler
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.device = auto_device() if device is None else device
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.model.to(self.device)
self.loss.to(self.device)
self.mutator.to(self.device)
n_train = len(self.dataset_train)
split = n_train // 10
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
self.test_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def train_epoch(self, epoch):
self.model.train()
self.mutator.train()
for phase in ["model", "mutator"]:
if phase == "model":
self.model.train()
self.mutator.eval()
else:
self.model.eval()
self.mutator.train()
loader = self.train_loader if phase == "model" else self.valid_loader
meters = AverageMeterGroup()
for step, (x, y) in enumerate(loader):
x, y = x.to(self.device), y.to(self.device)
self.optim.zero_grad()
self.mut_optim.zero_grad()
with self.mutator.forward_pass():
logits = self.model(x)
metrics = self.metrics(logits, y)
if phase == "model":
loss = self.loss(logits, y)
loss.backward()
self.optim.step()
else:
reward = self.reward_function(logits, y)
if self.entropy_weight is not None:
reward += self.entropy_weight * self.mutator.sample_entropy
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
self.baseline = self.baseline.detach().item()
loss = self.mutator.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.mutator.sample_skip_penalty
loss.backward()
self.mut_optim.step()
metrics["reward"] = reward
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
print("Epoch {} {} Step [{}/{}] {}".format(epoch, phase.capitalize(), step,
len(loader), meters))
# print(self.mutator._selected_layers)
# print(self.mutator._selected_inputs)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
def validate_epoch(self, epoch):
pass
def train(self):
for epoch in range(self.num_epochs):
# training
print("Epoch {} Training".format(epoch))
self.train_epoch(epoch)
# validation
print("Epoch {} Validating".format(epoch))
self.validate_epoch(epoch)
def export(self):
pass
...@@ -18,51 +18,101 @@ class PyTorchMutable(nn.Module): ...@@ -18,51 +18,101 @@ class PyTorchMutable(nn.Module):
def __init__(self, key=None): def __init__(self, key=None):
super().__init__() super().__init__()
if key is not None: if key is not None:
self.key = key if not isinstance(key, str):
key = str(key)
print("Warning: key \"{}\" is not string, converted to string.".format(key))
self._key = key
else: else:
self.key = self.__class__.__name__ + str(global_mutable_counting()) self._key = self.__class__.__name__ + str(global_mutable_counting())
self.name = self.key self.name = self.key
def __deepcopy__(self, memodict=None): def __deepcopy__(self, memodict=None):
raise NotImplementedError raise NotImplementedError("Deep copy doesn't work for mutables.")
def __enter__(self):
self._check_built()
return super().__enter__()
def __call__(self, *args, **kwargs):
self._check_built()
return super().__call__(*args, **kwargs)
def set_mutator(self, mutator): def set_mutator(self, mutator):
self.__dict__["mutator"] = mutator self.__dict__["mutator"] = mutator
def forward(self, *inputs): def forward(self, *inputs):
raise NotImplementedError("Mutable forward must be implemented") raise NotImplementedError("Mutable forward must be implemented.")
def __repr__(self): @property
return "{} ({})".format(self.name, self.key) def key(self):
return self._key
def similar(self, other): def similar(self, other):
return self == other return self == other
def _check_built(self):
if not hasattr(self, "mutator"):
raise ValueError(
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(PyTorchMutable):
"""
Mutable scope labels a subgraph to help mutators make better decisions. Mutators get notified when a mutable scope
is entered and exited. Mutators can override ``enter_mutable_scope`` and ``exit_mutable_scope`` to catch
corresponding events, and do status dump or update.
"""
def __init__(self, key):
super().__init__(key=key)
def __enter__(self):
self.mutator.enter_mutable_scope(self)
def __exit__(self, exc_type, exc_val, exc_tb):
self.mutator.exit_mutable_scope(self)
class LayerChoice(PyTorchMutable): class LayerChoice(PyTorchMutable):
def __init__(self, ops, key=None): def __init__(self, op_candidates, reduction="mean", return_mask=False, key=None):
super().__init__(key=key) super().__init__(key=key)
self.length = len(ops) self.length = len(op_candidates)
self.choices = nn.ModuleList(ops) self.choices = nn.ModuleList(op_candidates)
self.reduction = reduction
self.return_mask = return_mask
def forward(self, *inputs): def forward(self, *inputs):
return self.mutator.on_forward(self, self.choices, *inputs) out, mask = self.mutator.on_forward(self, *inputs)
if self.return_mask:
return out, mask
return out
def similar(self, other): def similar(self, other):
return type(self) == type(other) and self.length == other.length return type(self) == type(other) and self.length == other.length
class InputChoice(PyTorchMutable): class InputChoice(PyTorchMutable):
def __init__(self, n_candidates, n_selected=None, reduction="mean", return_index=False, key=None): def __init__(self, n_candidates, n_selected=None, reduction="mean", return_mask=False, key=None):
super().__init__(key=key) super().__init__(key=key)
assert n_candidates > 0, "Number of candidates must be greater than 0."
self.n_candidates = n_candidates self.n_candidates = n_candidates
self.n_selected = n_selected self.n_selected = n_selected
self.reduction = reduction self.reduction = reduction
self.return_index = return_index self.return_mask = return_mask
def forward(self, *inputs): def forward(self, optional_inputs, semantic_labels=None):
assert len(inputs) == self.n_candidates, "Length of the input list must be equal to number of candidates." assert len(optional_inputs) == self.n_candidates, \
return self.mutator.on_forward(self, *inputs) "Length of the input list must be equal to number of candidates."
if semantic_labels is None:
semantic_labels = ["default_label"] * self.n_candidates
out, mask = self.mutator.on_forward(self, optional_inputs, semantic_labels)
if self.return_mask:
return out, mask
return out
def similar(self, other): def similar(self, other):
return type(self) == type(other) and \ return type(self) == type(other) and \
......
import logging import logging
from contextlib import contextmanager
from torch import nn as nn import torch
import torch.nn as nn
from nni.nas.pytorch.mutables import PyTorchMutable from nni.nas.pytorch.mutables import PyTorchMutable
from nni.nas.utils import to_snake_case from nni.nas.utils import to_snake_case
...@@ -28,8 +30,8 @@ class PyTorchMutator(nn.Module): ...@@ -28,8 +30,8 @@ class PyTorchMutator(nn.Module):
if isinstance(module, PyTorchMutable): if isinstance(module, PyTorchMutable):
distinct = False distinct = False
if module.key in key2module: if module.key in key2module:
assert key2module[module.key].similar(module), "Mutable that share the same key must be similar " \ assert key2module[module.key].similar(module), \
"to each other" "Mutable \"{}\" that share the same key must be similar to each other".format(module.key)
else: else:
distinct = True distinct = True
key2module[module.key] = module key2module[module.key] = module
...@@ -56,11 +58,35 @@ class PyTorchMutator(nn.Module): ...@@ -56,11 +58,35 @@ class PyTorchMutator(nn.Module):
def on_init_general(self, mutable): def on_init_general(self, mutable):
pass pass
def on_forward_general(self, mutable, *inputs): @contextmanager
raise NotImplementedError("Forward has to be implemented") def forward_pass(self):
self.before_pass()
try:
yield self
finally:
self.after_pass()
def before_pass(self):
self._in_forward_pass = True
self._cache = dict()
def after_pass(self):
self._in_forward_pass = False
def enter_mutable_scope(self, mutable_scope):
pass
def exit_mutable_scope(self, mutable_scope):
pass
def forward(self, *inputs):
raise NotImplementedError("Mutator is not forward-able")
def on_forward(self, mutable, *inputs): def on_forward(self, mutable, *inputs):
"""Callback on forwarding a mutable""" """Callback on forwarding a mutable"""
if not hasattr(self, "_in_forward_pass") or not self._in_forward_pass:
raise ValueError("Not in forward pass. Did you forget to call mutator.forward_pass(), or forget to call "
"super().before_pass() and after_pass() in your override method?")
forward_method_name = "on_forward_{}".format(to_snake_case(mutable.__class__.__name__)) forward_method_name = "on_forward_{}".format(to_snake_case(mutable.__class__.__name__))
if hasattr(self, forward_method_name) and callable(getattr(self, forward_method_name)): if hasattr(self, forward_method_name) and callable(getattr(self, forward_method_name)):
return getattr(self, forward_method_name)(mutable, *inputs) return getattr(self, forward_method_name)(mutable, *inputs)
...@@ -68,5 +94,110 @@ class PyTorchMutator(nn.Module): ...@@ -68,5 +94,110 @@ class PyTorchMutator(nn.Module):
# fallback to general forward # fallback to general forward
return self.on_forward_general(mutable, *inputs) return self.on_forward_general(mutable, *inputs)
def forward(self, *inputs): def on_forward_general(self, mutable, *inputs):
raise NotImplementedError("Mutator is not forward-able") raise NotImplementedError("Forward has to be implemented")
def on_forward_layer_choice(self, mutable, *inputs):
"""
Callback of layer choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_layer_choice_mask` to get a mask on how to choose between layers
(either by switch or by weights), then it will reduce the list of all tensor outputs with the policy speicified
in `mutable.reduction`. It will also cache the mask with corresponding `mutable.key`.
Parameters
----------
mutable: LayerChoice
inputs: list of torch.Tensor
Returns
-------
torch.Tensor
"""
def _map_fn(op, *inputs):
return op(*inputs)
mask = self._cache.setdefault(mutable.key, self.on_calc_layer_choice_mask(mutable))
out = self._select_with_mask(_map_fn, [(choice, *inputs) for choice in mutable.choices], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_forward_input_choice(self, mutable, tensor_list, semantic_labels):
"""
Callback of input choice forward. Override if you are an advanced user.
On default, this method calls :meth:`on_calc_input_choice_mask` with `semantic_labels`
to get a mask on how to choose between inputs (either by switch or by weights), then it will reduce
the list of all tensor outputs with the policy speicified in `mutable.reduction`. It will also cache the
mask with corresponding `mutable.key`.
Parameters
----------
mutable: InputChoice
inputs: list of torch.Tensor
Returns
-------
torch.Tensor
"""
mask = self._cache.setdefault(mutable.key, self.on_calc_input_choice_mask(mutable, semantic_labels))
out = self._select_with_mask(lambda x: x, [(t, ) for t in tensor_list], mask)
return self._tensor_reduction(mutable.reduction, out), mask
def on_calc_layer_choice_mask(self, mutable):
"""
Recommended to override. Calculate a mask tensor for a layer choice.
Parameters
----------
mutable: LayerChoice
Corresponding layer choice object.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise NotImplementedError("Layer choice mask calculation must be implemented")
def on_calc_input_choice_mask(self, mutable, semantic_labels):
"""
Recommended to override. Calculate a mask tensor for a input choice.
Parameters
----------
mutable: InputChoice
Corresponding input choice object.
semantic_labels: list of string
The name of labels of input tensors given by user. Usually it's a
:class:`~nni.nas.pytorch.mutables.MutableScope` marked by user.
Returns
-------
torch.Tensor
Should be a 1D tensor, either float or bool. If float, the numbers are treated as weights. If bool,
the numbers are treated as switch.
"""
raise NotImplementedError("Input choice mask calculation must be implemented")
def _select_with_mask(self, map_fn, candidates, mask):
if "BoolTensor" in mask.type():
# print(candidates[0], len(mask))
out = [map_fn(*cand) for cand, m in zip(candidates, mask) if m]
elif "FloatTensor" in mask.type():
out = [map_fn(*cand) * m for cand, m in zip(candidates, mask)]
else:
raise ValueError("Unrecognized mask")
return out
def _tensor_reduction(self, reduction_type, tensor_list):
if tensor_list == "none":
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == "sum":
return sum(tensor_list)
if reduction_type == "mean":
return sum(tensor_list) / len(tensor_list)
if reduction_type == "concat":
return torch.cat(tensor_list, dim=1)
raise ValueError("Unrecognized reduction policy: \"{}\"".format(reduction_type))
...@@ -8,5 +8,5 @@ class Trainer(ABC): ...@@ -8,5 +8,5 @@ class Trainer(ABC):
raise NotImplementedError raise NotImplementedError
@abstractmethod @abstractmethod
def finalize(self): def export(self):
raise NotImplementedError raise NotImplementedError
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment