Commit 1cada380 authored by Yuge Zhang's avatar Yuge Zhang Committed by QuanluZhang
Browse files

Extract base mutator/trainer and support ENAS micro search space (#1739)

parent 3ddab980
import numpy as np
import torch
from torchvision import transforms from torchvision import transforms
from torchvision.datasets import CIFAR10 from torchvision.datasets import CIFAR10
def get_dataset(cls): class Cutout(object):
def __init__(self, length):
self.length = length
def __call__(self, img):
h, w = img.size(1), img.size(2)
mask = np.ones((h, w), np.float32)
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - self.length // 2, 0, h)
y2 = np.clip(y + self.length // 2, 0, h)
x1 = np.clip(x - self.length // 2, 0, w)
x2 = np.clip(x + self.length // 2, 0, w)
mask[y1: y2, x1: x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(img)
img *= mask
return img
def get_dataset(cls, cutout_length=0):
MEAN = [0.49139968, 0.48215827, 0.44653124] MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768] STD = [0.24703233, 0.24348505, 0.26158768]
transf = [ transf = [
...@@ -13,8 +38,11 @@ def get_dataset(cls): ...@@ -13,8 +38,11 @@ def get_dataset(cls):
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize(MEAN, STD) transforms.Normalize(MEAN, STD)
] ]
cutout = []
if cutout_length > 0:
cutout.append(Cutout(cutout_length))
train_transform = transforms.Compose(transf + normalize) train_transform = transforms.Compose(transf + normalize + cutout)
valid_transform = transforms.Compose(normalize) valid_transform = transforms.Compose(normalize)
if cls == "cifar10": if cls == "cifar10":
......
import torch
import torch.nn as nn
import ops
from nni.nas.pytorch import mutables, darts
class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
def __init__(self, input_size, C, n_classes):
""" assuming input size 7x7 or 8x8 """
assert input_size in [7, 8]
super().__init__()
self.net = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out
nn.Conv2d(C, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.linear = nn.Linear(768, n_classes)
def forward(self, x):
out = self.net(x)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
return logits
class Node(darts.DartsNode):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect, drop_path_prob=0.):
super().__init__(node_id, limitation=2)
self.ops = nn.ModuleList()
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
self.ops.append(
mutables.LayerChoice(
[
ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False),
],
key="{}_p{}".format(node_id, i)))
self.drop_path = ops.DropPath_(drop_path_prob)
def forward(self, prev_nodes):
assert len(self.ops) == len(prev_nodes)
out = [op(node) for op, node in zip(self.ops, prev_nodes)]
return sum(self.drop_path(o) for o in out if o is not None)
class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction, drop_path_prob=0.):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(self.n_nodes):
self.mutable_ops.append(Node("r{:d}_n{}".format(reduction, depth),
depth + 2, channels, 2 if reduction else 0,
drop_path_prob=drop_path_prob))
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
for node in self.mutable_ops:
cur_tensor = node(tensors)
tensors.append(cur_tensor)
output = torch.cat(tensors[2:], dim=1)
return output
class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
stem_multiplier=3, auxiliary=False, drop_path_prob=0.):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.n_classes = n_classes
self.n_layers = n_layers
self.aux_pos = 2 * n_layers // 3 if auxiliary else -1
c_cur = stem_multiplier * self.channels
self.stem = nn.Sequential(
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
nn.BatchNorm2d(c_cur)
)
# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp, channels_p, c_cur = c_cur, c_cur, channels
self.cells = nn.ModuleList()
reduction_p, reduction = False, False
for i in range(n_layers):
reduction_p, reduction = reduction, False
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
if i in [n_layers // 3, 2 * n_layers // 3]:
c_cur *= 2
reduction = True
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction, drop_path_prob=drop_path_prob)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out
if i == self.aux_pos:
self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)
self.gap = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(channels_p, n_classes)
def forward(self, x):
s0 = s1 = self.stem(x)
aux_logits = None
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1)
if i == self.aux_pos and self.training:
aux_logits = self.aux_head(s1)
out = self.gap(s1)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
if aux_logits is not None:
return logits, aux_logits
return logits
import torch
import torch.nn as nn
class DropPath_(nn.Module):
def __init__(self, p=0.):
""" [!] DropPath is inplace module
Args:
p: probability of an path to be zeroed.
"""
super().__init__()
self.p = p
def extra_repr(self):
return 'p={}, inplace'.format(self.p)
def forward(self, x):
if self.training and self.p > 0.:
keep_prob = 1. - self.p
# per data point mask
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
x.div_(keep_prob).mul_(mask)
return x
class PoolBN(nn.Module):
"""
AvgPool or MaxPool - BN
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
"""
Args:
pool_type: 'max' or 'avg'
"""
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C, affine=affine)
def forward(self, x):
out = self.pool(x)
out = self.bn(out)
return out
class StdConv(nn.Module):
""" Standard conv
ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class FacConv(nn.Module):
""" Factorized conv
ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class DilConv(nn.Module):
""" (Dilated) depthwise separable conv
ReLU - (Dilated) depthwise separable - Pointwise - BN
If dilation == 2, 3x3 conv => 5x5 receptive field
5x5 conv => 9x9 receptive field
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class SepConv(nn.Module):
""" Depthwise separable conv
DilConv(dilation=1) * 2
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
)
def forward(self, x):
return self.net(x)
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise(stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
from argparse import ArgumentParser from argparse import ArgumentParser
import datasets
import torch import torch
import torch.nn as nn import torch.nn as nn
import datasets from model import CNN
from nni.nas.pytorch.darts import CnnNetwork, DartsTrainer from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy from utils import accuracy
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser("darts") parser = ArgumentParser("darts")
parser.add_argument("--layers", default=5, type=int) parser.add_argument("--layers", default=8, type=int)
parser.add_argument("--nodes", default=4, type=int) parser.add_argument("--batch-size", default=96, type=int)
parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--log-frequency", default=10, type=int)
parser.add_argument("--log-frequency", default=1, type=int) parser.add_argument("--epochs", default=50, type=int)
args = parser.parse_args() args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10") dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = CnnNetwork(3, 16, 10, args.layers, n_nodes=args.nodes) model = CNN(32, 3, 16, 10, args.layers)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50 lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001)
trainer = DartsTrainer(model, trainer = DartsTrainer(model,
loss=criterion, loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)), metrics=lambda output, target: accuracy(output, target, topk=(1,)),
model_optim=optim, optimizer=optim,
lr_scheduler=lr_scheduler, num_epochs=args.epochs,
num_epochs=50,
dataset_train=dataset_train, dataset_train=dataset_train,
dataset_valid=dataset_valid, dataset_valid=dataset_valid,
batch_size=args.batch_size, batch_size=args.batch_size,
log_frequency=args.log_frequency) log_frequency=args.log_frequency,
trainer.train() callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
trainer.export() trainer.train_and_validate()
# augment step
# ...
import torch
import torch.nn as nn
class StdConv(nn.Module):
def __init__(self, C_in, C_out):
super(StdConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
class PoolBranch(nn.Module):
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__()
self.preproc = StdConv(C_in, C_out)
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = self.preproc(x)
out = self.pool(out)
out = self.bn(out)
return out
class SeparableConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SeparableConv, self).__init__()
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
groups=C_in, bias=False)
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class ConvBranch(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
super(ConvBranch, self).__init__()
self.preproc = StdConv(C_in, C_out)
if separable:
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
else:
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
self.postproc = nn.Sequential(
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
out = self.preproc(x)
out = self.conv(out)
out = self.postproc(out)
return out
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=False):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
from argparse import ArgumentParser
import torch
import torch.nn as nn import torch.nn as nn
import datasets from nni.nas.pytorch import mutables
from ops import FactorizedReduce, ConvBranch, PoolBranch from ops import FactorizedReduce, ConvBranch, PoolBranch
from nni.nas.pytorch import mutables, enas
class ENASLayer(nn.Module): class ENASLayer(mutables.MutableScope):
def __init__(self, layer_id, in_filters, out_filters): def __init__(self, key, num_prev_layers, in_filters, out_filters):
super().__init__() super().__init__(key)
self.in_filters = in_filters self.in_filters = in_filters
self.out_filters = out_filters self.out_filters = out_filters
self.mutable = mutables.LayerChoice([ self.mutable = mutables.LayerChoice([
...@@ -21,22 +18,19 @@ class ENASLayer(nn.Module): ...@@ -21,22 +18,19 @@ class ENASLayer(nn.Module):
PoolBranch('avg', in_filters, out_filters, 3, 1, 1), PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1) PoolBranch('max', in_filters, out_filters, 3, 1, 1)
]) ])
if layer_id > 0: if num_prev_layers > 0:
self.skipconnect = mutables.InputChoice(layer_id, n_selected=None, reduction="sum") self.skipconnect = mutables.InputChoice(num_prev_layers, n_selected=None, reduction="sum")
else: else:
self.skipconnect = None self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False) self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
self.mutable_scope = mutables.MutableScope("layer_{}".format(layer_id))
def forward(self, prev_layers): def forward(self, prev_layers, prev_labels):
with self.mutable_scope: out = self.mutable(prev_layers[-1])
out = self.mutable(prev_layers[-1]) if self.skipconnect is not None:
if self.skipconnect is not None: connection = self.skipconnect(prev_layers[:-1], tags=prev_labels)
connection = self.skipconnect(prev_layers[:-1], if connection is not None:
["layer_{}".format(i) for i in range(len(prev_layers) - 1)]) out += connection
if connection is not None: return self.batch_norm(out)
out += connection
return self.batch_norm(out)
class GeneralNetwork(nn.Module): class GeneralNetwork(nn.Module):
...@@ -62,7 +56,8 @@ class GeneralNetwork(nn.Module): ...@@ -62,7 +56,8 @@ class GeneralNetwork(nn.Module):
for layer_id in range(self.num_layers): for layer_id in range(self.num_layers):
if layer_id in self.pool_layers_idx: if layer_id in self.pool_layers_idx:
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters)) self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
self.layers.append(ENASLayer(layer_id, self.out_filters, self.out_filters)) self.layers.append(ENASLayer("layer_{}".format(layer_id), layer_id,
self.out_filters, self.out_filters))
self.gap = nn.AdaptiveAvgPool2d(1) self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(self.out_filters, self.num_classes) self.dense = nn.Linear(self.out_filters, self.num_classes)
...@@ -71,11 +66,12 @@ class GeneralNetwork(nn.Module): ...@@ -71,11 +66,12 @@ class GeneralNetwork(nn.Module):
bs = x.size(0) bs = x.size(0)
cur = self.stem(x) cur = self.stem(x)
layers = [cur] layers, labels = [cur], []
for layer_id in range(self.num_layers): for layer_id in range(self.num_layers):
cur = self.layers[layer_id](layers) cur = self.layers[layer_id](layers, labels)
layers.append(cur) layers.append(cur)
labels.append(self.layers[layer_id].key)
if layer_id in self.pool_layers_idx: if layer_id in self.pool_layers_idx:
for i, layer in enumerate(layers): for i, layer in enumerate(layers):
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer) layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
...@@ -85,58 +81,3 @@ class GeneralNetwork(nn.Module): ...@@ -85,58 +81,3 @@ class GeneralNetwork(nn.Module):
cur = self.dropout(cur) cur = self.dropout(cur)
logits = self.dense(cur) logits = self.dense(cur)
return logits return logits
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
def reward_accuracy(output, target, topk=(1,)):
batch_size = target.size(0)
_, predicted = torch.max(output.data, 1)
return (predicted == target).sum().item() / batch_size
if __name__ == "__main__":
parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=3, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = GeneralNetwork()
criterion = nn.CrossEntropyLoss()
n_epochs = 310
optim = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs, eta_min=0.001)
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optim,
lr_scheduler=lr_scheduler,
batch_size=args.batch_size,
num_epochs=n_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency)
trainer.train()
import torch
import torch.nn as nn
import torch.nn.functional as F
from nni.nas.pytorch import mutables
from ops import FactorizedReduce, StdConv, SepConvBN, Pool
class AuxiliaryHead(nn.Module):
def __init__(self, in_channels, num_classes):
super().__init__()
self.in_channels = in_channels
self.num_classes = num_classes
self.pooling = nn.Sequential(
nn.ReLU(),
nn.AvgPool2d(5, 3, 2)
)
self.proj = nn.Sequential(
StdConv(in_channels, 128),
StdConv(128, 768)
)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(768, 10, bias=False)
def forward(self, x):
bs = x.size(0)
x = self.pooling(x)
x = self.proj(x)
x = self.avg_pool(x).view(bs, -1)
x = self.fc(x)
return x
class Cell(nn.Module):
def __init__(self, cell_name, num_prev_layers, channels):
super().__init__()
self.input_choice = mutables.InputChoice(num_prev_layers, n_selected=1, return_mask=True,
key=cell_name + "_input")
self.op_choice = mutables.LayerChoice([
SepConvBN(channels, channels, 3, 1),
SepConvBN(channels, channels, 5, 2),
Pool("avg", 3, 1, 1),
Pool("max", 3, 1, 1),
nn.Identity()
], key=cell_name + "_op")
def forward(self, prev_layers, prev_labels):
chosen_input, chosen_mask = self.input_choice(prev_layers, tags=prev_labels)
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask
class Node(mutables.MutableScope):
def __init__(self, node_name, num_prev_layers, channels):
super().__init__(node_name)
self.cell_x = Cell(node_name + "_x", num_prev_layers, channels)
self.cell_y = Cell(node_name + "_y", num_prev_layers, channels)
def forward(self, prev_layers, prev_labels):
out_x, mask_x = self.cell_x(prev_layers, prev_labels)
out_y, mask_y = self.cell_y(prev_layers, prev_labels)
return out_x + out_y, mask_x | mask_y
class Calibration(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.process = None
if in_channels != out_channels:
self.process = StdConv(in_channels, out_channels)
def forward(self, x):
if self.process is None:
return x
return self.process(x)
class ReductionLayer(nn.Module):
def __init__(self, in_channels_pp, in_channels_p, out_channels):
super().__init__()
self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False)
self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False)
def forward(self, pprev, prev):
return self.reduce0(pprev), self.reduce1(prev)
class ENASLayer(nn.Module):
def __init__(self, num_nodes, in_channels_pp, in_channels_p, out_channels, reduction):
super().__init__()
self.preproc0 = Calibration(in_channels_pp, out_channels)
self.preproc1 = Calibration(in_channels_p, out_channels)
self.num_nodes = num_nodes
name_prefix = "reduce" if reduction else "normal"
self.nodes = nn.ModuleList([Node("{}_node_{}".format(name_prefix, i),
i + 2, out_channels) for i in range(num_nodes)])
self.final_conv_w = nn.Parameter(torch.zeros(out_channels, self.num_nodes + 2, out_channels, 1, 1), requires_grad=True)
self.bn = nn.BatchNorm2d(out_channels, affine=False)
self.reset_parameters()
def reset_parameters(self):
nn.init.kaiming_normal_(self.final_conv_w)
def forward(self, pprev, prev):
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)
prev_nodes_out = [pprev_, prev_]
prev_nodes_labels = ["prev1", "prev2"]
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
for i in range(self.num_nodes):
node_out, mask = self.nodes[i](prev_nodes_out, prev_nodes_labels)
nodes_used_mask[:mask.size(0)] |= mask
prev_nodes_out.append(node_out)
prev_nodes_labels.append(self.nodes[i].key)
unused_nodes = torch.cat([out for used, out in zip(nodes_used_mask, prev_nodes_out) if not used], 1)
unused_nodes = F.relu(unused_nodes)
conv_weight = self.final_conv_w[:, ~nodes_used_mask, :, :, :]
conv_weight = conv_weight.view(conv_weight.size(0), -1, 1, 1)
out = F.conv2d(unused_nodes, conv_weight)
return prev, self.bn(out)
class MicroNetwork(nn.Module):
def __init__(self, num_layers=2, num_nodes=5, out_channels=24, in_channels=3, num_classes=10,
dropout_rate=0.0, use_aux_heads=False):
super().__init__()
self.num_layers = num_layers
self.use_aux_heads = use_aux_heads
self.stem = nn.Sequential(
nn.Conv2d(in_channels, out_channels * 3, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_channels * 3)
)
pool_distance = self.num_layers // 3
pool_layers = [pool_distance, 2 * pool_distance + 1]
self.dropout = nn.Dropout(dropout_rate)
self.layers = nn.ModuleList()
c_pp = c_p = out_channels * 3
c_cur = out_channels
for layer_id in range(self.num_layers + 2):
reduction = False
if layer_id in pool_layers:
c_cur, reduction = c_p * 2, True
self.layers.append(ReductionLayer(c_pp, c_p, c_cur))
c_pp = c_p = c_cur
self.layers.append(ENASLayer(num_nodes, c_pp, c_p, c_cur, reduction))
if self.use_aux_heads and layer_id == pool_layers[-1] + 1:
self.layers.append(AuxiliaryHead(c_cur, num_classes))
c_pp, c_p = c_p, c_cur
self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(c_cur, num_classes)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
def forward(self, x):
bs = x.size(0)
prev = cur = self.stem(x)
aux_logits = None
for layer in self.layers:
if isinstance(layer, AuxiliaryHead):
if self.training:
aux_logits = layer(cur)
else:
prev, cur = layer(prev, cur)
cur = self.gap(F.relu(cur)).view(bs, -1)
cur = self.dropout(cur)
logits = self.dense(cur)
if aux_logits is not None:
return logits, aux_logits
return logits
...@@ -19,12 +19,7 @@ class PoolBranch(nn.Module): ...@@ -19,12 +19,7 @@ class PoolBranch(nn.Module):
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False): def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__() super().__init__()
self.preproc = StdConv(C_in, C_out) self.preproc = StdConv(C_in, C_out)
if pool_type.lower() == 'max': self.pool = Pool(pool_type, kernel_size, stride, padding)
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C_out, affine=affine) self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x): def forward(self, x):
...@@ -78,3 +73,31 @@ class FactorizedReduce(nn.Module): ...@@ -78,3 +73,31 @@ class FactorizedReduce(nn.Module):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1) out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out) out = self.bn(out)
return out return out
class Pool(nn.Module):
def __init__(self, pool_type, kernel_size, stride, padding):
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
def forward(self, x):
return self.pool(x)
class SepConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, padding):
super().__init__()
self.relu = nn.ReLU()
self.conv = SeparableConv(C_in, C_out, kernel_size, 1, padding)
self.bn = nn.BatchNorm2d(C_out, affine=True)
def forward(self, x):
x = self.relu(x)
x = self.conv(x)
x = self.bn(x)
return x
from argparse import ArgumentParser
import torch
import torch.nn as nn
import datasets
from macro import GeneralNetwork
from micro import MicroNetwork
from nni.nas.pytorch import enas
from nni.nas.pytorch.callbacks import LearningRateScheduler, ArchitectureCheckpoint
from utils import accuracy, reward_accuracy
if __name__ == "__main__":
parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
if args.search_for == "macro":
model = GeneralNetwork()
num_epochs = 310
mutator = None
elif args.search_for == "micro":
model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True)
num_epochs = 150
mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
else:
raise AssertionError
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
callbacks=[LearningRateScheduler(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency)
trainer.train_and_validate()
import torch
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
def reward_accuracy(output, target, topk=(1,)):
batch_size = target.size(0)
_, predicted = torch.max(output.data, 1)
return (predicted == target).sum().item() / batch_size
import logging
import torch.nn as nn
from nni.nas.pytorch.mutables import Mutable
logger = logging.getLogger(__name__)
class BaseMutator(nn.Module):
def __init__(self, model):
super().__init__()
self.__dict__["model"] = model
self.before_parse_search_space()
self._parse_search_space()
self.after_parse_search_space()
def before_parse_search_space(self):
pass
def after_parse_search_space(self):
pass
def _parse_search_space(self):
for name, mutable, _ in self.named_mutables(distinct=False):
mutable.name = name
mutable.set_mutator(self)
def named_mutables(self, root=None, distinct=True):
if root is None:
root = self.model
# if distinct is true, the method will filter out those with duplicated keys
key2module = dict()
for name, module in root.named_modules():
if isinstance(module, Mutable):
module_distinct = False
if module.key in key2module:
assert key2module[module.key].similar(module), \
"Mutable \"{}\" that share the same key must be similar to each other".format(module.key)
else:
module_distinct = True
key2module[module.key] = module
if distinct:
if module_distinct:
yield name, module
else:
yield name, module, module_distinct
def __setattr__(self, key, value):
if key in ["model", "net", "network"]:
logger.warning("Think twice if you are including the network into mutator.")
return super().__setattr__(key, value)
def forward(self, *inputs):
raise NotImplementedError("Mutator is not forward-able")
def enter_mutable_scope(self, mutable_scope):
pass
def exit_mutable_scope(self, mutable_scope):
pass
def on_forward_layer_choice(self, mutable, *inputs):
raise NotImplementedError
def on_forward_input_choice(self, mutable, tensor_list, tags):
raise NotImplementedError
def export(self):
raise NotImplementedError
from abc import ABC, abstractmethod
class BaseTrainer(ABC):
@abstractmethod
def train(self):
raise NotImplementedError
@abstractmethod
def validate(self):
raise NotImplementedError
@abstractmethod
def train_and_validate(self):
raise NotImplementedError
import json
import logging
import os
import torch
_logger = logging.getLogger(__name__)
class Callback:
def __init__(self):
self.model = None
self.mutator = None
self.trainer = None
def build(self, model, mutator, trainer):
self.model = model
self.mutator = mutator
self.trainer = trainer
def on_epoch_begin(self, epoch):
pass
def on_epoch_end(self, epoch):
pass
def on_batch_begin(self, epoch):
pass
def on_batch_end(self, epoch):
pass
class LearningRateScheduler(Callback):
def __init__(self, scheduler, mode="epoch"):
super().__init__()
assert mode == "epoch"
self.scheduler = scheduler
self.mode = mode
def on_epoch_end(self, epoch):
self.scheduler.step()
class ArchitectureCheckpoint(Callback):
class TorchTensorEncoder(json.JSONEncoder):
def default(self, o): # pylint: disable=method-hidden
if isinstance(o, torch.Tensor):
olist = o.tolist()
if "bool" not in o.type().lower() and all(map(lambda d: d == 0 or d == 1, olist)):
_logger.warning("Every element in %s is either 0 or 1. "
"You might consider convert it into bool.", olist)
return olist
return super().default(o)
def __init__(self, checkpoint_dir, every="epoch"):
super().__init__()
assert every == "epoch"
self.checkpoint_dir = checkpoint_dir
os.makedirs(self.checkpoint_dir, exist_ok=True)
def _export_to_file(self, file):
mutator_export = self.mutator.export()
with open(file, "w") as f:
json.dump(mutator_export, f, indent=2, sort_keys=True, cls=self.TorchTensorEncoder)
def on_epoch_end(self, epoch):
self._export_to_file(os.path.join(self.checkpoint_dir, "epoch_{}.json".format(epoch)))
from .mutator import DartsMutator from .mutator import DartsMutator
from .trainer import DartsTrainer from .trainer import DartsTrainer
from .cnn_cell import CnnCell from .scope import DartsNode
from .cnn_network import CnnNetwork \ No newline at end of file
...@@ -3,16 +3,34 @@ from torch import nn as nn ...@@ -3,16 +3,34 @@ from torch import nn as nn
from torch.nn import functional as F from torch.nn import functional as F
from nni.nas.pytorch.mutables import LayerChoice from nni.nas.pytorch.mutables import LayerChoice
from nni.nas.pytorch.mutator import PyTorchMutator from nni.nas.pytorch.mutator import Mutator
from .scope import DartsNode
class DartsMutator(PyTorchMutator): class DartsMutator(Mutator):
def before_build(self, model): def after_parse_search_space(self):
self.choices = nn.ParameterDict() self.choices = nn.ParameterDict()
for _, mutable in self.named_mutables():
def on_init_layer_choice(self, mutable: LayerChoice): if isinstance(mutable, LayerChoice):
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length)) self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(len(mutable) + 1))
def on_calc_layer_choice_mask(self, mutable: LayerChoice): def on_calc_layer_choice_mask(self, mutable: LayerChoice):
return F.softmax(self.choices[mutable.key], dim=-1) return F.softmax(self.choices[mutable.key], dim=-1)[:-1]
def export(self):
result = super().export()
for _, darts_node in self.named_mutables():
if isinstance(darts_node, DartsNode):
keys, edges_max = [], [] # key of all the layer choices in current node, and their best edge weight
for _, choice in self.named_mutables(darts_node):
if isinstance(choice, LayerChoice):
keys.append(choice.key)
max_val, index = torch.max(result[choice.key], 0)
edges_max.append(max_val)
result[choice.key] = F.one_hot(index, num_classes=len(result[choice.key])).view(-1).bool()
_, topk_edge_indices = torch.topk(torch.tensor(edges_max).view(-1), darts_node.limitation) # pylint: disable=not-callable
for i, key in enumerate(keys):
if i not in topk_edge_indices:
result[key] = torch.zeros_like(result[key])
return result
from nni.nas.pytorch.mutables import MutableScope
class DartsNode(MutableScope):
"""
At most `limitation` choice is activated in a `DartsNode` when exporting.
"""
def __init__(self, key, limitation):
super().__init__(key)
self.limitation = limitation
...@@ -4,32 +4,18 @@ import torch ...@@ -4,32 +4,18 @@ import torch
from torch import nn as nn from torch import nn as nn
from nni.nas.pytorch.trainer import Trainer from nni.nas.pytorch.trainer import Trainer
from nni.nas.utils import AverageMeterGroup, auto_device from nni.nas.utils import AverageMeterGroup
from .mutator import DartsMutator from .mutator import DartsMutator
class DartsTrainer(Trainer): class DartsTrainer(Trainer):
def __init__(self, model, loss, metrics, def __init__(self, model, loss, metrics,
model_optim, lr_scheduler, num_epochs, dataset_train, dataset_valid, optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None): mutator=None, batch_size=64, workers=4, device=None, log_frequency=None,
self.model = model callbacks=None):
self.loss = loss super().__init__(model, loss, metrics, optimizer, num_epochs,
self.metrics = metrics dataset_train, dataset_valid, batch_size, workers, device, log_frequency,
self.mutator = mutator mutator if mutator is not None else DartsMutator(model), callbacks)
if self.mutator is None:
self.mutator = DartsMutator(model)
self.model_optim = model_optim
self.lr_scheduler = lr_scheduler
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.device = auto_device() if device is None else device
self.log_frequency = log_frequency
self.model.to(self.device)
self.loss.to(self.device)
self.mutator.to(self.device)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999), self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999),
weight_decay=1.0E-3) weight_decay=1.0E-3)
n_train = len(self.dataset_train) n_train = len(self.dataset_train)
...@@ -46,10 +32,10 @@ class DartsTrainer(Trainer): ...@@ -46,10 +32,10 @@ class DartsTrainer(Trainer):
sampler=valid_sampler, sampler=valid_sampler,
num_workers=workers) num_workers=workers)
def train_epoch(self, epoch): def train_one_epoch(self, epoch):
self.model.train() self.model.train()
self.mutator.train() self.mutator.train()
lr = self.lr_scheduler.get_lr()[0] lr = self.optimizer.param_groups[0]["lr"]
meters = AverageMeterGroup() meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)): for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device) trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
...@@ -60,14 +46,14 @@ class DartsTrainer(Trainer): ...@@ -60,14 +46,14 @@ class DartsTrainer(Trainer):
# cannot deepcopy model because it will break the reference # cannot deepcopy model because it will break the reference
# phase 1. child network step # phase 1. child network step
self.model_optim.zero_grad() self.optimizer.zero_grad()
with self.mutator.forward_pass(): with self.mutator.forward_pass():
logits = self.model(trn_X) logits = self.model(trn_X)
loss = self.loss(logits, trn_y) loss = self.loss(logits, trn_y)
loss.backward() loss.backward()
# gradient clipping # gradient clipping
nn.utils.clip_grad_norm_(self.model.parameters(), 5.) nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
self.model_optim.step() self.optimizer.step()
new_model = copy.deepcopy(self.model.state_dict()) new_model = copy.deepcopy(self.model.state_dict())
...@@ -83,11 +69,9 @@ class DartsTrainer(Trainer): ...@@ -83,11 +69,9 @@ class DartsTrainer(Trainer):
metrics["loss"] = loss.item() metrics["loss"] = loss.item()
meters.update(metrics) meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
print("Epoch {} Step [{}/{}] {}".format(epoch, step, len(self.train_loader), meters)) print("Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, step, len(self.train_loader), meters))
self.lr_scheduler.step()
def validate_epoch(self, epoch): def validate_one_epoch(self, epoch):
self.model.eval() self.model.eval()
self.mutator.eval() self.mutator.eval()
meters = AverageMeterGroup() meters = AverageMeterGroup()
...@@ -99,17 +83,7 @@ class DartsTrainer(Trainer): ...@@ -99,17 +83,7 @@ class DartsTrainer(Trainer):
metrics = self.metrics(logits, y) metrics = self.metrics(logits, y)
meters.update(metrics) meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
print("Epoch {} Step [{}/{}] {}".format(epoch, step, len(self.valid_loader), meters)) print("Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs, step, len(self.valid_loader), meters))
def train(self):
for epoch in range(self.num_epochs):
# training
print("Epoch {} Training".format(epoch))
self.train_epoch(epoch)
# validation
print("Epoch {} Validating".format(epoch))
self.validate_epoch(epoch)
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr): def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr):
""" """
...@@ -160,6 +134,3 @@ class DartsTrainer(Trainer): ...@@ -160,6 +134,3 @@ class DartsTrainer(Trainer):
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)] hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian return hessian
def export(self):
pass
...@@ -2,7 +2,8 @@ import torch ...@@ -2,7 +2,8 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from nni.nas.pytorch.mutator import PyTorchMutator from nni.nas.pytorch.mutables import LayerChoice
from nni.nas.pytorch.mutator import Mutator
class StackedLSTMCell(nn.Module): class StackedLSTMCell(nn.Module):
...@@ -23,35 +24,49 @@ class StackedLSTMCell(nn.Module): ...@@ -23,35 +24,49 @@ class StackedLSTMCell(nn.Module):
return next_c, next_h return next_c, next_h
class EnasMutator(PyTorchMutator): class EnasMutator(Mutator):
def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, anchor_extra_step=False, def __init__(self, model, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5, cell_exit_extra_step=False,
skip_target=0.4): skip_target=0.4, branch_bias=0.25):
self.lstm_size = lstm_size self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant self.tanh_constant = tanh_constant
self.max_layer_choice = 0 self.cell_exit_extra_step = cell_exit_extra_step
self.anchor_extra_step = anchor_extra_step
self.skip_target = skip_target self.skip_target = skip_target
self.branch_bias = branch_bias
super().__init__(model) super().__init__(model)
def before_build(self, model): def before_parse_search_space(self):
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False) self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.Tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) # pylint: disable=not-callable
self.cross_entropy_loss = nn.CrossEntropyLoss() self.cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
self.bias_dict = nn.ParameterDict()
def after_parse_search_space(self):
self.max_layer_choice = 0
for _, mutable in self.named_mutables():
if isinstance(mutable, LayerChoice):
if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \
"ENAS mutator requires all layer choice have the same number of candidates."
# NOTE(yuge): We might implement an interface later. Judging by key now.
if "reduce" in mutable.key:
def is_conv(choice):
return "conv" in str(type(choice)).lower()
bias = torch.tensor([self.branch_bias if is_conv(choice) else -self.branch_bias # pylint: disable=not-callable
for choice in mutable.choices])
self.bias_dict[mutable.key] = nn.Parameter(bias, requires_grad=False)
def after_build(self, model):
self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size) self.embedding = nn.Embedding(self.max_layer_choice + 1, self.lstm_size)
self.soft = nn.Linear(self.lstm_size, self.max_layer_choice) self.soft = nn.Linear(self.lstm_size, self.max_layer_choice, bias=False)
def before_pass(self): def before_pass(self):
super().before_pass() super().before_pass()
self._anchors_hid = dict() self._anchors_hid = dict()
self._selected_layers = []
self._selected_inputs = []
self._inputs = self.g_emb.data self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size), self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype, dtype=self._inputs.dtype,
...@@ -69,58 +84,58 @@ class EnasMutator(PyTorchMutator): ...@@ -69,58 +84,58 @@ class EnasMutator(PyTorchMutator):
def _mark_anchor(self, key): def _mark_anchor(self, key):
self._anchors_hid[key] = self._h[-1] self._anchors_hid[key] = self._h[-1]
def on_init_layer_choice(self, mutable):
if self.max_layer_choice == 0:
self.max_layer_choice = mutable.length
assert self.max_layer_choice == mutable.length, \
"ENAS mutator requires all layer choice have the same number of candidates."
def on_calc_layer_choice_mask(self, mutable): def on_calc_layer_choice_mask(self, mutable):
self._lstm_next_step() self._lstm_next_step()
logit = self.soft(self._h[-1]) logit = self.soft(self._h[-1])
if self.tanh_constant is not None: if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit) logit = self.tanh_constant * torch.tanh(logit)
if mutable.key in self.bias_dict:
logit += self.bias_dict[mutable.key]
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id) log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += log_prob self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() entropy = (log_prob * torch.exp(-log_prob)).detach()
self.sample_entropy += entropy self.sample_entropy += torch.sum(entropy)
self._inputs = self.embedding(branch_id) self._inputs = self.embedding(branch_id)
self._selected_layers.append(branch_id.item()) return F.one_hot(branch_id, num_classes=self.max_layer_choice).bool().view(-1)
return F.one_hot(branch_id).bool().view(-1)
def on_calc_input_choice_mask(self, mutable, tags):
query, anchors = [], []
for label in tags:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
if self.tanh_constant is not None:
query = self.tanh_constant * torch.tanh(query)
def on_calc_input_choice_mask(self, mutable, semantic_labels):
if mutable.n_selected is None: if mutable.n_selected is None:
query, anchors = [], []
for label in semantic_labels:
if label not in self._anchors_hid:
self._lstm_next_step()
self._mark_anchor(label) # empty loop, fill not found
query.append(self.attn_anchor(self._anchors_hid[label]))
anchors.append(self._anchors_hid[label])
query = torch.cat(query, 0)
query = torch.tanh(query + self.attn_query(self._h[-1]))
query = self.v_attn(query)
logit = torch.cat([-query, query], 1) logit = torch.cat([-query, query], 1)
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) skip = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit) skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets)) kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, skip) log_prob = self.cross_entropy_loss(logit, skip)
self.sample_log_prob += torch.sum(log_prob) self._inputs = (torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))).unsqueeze(0)
entropy = (log_prob * torch.exp(-log_prob)).detach()
self.sample_entropy += torch.sum(entropy)
self.inputs = torch.matmul(skip.float(), torch.cat(anchors, 0)) / (1. + torch.sum(skip))
self._selected_inputs.append(skip)
return skip.bool()
else: else:
assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS." assert mutable.n_selected == 1, "Input choice must select exactly one or any in ENAS."
raise NotImplementedError logit = query.view(1, -1)
index = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip = F.one_hot(index).view(-1)
log_prob = self.cross_entropy_loss(logit, index)
self._inputs = anchors[index.item()]
self.sample_log_prob += torch.sum(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach()
self.sample_entropy += torch.sum(entropy)
return skip.bool()
def exit_mutable_scope(self, mutable_scope): def exit_mutable_scope(self, mutable_scope):
if self.cell_exit_extra_step:
self._lstm_next_step()
self._mark_anchor(mutable_scope.key) self._mark_anchor(mutable_scope.key)
...@@ -2,39 +2,29 @@ import torch ...@@ -2,39 +2,29 @@ import torch
import torch.optim as optim import torch.optim as optim
from nni.nas.pytorch.trainer import Trainer from nni.nas.pytorch.trainer import Trainer
from nni.nas.utils import AverageMeterGroup, auto_device from nni.nas.utils import AverageMeterGroup
from .mutator import EnasMutator from .mutator import EnasMutator
class EnasTrainer(Trainer): class EnasTrainer(Trainer):
def __init__(self, model, loss, metrics, reward_function, def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset_train, dataset_valid, lr_scheduler=None, optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, mutator=None, batch_size=64, workers=4, device=None, log_frequency=None, callbacks=None,
entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999, entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
mutator_lr=0.00035): mutator_lr=0.00035, mutator_steps_aggregate=20, mutator_steps=50, aux_weight=0.4):
self.model = model super().__init__(model, loss, metrics, optimizer, num_epochs,
self.loss = loss dataset_train, dataset_valid, batch_size, workers, device, log_frequency,
self.metrics = metrics mutator if mutator is not None else EnasMutator(model), callbacks)
self.reward_function = reward_function self.reward_function = reward_function
self.mutator = mutator self.mutator_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
if self.mutator is None:
self.mutator = EnasMutator(model)
self.optim = optimizer
self.mut_optim = optim.Adam(self.mutator.parameters(), lr=mutator_lr)
self.lr_scheduler = lr_scheduler
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.device = auto_device() if device is None else device
self.log_frequency = log_frequency
self.entropy_weight = entropy_weight self.entropy_weight = entropy_weight
self.skip_weight = skip_weight self.skip_weight = skip_weight
self.baseline_decay = baseline_decay self.baseline_decay = baseline_decay
self.baseline = 0. self.baseline = 0.
self.mutator_steps_aggregate = mutator_steps_aggregate
self.model.to(self.device) self.mutator_steps = mutator_steps
self.loss.to(self.device) self.aux_weight = aux_weight
self.mutator.to(self.device)
n_train = len(self.dataset_train) n_train = len(self.dataset_train)
split = n_train // 10 split = n_train // 10
...@@ -53,68 +43,76 @@ class EnasTrainer(Trainer): ...@@ -53,68 +43,76 @@ class EnasTrainer(Trainer):
batch_size=batch_size, batch_size=batch_size,
num_workers=workers) num_workers=workers)
def train_epoch(self, epoch): def train_one_epoch(self, epoch):
# Sample model and train
self.model.train() self.model.train()
self.mutator.train() self.mutator.eval()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
with self.mutator.forward_pass():
logits = self.model(x)
for phase in ["model", "mutator"]: if isinstance(logits, tuple):
if phase == "model": logits, aux_logits = logits
self.model.train() aux_loss = self.loss(aux_logits, y)
self.mutator.eval()
else: else:
self.model.eval() aux_loss = 0.
self.mutator.train() metrics = self.metrics(logits, y)
loader = self.train_loader if phase == "model" else self.valid_loader loss = self.loss(logits, y)
meters = AverageMeterGroup() loss = loss + self.aux_weight * aux_loss
for step, (x, y) in enumerate(loader): loss.backward()
self.optimizer.step()
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
print("Model Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs,
step, len(self.train_loader), meters))
# Train sampler (mutator)
self.model.eval()
self.mutator.train()
meters = AverageMeterGroup()
mutator_step, total_mutator_steps = 0, self.mutator_steps * self.mutator_steps_aggregate
while mutator_step < total_mutator_steps:
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device) x, y = x.to(self.device), y.to(self.device)
self.optim.zero_grad()
self.mut_optim.zero_grad()
with self.mutator.forward_pass(): with self.mutator.forward_pass():
logits = self.model(x) logits = self.model(x)
metrics = self.metrics(logits, y) metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if phase == "model": if self.entropy_weight is not None:
loss = self.loss(logits, y) reward += self.entropy_weight * self.mutator.sample_entropy
loss.backward() self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
self.optim.step() self.baseline = self.baseline.detach().item()
else: loss = self.mutator.sample_log_prob * (reward - self.baseline)
reward = self.reward_function(logits, y) if self.skip_weight:
if self.entropy_weight is not None: loss += self.skip_weight * self.mutator.sample_skip_penalty
reward += self.entropy_weight * self.mutator.sample_entropy metrics["reward"] = reward
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
self.baseline = self.baseline.detach().item()
loss = self.mutator.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.mutator.sample_skip_penalty
loss.backward()
self.mut_optim.step()
metrics["reward"] = reward
metrics["loss"] = loss.item() metrics["loss"] = loss.item()
meters.update(metrics) metrics["ent"] = self.mutator.sample_entropy.item()
metrics["baseline"] = self.baseline
if self.log_frequency is not None and step % self.log_frequency == 0: metrics["skip"] = self.mutator.sample_skip_penalty
print("Epoch {} {} Step [{}/{}] {}".format(epoch, phase.capitalize(), step,
len(loader), meters))
# print(self.mutator._selected_layers)
# print(self.mutator._selected_inputs)
if self.lr_scheduler is not None:
self.lr_scheduler.step()
def validate_epoch(self, epoch): loss = loss / self.mutator_steps_aggregate
pass loss.backward()
meters.update(metrics)
def train(self): if mutator_step % self.mutator_steps_aggregate == 0:
for epoch in range(self.num_epochs): self.mutator_optim.step()
# training self.mutator_optim.zero_grad()
print("Epoch {} Training".format(epoch))
self.train_epoch(epoch)
# validation if self.log_frequency is not None and step % self.log_frequency == 0:
print("Epoch {} Validating".format(epoch)) print("Mutator Epoch [{}/{}] Step [{}/{}] {}".format(epoch, self.num_epochs,
self.validate_epoch(epoch) mutator_step // self.mutator_steps_aggregate,
self.mutator_steps, meters))
mutator_step += 1
if mutator_step >= total_mutator_steps:
break
def export(self): def validate_one_epoch(self, epoch):
pass pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment