Commit d43fbe82 authored by quzha's avatar quzha
Browse files

Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-nas-refactor

parents 0e3906aa bb797e10
# NNI Programming Interface for Neural Architecture Search (NAS)
*This is an experimental feature, programming APIs are almost done, NAS trainers are under intensive development. ([NAS annotation](../AdvancedFeature/GeneralNasInterfaces.md) will become deprecated in future)*
Automatic neural architecture search is taking an increasingly important role on finding better models. Recent research works have proved the feasibility of automatic NAS, and also found some models that could beat manually designed and tuned models. Some of representative works are [NASNet][2], [ENAS][1], [DARTS][3], [Network Morphism][4], and [Evolution][5]. There are new innovations keeping emerging. However, it takes great efforts to implement those algorithms, and it is hard to reuse code base of one algorithm for implementing another.
To facilitate NAS innovations (e.g., design/implement new NAS models, compare different NAS models side-by-side), an easy-to-use and flexible programming interface is crucial.
## Programming interface
A new programming interface for designing and searching for a model is often demanded in two scenarios.
1. When designing a neural network, the designer may have multiple choices for a layer, sub-model, or connection, and not sure which one or a combination performs the best. It would be appealing to have an easy way to express the candidate layers/sub-models they want to try.
2. For the researchers who are working on automatic NAS, they want to have an unified way to express the search space of neural architectures. And making unchanged trial code adapted to different searching algorithms.
For expressing neural architecture search space, we provide two APIs:
```python
# choose one ``op`` from ``ops``, for pytorch this is a module.
# ops: for pytorch ``ops`` is a list of modules, for tensorflow it is a list of keras layers. An example in pytroch:
# ops = [PoolBN('max', channels, 3, stride, 1, affine=False),
# PoolBN('avg', channels, 3, stride, 1, affine=False),
# FactorizedReduce(channels, channels, affine=False),
# SepConv(channels, channels, 3, stride, 1, affine=False),
# DilConv(channels, channels, 3, stride, 2, 2, affine=False)]
# key: the name of this ``LayerChoice`` instance
nni.nas.LayerChoice(ops, key)
# choose ``n_selected`` from ``n_candidates`` inputs.
# n_candidates: the number of candidate inputs
# n_selected: the number of chosen inputs
# reduction: reduction operation for the chosen inputs
# key: the name of this ``InputChoice`` instance
nni.nas.InputChoice(n_candidates, n_selected, reduction, key)
```
After writing your model with search space embedded in the model using the above two APIs, the next step is finding the best model from the search space. Similar to optimizers of deep learning models, the procedure of finding the best model from search space can be viewed as a type of optimizing process, we call it `NAS trainer`. There have been several NAS trainers, for example, `DartsTrainer` which uses SGD to train architecture weights and model weights iteratively, `ENASTrainer` which uses a controller to train the model. New and more efficient NAS trainers keep emerging in research community.
NNI provides some popular NAS trainers, to use a NAS trainer, users could initialize a trainer after the model is defined:
```python
# create a DartsTrainer
trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
model_optim=optim,
lr_scheduler=lr_scheduler,
num_epochs=50,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
# finding the best model from search space
trainer.train()
# export the best found model
trainer.export_model()
```
Different trainers could have different input arguments depending on their algorithms. After training, users could export the best one of the found models through `trainer.export_model()`.
[Here](https://github.com/microsoft/nni/blob/dev-nas-refactor/examples/nas/darts/main.py) is a trial example using DartsTrainer.
[1]: https://arxiv.org/abs/1802.03268
[2]: https://arxiv.org/abs/1707.07012
[3]: https://arxiv.org/abs/1806.09055
[4]: https://arxiv.org/abs/1806.10282
[5]: https://arxiv.org/abs/1703.01041
\ No newline at end of file
from torchvision import transforms
from torchvision.datasets import CIFAR10
def get_dataset(cls):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]
train_transform = transforms.Compose(transf + normalize)
valid_transform = transforms.Compose(normalize)
if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
import torch
import torch.nn as nn
import ops
from nni.nas import pytorch as nas
class SearchCell(nn.Module):
"""
Cell for search.
"""
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(self.n_nodes):
self.mutable_ops.append(nn.ModuleList())
for i in range(2 + depth): # include 2 input nodes
# reduction should be used only for input node
stride = 2 if reduction and i < 2 else 1
op = nas.mutables.LayerChoice([ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
ops.Identity() if stride == 1 else
ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False),
ops.Zero(stride)],
key="r{}_d{}_i{}".format(reduction, depth, i))
self.mutable_ops[depth].append(op)
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
for ops in self.mutable_ops:
assert len(ops) == len(tensors)
cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors))
tensors.append(cur_tensor)
output = torch.cat(tensors[2:], dim=1)
return output
class SearchCNN(nn.Module):
"""
Search CNN model
"""
def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3):
"""
Initializing a search channelsNN.
Parameters
----------
in_channels: int
Number of channels in images.
channels: int
Number of channels used in the network.
n_classes: int
Number of classes.
n_layers: int
Number of cells in the whole network.
n_nodes: int
Number of nodes in a cell.
stem_multiplier: int
Multiplier of channels in STEM.
"""
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.n_classes = n_classes
self.n_layers = n_layers
c_cur = stem_multiplier * self.channels
self.stem = nn.Sequential(
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
nn.BatchNorm2d(c_cur)
)
# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp, channels_p, c_cur = c_cur, c_cur, channels
self.cells = nn.ModuleList()
reduction_p, reduction = False, False
for i in range(n_layers):
reduction_p, reduction = reduction, False
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
if i in [n_layers // 3, 2 * n_layers // 3]:
c_cur *= 2
reduction = True
cell = SearchCell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out
self.gap = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(channels_p, n_classes)
def forward(self, x):
s0 = s1 = self.stem(x)
for cell in self.cells:
s0, s1 = s1, cell(s0, s1)
out = self.gap(s1)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
return logits
import torch
import torch.nn as nn
PRIMITIVES = [
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect', # identity
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'none'
]
OPS = {
'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine),
'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine),
'skip_connect': lambda C, stride, affine: \
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
}
def drop_path_(x, drop_prob, training):
if training and drop_prob > 0.:
keep_prob = 1. - drop_prob
# per data point mask; assuming x in cuda.
mask = torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)
x.div_(keep_prob).mul_(mask)
return x
class DropPath_(nn.Module):
def __init__(self, p=0.):
""" [!] DropPath is inplace module
Args:
p: probability of an path to be zeroed.
"""
super().__init__()
self.p = p
def extra_repr(self):
return 'p={}, inplace'.format(self.p)
def forward(self, x):
drop_path_(x, self.p, self.training)
return x
class PoolBN(nn.Module):
"""
AvgPool or MaxPool - BN
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
"""
Args:
pool_type: 'max' or 'avg'
"""
super().__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C, affine=affine)
def forward(self, x):
out = self.pool(x)
out = self.bn(out)
return out
class StdConv(nn.Module):
""" Standard conv
ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class FacConv(nn.Module):
""" Factorized conv
ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class DilConv(nn.Module):
""" (Dilated) depthwise separable conv
ReLU - (Dilated) depthwise separable - Pointwise - BN
If dilation == 2, 3x3 conv => 5x5 receptive field
5x5 conv => 9x9 receptive field
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
class SepConv(nn.Module):
""" Depthwise separable conv
DilConv(dilation=1) * 2
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
)
def forward(self, x):
return self.net(x)
class Identity(nn.Module):
def forward(self, x):
return x
class Zero(nn.Module):
def __init__(self, stride):
super().__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x * 0.
# re-sizing by stride
return x[:, :, ::self.stride, ::self.stride] * 0.
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise(stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
from argparse import ArgumentParser
import datasets
import torch
import torch.nn as nn
from model import SearchCNN
from nni.nas.pytorch.darts import DartsTrainer
from utils import accuracy
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=4, type=int)
parser.add_argument("--nodes", default=2, type=int)
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, n_epochs, eta_min=0.001)
trainer = DartsTrainer(model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
model_optim=optim,
lr_scheduler=lr_scheduler,
num_epochs=50,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
trainer.train()
trainer.export()
# augment step
# ...
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
\ No newline at end of file
from torchvision import transforms
from torchvision.datasets import CIFAR10
def get_dataset(cls):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]
train_transform = transforms.Compose(transf + normalize)
valid_transform = transforms.Compose(normalize)
if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
import torch
import torch.nn as nn
class StdConv(nn.Module):
def __init__(self, C_in, C_out):
super(StdConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
class PoolBranch(nn.Module):
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__()
self.preproc = StdConv(C_in, C_out)
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = self.preproc(x)
out = self.pool(out)
out = self.bn(out)
return out
class SeparableConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SeparableConv, self).__init__()
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
groups=C_in, bias=False)
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class ConvBranch(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
super(ConvBranch, self).__init__()
self.preproc = StdConv(C_in, C_out)
if separable:
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
else:
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
self.postproc = nn.Sequential(
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
out = self.preproc(x)
out = self.conv(out)
out = self.postproc(out)
return out
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=False):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
from argparse import ArgumentParser
import torch
import torch.nn as nn
import datasets
from ops import FactorizedReduce, ConvBranch, PoolBranch
from nni.nas.pytorch import mutables, enas
class ENASLayer(nn.Module):
def __init__(self, layer_id, in_filters, out_filters):
super().__init__()
self.in_filters = in_filters
self.out_filters = out_filters
self.mutable = mutables.LayerChoice([
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=False),
ConvBranch(in_filters, out_filters, 3, 1, 1, separable=True),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=False),
ConvBranch(in_filters, out_filters, 5, 1, 2, separable=True),
PoolBranch('avg', in_filters, out_filters, 3, 1, 1),
PoolBranch('max', in_filters, out_filters, 3, 1, 1)
])
if layer_id > 0:
self.skipconnect = mutables.InputChoice(layer_id, n_selected=None, reduction="sum")
else:
self.skipconnect = None
self.batch_norm = nn.BatchNorm2d(out_filters, affine=False)
self.mutable_scope = mutables.MutableScope("layer_{}".format(layer_id))
def forward(self, prev_layers):
with self.mutable_scope:
out = self.mutable(prev_layers[-1])
if self.skipconnect is not None:
connection = self.skipconnect(prev_layers[:-1],
["layer_{}".format(i) for i in range(len(prev_layers) - 1)])
if connection is not None:
out += connection
return self.batch_norm(out)
class GeneralNetwork(nn.Module):
def __init__(self, num_layers=12, out_filters=24, in_channels=3, num_classes=10,
dropout_rate=0.0):
super().__init__()
self.num_layers = num_layers
self.num_classes = num_classes
self.out_filters = out_filters
self.stem = nn.Sequential(
nn.Conv2d(in_channels, out_filters, 3, 1, 1, bias=False),
nn.BatchNorm2d(out_filters)
)
pool_distance = self.num_layers // 3
self.pool_layers_idx = [pool_distance - 1, 2 * pool_distance - 1]
self.dropout_rate = dropout_rate
self.dropout = nn.Dropout(self.dropout_rate)
self.layers = nn.ModuleList()
self.pool_layers = nn.ModuleList()
for layer_id in range(self.num_layers):
if layer_id in self.pool_layers_idx:
self.pool_layers.append(FactorizedReduce(self.out_filters, self.out_filters))
self.layers.append(ENASLayer(layer_id, self.out_filters, self.out_filters))
self.gap = nn.AdaptiveAvgPool2d(1)
self.dense = nn.Linear(self.out_filters, self.num_classes)
def forward(self, x):
bs = x.size(0)
cur = self.stem(x)
layers = [cur]
for layer_id in range(self.num_layers):
cur = self.layers[layer_id](layers)
layers.append(cur)
if layer_id in self.pool_layers_idx:
for i, layer in enumerate(layers):
layers[i] = self.pool_layers[self.pool_layers_idx.index(layer_id)](layer)
cur = layers[-1]
cur = self.gap(cur).view(bs, -1)
cur = self.dropout(cur)
logits = self.dense(cur)
return logits
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
def reward_accuracy(output, target, topk=(1,)):
batch_size = target.size(0)
_, predicted = torch.max(output.data, 1)
return (predicted == target).sum().item() / batch_size
if __name__ == "__main__":
parser = ArgumentParser("enas")
parser.add_argument("--batch-size", default=3, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = GeneralNetwork()
criterion = nn.CrossEntropyLoss()
n_epochs = 310
optim = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=n_epochs, eta_min=0.001)
trainer = enas.EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optim,
lr_scheduler=lr_scheduler,
batch_size=args.batch_size,
num_epochs=n_epochs,
dataset_train=dataset_train,
dataset_valid=dataset_valid,
log_frequency=args.log_frequency)
trainer.train()
import torch
import torch.nn as nn
class StdConv(nn.Module):
def __init__(self, C_in, C_out):
super(StdConv, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
return self.conv(x)
class PoolBranch(nn.Module):
def __init__(self, pool_type, C_in, C_out, kernel_size, stride, padding, affine=False):
super().__init__()
self.preproc = StdConv(C_in, C_out)
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = self.preproc(x)
out = self.pool(out)
out = self.bn(out)
return out
class SeparableConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding):
super(SeparableConv, self).__init__()
self.depthwise = nn.Conv2d(C_in, C_in, kernel_size=kernel_size, padding=padding, stride=stride,
groups=C_in, bias=False)
self.pointwise = nn.Conv2d(C_in, C_out, kernel_size=1, bias=False)
def forward(self, x):
out = self.depthwise(x)
out = self.pointwise(out)
return out
class ConvBranch(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, separable):
super(ConvBranch, self).__init__()
self.preproc = StdConv(C_in, C_out)
if separable:
self.conv = SeparableConv(C_out, C_out, kernel_size, stride, padding)
else:
self.conv = nn.Conv2d(C_out, C_out, kernel_size, stride=stride, padding=padding)
self.postproc = nn.Sequential(
nn.BatchNorm2d(C_out, affine=False),
nn.ReLU()
)
def forward(self, x):
out = self.preproc(x)
out = self.conv(out)
out = self.postproc(out)
return out
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=False):
super().__init__()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
from .mutator import DartsMutator
from .trainer import DartsTrainer
import torch
from torch import nn as nn
from torch.nn import functional as F
from nni.nas.pytorch.mutables import LayerChoice
from nni.nas.pytorch.mutator import PyTorchMutator
class DartsMutator(PyTorchMutator):
def before_build(self, model):
self.choices = nn.ParameterDict()
def on_init_layer_choice(self, mutable: LayerChoice):
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length))
def on_calc_layer_choice_mask(self, mutable: LayerChoice):
return F.softmax(self.choices[mutable.key], dim=-1)
import copy
import torch
from torch import nn as nn
from nni.nas.pytorch.trainer import Trainer
from nni.nas.utils import AverageMeterGroup, auto_device
from .mutator import DartsMutator
class DartsTrainer(Trainer):
def __init__(self, model, loss, metrics,
model_optim, lr_scheduler, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None):
self.model = model
self.loss = loss
self.metrics = metrics
self.mutator = mutator
if self.mutator is None:
self.mutator = DartsMutator(model)
self.model_optim = model_optim
self.lr_scheduler = lr_scheduler
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.device = auto_device() if device is None else device
self.log_frequency = log_frequency
self.model.to(self.device)
self.loss.to(self.device)
self.mutator.to(self.device)
self.ctrl_optim = torch.optim.Adam(self.mutator.parameters(), 3.0E-4, betas=(0.5, 0.999),
weight_decay=1.0E-3)
n_train = len(self.dataset_train)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=train_sampler,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
sampler=valid_sampler,
num_workers=workers)
def train_epoch(self, epoch):
self.model.train()
self.mutator.train()
lr = self.lr_scheduler.get_lr()[0]
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
# backup model for hessian
backup_model = copy.deepcopy(self.model.state_dict())
# cannot deepcopy model because it will break the reference
# phase 1. child network step
self.model_optim.zero_grad()
with self.mutator.forward_pass():
logits = self.model(trn_X)
loss = self.loss(logits, trn_y)
loss.backward()
# gradient clipping
nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
self.model_optim.step()
new_model = copy.deepcopy(self.model.state_dict())
# phase 2. architect step (alpha)
self.ctrl_optim.zero_grad()
# compute unrolled loss
self._unrolled_backward(trn_X, trn_y, val_X, val_y, backup_model, lr)
self.ctrl_optim.step()
self.model.load_state_dict(new_model)
metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
print("Epoch {} Step [{}/{}] {}".format(epoch, step, len(self.train_loader), meters))
self.lr_scheduler.step()
def validate_epoch(self, epoch):
self.model.eval()
self.mutator.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (X, y) in enumerate(self.valid_loader):
X, y = X.to(self.device), y.to(self.device)
logits = self.model(X)
metrics = self.metrics(logits, y)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
print("Epoch {} Step [{}/{}] {}".format(epoch, step, len(self.valid_loader), meters))
def train(self):
for epoch in range(self.num_epochs):
# training
print("Epoch {} Training".format(epoch))
self.train_epoch(epoch)
# validation
print("Epoch {} Validating".format(epoch))
self.validate_epoch(epoch)
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y, backup_model, lr):
"""
Compute unrolled loss and backward its gradients
Parameters
----------
v_model: backup model before this step
lr: learning rate for virtual gradient step (same as net lr)
"""
with self.mutator.forward_pass():
loss = self.loss(self.model(val_X), val_y)
w_model = tuple(self.model.parameters())
w_ctrl = tuple(self.mutator.parameters())
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model = w_grads[:len(w_model)]
d_ctrl = w_grads[len(w_model):]
hessian = self._compute_hessian(backup_model, d_model, trn_X, trn_y)
with torch.no_grad():
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
param.grad = d - lr * h
def _compute_hessian(self, model, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self.model.load_state_dict(model)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.model.parameters(), dw):
p += eps * d
with self.mutator.forward_pass():
loss = self.loss(self.model(trn_X), trn_y)
if e > 0:
dalpha_pos = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w+) }
elif e < 0:
dalpha_neg = torch.autograd.grad(loss, self.mutator.parameters()) # dalpha { L_trn(w-) }
hessian = [(p - n) / 2. * eps for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian
def export(self):
pass
from .mutator import EnasMutator
from .trainer import EnasTrainer
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment