Commit 3ddab980 authored by quzha's avatar quzha
Browse files

Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-nas-refactor

parents 594924a9 d1d10de7
......@@ -80,6 +80,7 @@ venv.bak/
# VSCode
.vscode
.vs
# In case you place source code in ~/nni/
/experiments
from argparse import ArgumentParser
import datasets
import torch
import torch.nn as nn
from model import SearchCNN
from nni.nas.pytorch.darts import DartsTrainer
import datasets
from nni.nas.pytorch.darts import CnnNetwork, DartsTrainer
from utils import accuracy
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=4, type=int)
parser.add_argument("--nodes", default=2, type=int)
parser.add_argument("--layers", default=5, type=int)
parser.add_argument("--nodes", default=4, type=int)
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes)
model = CnnNetwork(3, 16, 10, args.layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
......
from torchvision import transforms
from torchvision.datasets import CIFAR10
def get_dataset(cls):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]
train_transform = transforms.Compose(transf + normalize)
valid_transform = transforms.Compose(normalize)
if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
from argparse import ArgumentParser
import datasets
import torch
import torch.nn as nn
import nni.nas.pytorch as nas
from nni.nas.pytorch.pdarts import PdartsTrainer
from nni.nas.pytorch.darts import CnnNetwork, CnnCell
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=5, type=int)
parser.add_argument('--add_layers', action='append',
default=[0, 6, 12], help='add layers')
parser.add_argument("--nodes", default=4, type=int)
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
def model_creator(layers, n_nodes):
model = CnnNetwork(3, 16, 10, layers, n_nodes=n_nodes, cell_type=CnnCell)
loss = nn.CrossEntropyLoss()
model_optim = torch.optim.SGD(model.parameters(), 0.025,
momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, n_epochs, eta_min=0.001)
return model, loss, model_optim, lr_scheduler
trainer = PdartsTrainer(model_creator,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
num_epochs=50,
pdarts_num_layers=[0, 6, 12],
pdarts_num_to_drop=[3, 2, 2],
dataset_train=dataset_train,
dataset_valid=dataset_valid,
layers=args.layers,
n_nodes=args.nodes,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
trainer.train()
trainer.export()
from .mutator import DartsMutator
from .trainer import DartsTrainer
from .cnn_cell import CnnCell
from .cnn_network import CnnNetwork
import torch
import torch.nn as nn
import nni.nas.pytorch as nas
from nni.nas.pytorch.modules import RankedModule
from .cnn_ops import OPS, PRIMITIVES, FactorizedReduce, StdConv
class CnnCell(RankedModule):
"""
Cell for search.
"""
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super(CnnCell, self).__init__(rank=1, reduction=reduction)
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(self.n_nodes):
self.mutable_ops.append(nn.ModuleList())
for i in range(2 + depth): # include 2 input nodes
# reduction should be used only for input node
stride = 2 if reduction and i < 2 else 1
m_ops = []
for primitive in PRIMITIVES:
op = OPS[primitive](channels, stride, False)
m_ops.append(op)
op = nas.mutables.LayerChoice(m_ops, key="r{}_d{}_i{}".format(reduction, depth, i))
self.mutable_ops[depth].append(op)
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
for ops in self.mutable_ops:
assert len(ops) == len(tensors)
cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors))
tensors.append(cur_tensor)
output = torch.cat(tensors[2:], dim=1)
return output
import torch
import torch.nn as nn
import ops
from nni.nas import pytorch as nas
class SearchCell(nn.Module):
"""
Cell for search.
"""
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(self.n_nodes):
self.mutable_ops.append(nn.ModuleList())
for i in range(2 + depth): # include 2 input nodes
# reduction should be used only for input node
stride = 2 if reduction and i < 2 else 1
op = nas.mutables.LayerChoice([ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
ops.Identity() if stride == 1 else
ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False),
ops.Zero(stride)],
key="r{}_d{}_i{}".format(reduction, depth, i))
self.mutable_ops[depth].append(op)
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
for ops in self.mutable_ops:
assert len(ops) == len(tensors)
cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors))
tensors.append(cur_tensor)
import torch.nn as nn
output = torch.cat(tensors[2:], dim=1)
return output
from .cnn_cell import CnnCell
class SearchCNN(nn.Module):
class CnnNetwork(nn.Module):
"""
Search CNN model
"""
def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3):
def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3, cell_type=CnnCell):
"""
Initializing a search channelsNN.
......@@ -121,7 +53,7 @@ class SearchCNN(nn.Module):
c_cur *= 2
reduction = True
cell = SearchCell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
cell = cell_type(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out
......
import torch
import torch.nn as nn
PRIMITIVES = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect', # identity
'skip_connect', # identity
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
'none'
]
OPS = {
'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine),
'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine),
'skip_connect': lambda C, stride, affine: \
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
}
......@@ -60,6 +58,7 @@ class PoolBN(nn.Module):
"""
AvgPool or MaxPool - BN
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
"""
Args:
......@@ -85,6 +84,7 @@ class StdConv(nn.Module):
""" Standard conv
ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
......@@ -101,6 +101,7 @@ class FacConv(nn.Module):
""" Factorized conv
ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
......@@ -118,14 +119,14 @@ class DilConv(nn.Module):
""" (Dilated) depthwise separable conv
ReLU - (Dilated) depthwise separable - Pointwise - BN
If dilation == 2, 3x3 conv => 5x5 receptive field
5x5 conv => 9x9 receptive field
5x5 conv => 9x9 receptive field
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
......@@ -138,6 +139,7 @@ class SepConv(nn.Module):
""" Depthwise separable conv
DilConv(dilation=1) * 2
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__()
self.net = nn.Sequential(
......@@ -172,6 +174,7 @@ class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise(stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super().__init__()
self.relu = nn.ReLU()
......
......@@ -94,7 +94,8 @@ class DartsTrainer(Trainer):
with torch.no_grad():
for step, (X, y) in enumerate(self.valid_loader):
X, y = X.to(self.device), y.to(self.device)
logits = self.model(X)
with self.mutator.forward_pass():
logits = self.model(X)
metrics = self.metrics(logits, y)
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
......
......@@ -40,7 +40,7 @@ class EnasMutator(PyTorchMutator):
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False)
self.skip_targets = nn.Parameter(torch.Tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False)
self.cross_entropy_loss = nn.CrossEntropyLoss()
def after_build(self, model):
......@@ -79,7 +79,7 @@ class EnasMutator(PyTorchMutator):
self._lstm_next_step()
logit = self.soft(self._h[-1])
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
logit = self.tanh_constant * torch.tanh(logit)
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += log_prob
......
from torch import nn as nn
class RankedModule(nn.Module):
def __init__(self, rank=None, reduction=False):
super(RankedModule, self).__init__()
self.rank = rank
self.reduction = reduction
......@@ -56,9 +56,6 @@ class PyTorchMutable(nn.Module):
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(PyTorchMutable):
"""
......@@ -85,6 +82,9 @@ class LayerChoice(PyTorchMutable):
self.reduction = reduction
self.return_mask = return_mask
def __len__(self):
return self.length
def forward(self, *inputs):
out, mask = self.mutator.on_forward(self, *inputs)
if self.return_mask:
......@@ -116,4 +116,4 @@ class InputChoice(PyTorchMutable):
def similar(self, other):
return type(self) == type(other) and \
self.n_candidates == other.n_candidates and self.n_selected and other.n_selected
self.n_candidates == other.n_candidates and self.n_selected and other.n_selected
from .trainer import PdartsTrainer
import copy
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.mutables import LayerChoice
class PdartsMutator(DartsMutator):
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches=None):
self.pdarts_epoch_index = pdarts_epoch_index
self.pdarts_num_to_drop = pdarts_num_to_drop
self.switches = switches
super(PdartsMutator, self).__init__(model)
def before_build(self, model):
self.choices = nn.ParameterDict()
if self.switches is None:
self.switches = {}
def named_mutables(self, model):
key2module = dict()
for name, module in model.named_modules():
if isinstance(module, LayerChoice):
key2module[module.key] = module
yield name, module, True
def drop_paths(self):
for key in self.switches:
prob = F.softmax(self.choices[key], dim=-1).data.cpu().numpy()
switches = self.switches[key]
idxs = []
for j in range(len(switches)):
if switches[j]:
idxs.append(j)
if self.pdarts_epoch_index == len(self.pdarts_num_to_drop) - 1:
# for the last stage, drop all Zero operations
drop = self.get_min_k_no_zero(prob, idxs, self.pdarts_num_to_drop[self.pdarts_epoch_index])
else:
drop = self.get_min_k(prob, self.pdarts_num_to_drop[self.pdarts_epoch_index])
for idx in drop:
switches[idxs[idx]] = False
return self.switches
def on_init_layer_choice(self, mutable: LayerChoice):
switches = self.switches.get(
mutable.key, [True for j in range(mutable.length)])
for index in range(len(switches)-1, -1, -1):
if switches[index] == False:
del(mutable.choices[index])
mutable.length -= 1
self.switches[mutable.key] = switches
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length))
def on_calc_layer_choice_mask(self, mutable: LayerChoice):
return F.softmax(self.choices[mutable.key], dim=-1)
def get_min_k(self, input_in, k):
index = []
for _ in range(k):
idx = np.argmin(input)
index.append(idx)
return index
def get_min_k_no_zero(self, w_in, idxs, k):
w = copy.deepcopy(w_in)
index = []
if 0 in idxs:
zf = True
else:
zf = False
if zf:
w = w[1:]
index.append(0)
k = k - 1
for _ in range(k):
idx = np.argmin(w)
w[idx] = 1
if zf:
idx = idx + 1
index.append(idx)
return index
from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.trainer import Trainer
from .mutator import PdartsMutator
class PdartsTrainer(Trainer):
def __init__(self, model_creator, metrics, num_epochs, dataset_train, dataset_valid,
layers=5, n_nodes=4, pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2],
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None):
self.model_creator = model_creator
self.layers = layers
self.n_nodes = n_nodes
self.pdarts_num_layers = pdarts_num_layers
self.pdarts_num_to_drop = pdarts_num_to_drop
self.pdarts_epoch = len(pdarts_num_to_drop)
self.darts_parameters = {
"metrics": metrics,
"num_epochs": num_epochs,
"dataset_train": dataset_train,
"dataset_valid": dataset_valid,
"batch_size": batch_size,
"workers": workers,
"device": device,
"log_frequency": log_frequency
}
def train(self):
layers = self.layers
n_nodes = self.n_nodes
switches = None
for epoch in range(self.pdarts_epoch):
layers = self.layers+self.pdarts_num_layers[epoch]
model, loss, model_optim, lr_scheduler = self.model_creator(
layers, n_nodes)
mutator = PdartsMutator(
model, epoch, self.pdarts_num_to_drop, switches)
self.trainer = DartsTrainer(model, loss=loss, model_optim=model_optim,
lr_scheduler=lr_scheduler, mutator=mutator, **self.darts_parameters)
print("start pdrats training %s..." % epoch)
self.trainer.train()
# with open('log/parameters_%d.txt' % epoch, "w") as f:
# f.write(str(model.parameters))
switches = mutator.drop_paths()
def export(self):
if (self.trainer is not None) and hasattr(self.trainer, "export"):
self.trainer.export()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment