Commit 3ddab980 authored by quzha's avatar quzha
Browse files

Merge branch 'dev-nas-refactor' of github.com:Microsoft/nni into dev-nas-refactor

parents 594924a9 d1d10de7
...@@ -80,6 +80,7 @@ venv.bak/ ...@@ -80,6 +80,7 @@ venv.bak/
# VSCode # VSCode
.vscode .vscode
.vs
# In case you place source code in ~/nni/ # In case you place source code in ~/nni/
/experiments /experiments
from argparse import ArgumentParser from argparse import ArgumentParser
import datasets
import torch import torch
import torch.nn as nn import torch.nn as nn
from model import SearchCNN import datasets
from nni.nas.pytorch.darts import DartsTrainer from nni.nas.pytorch.darts import CnnNetwork, DartsTrainer
from utils import accuracy from utils import accuracy
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser("darts") parser = ArgumentParser("darts")
parser.add_argument("--layers", default=4, type=int) parser.add_argument("--layers", default=5, type=int)
parser.add_argument("--nodes", default=2, type=int) parser.add_argument("--nodes", default=4, type=int)
parser.add_argument("--batch-size", default=128, type=int) parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int) parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args() args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10") dataset_train, dataset_valid = datasets.get_dataset("cifar10")
model = SearchCNN(3, 16, 10, args.layers, n_nodes=args.nodes) model = CnnNetwork(3, 16, 10, args.layers, n_nodes=args.nodes)
criterion = nn.CrossEntropyLoss() criterion = nn.CrossEntropyLoss()
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
......
from torchvision import transforms
from torchvision.datasets import CIFAR10
def get_dataset(cls):
MEAN = [0.49139968, 0.48215827, 0.44653124]
STD = [0.24703233, 0.24348505, 0.26158768]
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize(MEAN, STD)
]
train_transform = transforms.Compose(transf + normalize)
valid_transform = transforms.Compose(normalize)
if cls == "cifar10":
dataset_train = CIFAR10(root="./data", train=True, download=True, transform=train_transform)
dataset_valid = CIFAR10(root="./data", train=False, download=True, transform=valid_transform)
else:
raise NotImplementedError
return dataset_train, dataset_valid
from argparse import ArgumentParser
import datasets
import torch
import torch.nn as nn
import nni.nas.pytorch as nas
from nni.nas.pytorch.pdarts import PdartsTrainer
from nni.nas.pytorch.darts import CnnNetwork, CnnCell
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
if __name__ == "__main__":
parser = ArgumentParser("darts")
parser.add_argument("--layers", default=5, type=int)
parser.add_argument('--add_layers', action='append',
default=[0, 6, 12], help='add layers')
parser.add_argument("--nodes", default=4, type=int)
parser.add_argument("--batch-size", default=128, type=int)
parser.add_argument("--log-frequency", default=1, type=int)
args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10")
def model_creator(layers, n_nodes):
model = CnnNetwork(3, 16, 10, layers, n_nodes=n_nodes, cell_type=CnnCell)
loss = nn.CrossEntropyLoss()
model_optim = torch.optim.SGD(model.parameters(), 0.025,
momentum=0.9, weight_decay=3.0E-4)
n_epochs = 50
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(model_optim, n_epochs, eta_min=0.001)
return model, loss, model_optim, lr_scheduler
trainer = PdartsTrainer(model_creator,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
num_epochs=50,
pdarts_num_layers=[0, 6, 12],
pdarts_num_to_drop=[3, 2, 2],
dataset_train=dataset_train,
dataset_valid=dataset_valid,
layers=args.layers,
n_nodes=args.nodes,
batch_size=args.batch_size,
log_frequency=args.log_frequency)
trainer.train()
trainer.export()
from .mutator import DartsMutator from .mutator import DartsMutator
from .trainer import DartsTrainer from .trainer import DartsTrainer
from .cnn_cell import CnnCell
from .cnn_network import CnnNetwork
import torch
import torch.nn as nn
import nni.nas.pytorch as nas
from nni.nas.pytorch.modules import RankedModule
from .cnn_ops import OPS, PRIMITIVES, FactorizedReduce, StdConv
class CnnCell(RankedModule):
"""
Cell for search.
"""
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super(CnnCell, self).__init__(rank=1, reduction=reduction)
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(self.n_nodes):
self.mutable_ops.append(nn.ModuleList())
for i in range(2 + depth): # include 2 input nodes
# reduction should be used only for input node
stride = 2 if reduction and i < 2 else 1
m_ops = []
for primitive in PRIMITIVES:
op = OPS[primitive](channels, stride, False)
m_ops.append(op)
op = nas.mutables.LayerChoice(m_ops, key="r{}_d{}_i{}".format(reduction, depth, i))
self.mutable_ops[depth].append(op)
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
for ops in self.mutable_ops:
assert len(ops) == len(tensors)
cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors))
tensors.append(cur_tensor)
output = torch.cat(tensors[2:], dim=1)
return output
import torch
import torch.nn as nn
import ops
from nni.nas import pytorch as nas
class SearchCell(nn.Module):
"""
Cell for search.
"""
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
"""
Initialization a search cell.
Parameters
----------
n_nodes: int
Number of nodes in current DAG.
channels_pp: int
Number of output channels from previous previous cell.
channels_p: int
Number of output channels from previous cell.
channels: int
Number of channels that will be used in the current DAG.
reduction_p: bool
Flag for whether the previous cell is reduction cell or not.
reduction: bool
Flag for whether the current cell is reduction cell or not.
"""
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag import torch.nn as nn
self.mutable_ops = nn.ModuleList()
for depth in range(self.n_nodes):
self.mutable_ops.append(nn.ModuleList())
for i in range(2 + depth): # include 2 input nodes
# reduction should be used only for input node
stride = 2 if reduction and i < 2 else 1
op = nas.mutables.LayerChoice([ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
ops.Identity() if stride == 1 else
ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False),
ops.Zero(stride)],
key="r{}_d{}_i{}".format(reduction, depth, i))
self.mutable_ops[depth].append(op)
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
for ops in self.mutable_ops:
assert len(ops) == len(tensors)
cur_tensor = sum(op(tensor) for op, tensor in zip(ops, tensors))
tensors.append(cur_tensor)
output = torch.cat(tensors[2:], dim=1) from .cnn_cell import CnnCell
return output
class SearchCNN(nn.Module): class CnnNetwork(nn.Module):
""" """
Search CNN model Search CNN model
""" """
def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3): def __init__(self, in_channels, channels, n_classes, n_layers, n_nodes=4, stem_multiplier=3, cell_type=CnnCell):
""" """
Initializing a search channelsNN. Initializing a search channelsNN.
...@@ -121,7 +53,7 @@ class SearchCNN(nn.Module): ...@@ -121,7 +53,7 @@ class SearchCNN(nn.Module):
c_cur *= 2 c_cur *= 2
reduction = True reduction = True
cell = SearchCell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction) cell = cell_type(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell) self.cells.append(cell)
c_cur_out = c_cur * n_nodes c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out channels_pp, channels_p = channels_p, c_cur_out
......
import torch import torch
import torch.nn as nn import torch.nn as nn
PRIMITIVES = [ PRIMITIVES = [
'none',
'max_pool_3x3', 'max_pool_3x3',
'avg_pool_3x3', 'avg_pool_3x3',
'skip_connect', # identity 'skip_connect', # identity
'sep_conv_3x3', 'sep_conv_3x3',
'sep_conv_5x5', 'sep_conv_5x5',
'dil_conv_3x3', 'dil_conv_3x3',
'dil_conv_5x5', 'dil_conv_5x5',
'none'
] ]
OPS = { OPS = {
'none': lambda C, stride, affine: Zero(stride), 'none': lambda C, stride, affine: Zero(stride),
'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine), 'avg_pool_3x3': lambda C, stride, affine: PoolBN('avg', C, 3, stride, 1, affine=affine),
'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine), 'max_pool_3x3': lambda C, stride, affine: PoolBN('max', C, 3, stride, 1, affine=affine),
'skip_connect': lambda C, stride, affine: \ 'skip_connect': lambda C, stride, affine: Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine), 'sep_conv_3x3': lambda C, stride, affine: SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine), 'sep_conv_5x5': lambda C, stride, affine: SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine), 'sep_conv_7x7': lambda C, stride, affine: SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5 'dil_conv_3x3': lambda C, stride, affine: DilConv(C, C, 3, stride, 2, 2, affine=affine), # 5x5
'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9 'dil_conv_5x5': lambda C, stride, affine: DilConv(C, C, 5, stride, 4, 2, affine=affine), # 9x9
'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine) 'conv_7x1_1x7': lambda C, stride, affine: FacConv(C, C, 7, stride, 3, affine=affine)
} }
...@@ -60,6 +58,7 @@ class PoolBN(nn.Module): ...@@ -60,6 +58,7 @@ class PoolBN(nn.Module):
""" """
AvgPool or MaxPool - BN AvgPool or MaxPool - BN
""" """
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True): def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
""" """
Args: Args:
...@@ -85,6 +84,7 @@ class StdConv(nn.Module): ...@@ -85,6 +84,7 @@ class StdConv(nn.Module):
""" Standard conv """ Standard conv
ReLU - Conv - BN ReLU - Conv - BN
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
...@@ -101,6 +101,7 @@ class FacConv(nn.Module): ...@@ -101,6 +101,7 @@ class FacConv(nn.Module):
""" Factorized conv """ Factorized conv
ReLU - Conv(Kx1) - Conv(1xK) - BN ReLU - Conv(Kx1) - Conv(1xK) - BN
""" """
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
...@@ -118,14 +119,14 @@ class DilConv(nn.Module): ...@@ -118,14 +119,14 @@ class DilConv(nn.Module):
""" (Dilated) depthwise separable conv """ (Dilated) depthwise separable conv
ReLU - (Dilated) depthwise separable - Pointwise - BN ReLU - (Dilated) depthwise separable - Pointwise - BN
If dilation == 2, 3x3 conv => 5x5 receptive field If dilation == 2, 3x3 conv => 5x5 receptive field
5x5 conv => 9x9 receptive field 5x5 conv => 9x9 receptive field
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
nn.ReLU(), nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in, bias=False),
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False), nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine) nn.BatchNorm2d(C_out, affine=affine)
) )
...@@ -138,6 +139,7 @@ class SepConv(nn.Module): ...@@ -138,6 +139,7 @@ class SepConv(nn.Module):
""" Depthwise separable conv """ Depthwise separable conv
DilConv(dilation=1) * 2 DilConv(dilation=1) * 2
""" """
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__() super().__init__()
self.net = nn.Sequential( self.net = nn.Sequential(
...@@ -172,6 +174,7 @@ class FactorizedReduce(nn.Module): ...@@ -172,6 +174,7 @@ class FactorizedReduce(nn.Module):
""" """
Reduce feature map size by factorized pointwise(stride=2). Reduce feature map size by factorized pointwise(stride=2).
""" """
def __init__(self, C_in, C_out, affine=True): def __init__(self, C_in, C_out, affine=True):
super().__init__() super().__init__()
self.relu = nn.ReLU() self.relu = nn.ReLU()
......
...@@ -94,7 +94,8 @@ class DartsTrainer(Trainer): ...@@ -94,7 +94,8 @@ class DartsTrainer(Trainer):
with torch.no_grad(): with torch.no_grad():
for step, (X, y) in enumerate(self.valid_loader): for step, (X, y) in enumerate(self.valid_loader):
X, y = X.to(self.device), y.to(self.device) X, y = X.to(self.device), y.to(self.device)
logits = self.model(X) with self.mutator.forward_pass():
logits = self.model(X)
metrics = self.metrics(logits, y) metrics = self.metrics(logits, y)
meters.update(metrics) meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0: if self.log_frequency is not None and step % self.log_frequency == 0:
......
...@@ -40,7 +40,7 @@ class EnasMutator(PyTorchMutator): ...@@ -40,7 +40,7 @@ class EnasMutator(PyTorchMutator):
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False) self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False) self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1) self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False) self.skip_targets = nn.Parameter(torch.Tensor([1.0 - self.skip_target, self.skip_target]), requires_grad=False)
self.cross_entropy_loss = nn.CrossEntropyLoss() self.cross_entropy_loss = nn.CrossEntropyLoss()
def after_build(self, model): def after_build(self, model):
...@@ -79,7 +79,7 @@ class EnasMutator(PyTorchMutator): ...@@ -79,7 +79,7 @@ class EnasMutator(PyTorchMutator):
self._lstm_next_step() self._lstm_next_step()
logit = self.soft(self._h[-1]) logit = self.soft(self._h[-1])
if self.tanh_constant is not None: if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit) logit = self.tanh_constant * torch.tanh(logit)
branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1) branch_id = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, branch_id) log_prob = self.cross_entropy_loss(logit, branch_id)
self.sample_log_prob += log_prob self.sample_log_prob += log_prob
......
from torch import nn as nn
class RankedModule(nn.Module):
def __init__(self, rank=None, reduction=False):
super(RankedModule, self).__init__()
self.rank = rank
self.reduction = reduction
...@@ -56,9 +56,6 @@ class PyTorchMutable(nn.Module): ...@@ -56,9 +56,6 @@ class PyTorchMutable(nn.Module):
"Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__" "Mutator not set for {}. Did you initialize a mutable on the fly in forward pass? Move to __init__"
"so that trainer can locate all your mutables. See NNI docs for more details.".format(self)) "so that trainer can locate all your mutables. See NNI docs for more details.".format(self))
def __repr__(self):
return "{} ({})".format(self.name, self.key)
class MutableScope(PyTorchMutable): class MutableScope(PyTorchMutable):
""" """
...@@ -85,6 +82,9 @@ class LayerChoice(PyTorchMutable): ...@@ -85,6 +82,9 @@ class LayerChoice(PyTorchMutable):
self.reduction = reduction self.reduction = reduction
self.return_mask = return_mask self.return_mask = return_mask
def __len__(self):
return self.length
def forward(self, *inputs): def forward(self, *inputs):
out, mask = self.mutator.on_forward(self, *inputs) out, mask = self.mutator.on_forward(self, *inputs)
if self.return_mask: if self.return_mask:
...@@ -116,4 +116,4 @@ class InputChoice(PyTorchMutable): ...@@ -116,4 +116,4 @@ class InputChoice(PyTorchMutable):
def similar(self, other): def similar(self, other):
return type(self) == type(other) and \ return type(self) == type(other) and \
self.n_candidates == other.n_candidates and self.n_selected and other.n_selected self.n_candidates == other.n_candidates and self.n_selected and other.n_selected
from .trainer import PdartsTrainer
import copy
import numpy as np
import torch
from torch import nn as nn
from torch.nn import functional as F
from nni.nas.pytorch.darts import DartsMutator
from nni.nas.pytorch.mutables import LayerChoice
class PdartsMutator(DartsMutator):
def __init__(self, model, pdarts_epoch_index, pdarts_num_to_drop, switches=None):
self.pdarts_epoch_index = pdarts_epoch_index
self.pdarts_num_to_drop = pdarts_num_to_drop
self.switches = switches
super(PdartsMutator, self).__init__(model)
def before_build(self, model):
self.choices = nn.ParameterDict()
if self.switches is None:
self.switches = {}
def named_mutables(self, model):
key2module = dict()
for name, module in model.named_modules():
if isinstance(module, LayerChoice):
key2module[module.key] = module
yield name, module, True
def drop_paths(self):
for key in self.switches:
prob = F.softmax(self.choices[key], dim=-1).data.cpu().numpy()
switches = self.switches[key]
idxs = []
for j in range(len(switches)):
if switches[j]:
idxs.append(j)
if self.pdarts_epoch_index == len(self.pdarts_num_to_drop) - 1:
# for the last stage, drop all Zero operations
drop = self.get_min_k_no_zero(prob, idxs, self.pdarts_num_to_drop[self.pdarts_epoch_index])
else:
drop = self.get_min_k(prob, self.pdarts_num_to_drop[self.pdarts_epoch_index])
for idx in drop:
switches[idxs[idx]] = False
return self.switches
def on_init_layer_choice(self, mutable: LayerChoice):
switches = self.switches.get(
mutable.key, [True for j in range(mutable.length)])
for index in range(len(switches)-1, -1, -1):
if switches[index] == False:
del(mutable.choices[index])
mutable.length -= 1
self.switches[mutable.key] = switches
self.choices[mutable.key] = nn.Parameter(1.0E-3 * torch.randn(mutable.length))
def on_calc_layer_choice_mask(self, mutable: LayerChoice):
return F.softmax(self.choices[mutable.key], dim=-1)
def get_min_k(self, input_in, k):
index = []
for _ in range(k):
idx = np.argmin(input)
index.append(idx)
return index
def get_min_k_no_zero(self, w_in, idxs, k):
w = copy.deepcopy(w_in)
index = []
if 0 in idxs:
zf = True
else:
zf = False
if zf:
w = w[1:]
index.append(0)
k = k - 1
for _ in range(k):
idx = np.argmin(w)
w[idx] = 1
if zf:
idx = idx + 1
index.append(idx)
return index
from nni.nas.pytorch.darts import DartsTrainer
from nni.nas.pytorch.trainer import Trainer
from .mutator import PdartsMutator
class PdartsTrainer(Trainer):
def __init__(self, model_creator, metrics, num_epochs, dataset_train, dataset_valid,
layers=5, n_nodes=4, pdarts_num_layers=[0, 6, 12], pdarts_num_to_drop=[3, 2, 2],
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None):
self.model_creator = model_creator
self.layers = layers
self.n_nodes = n_nodes
self.pdarts_num_layers = pdarts_num_layers
self.pdarts_num_to_drop = pdarts_num_to_drop
self.pdarts_epoch = len(pdarts_num_to_drop)
self.darts_parameters = {
"metrics": metrics,
"num_epochs": num_epochs,
"dataset_train": dataset_train,
"dataset_valid": dataset_valid,
"batch_size": batch_size,
"workers": workers,
"device": device,
"log_frequency": log_frequency
}
def train(self):
layers = self.layers
n_nodes = self.n_nodes
switches = None
for epoch in range(self.pdarts_epoch):
layers = self.layers+self.pdarts_num_layers[epoch]
model, loss, model_optim, lr_scheduler = self.model_creator(
layers, n_nodes)
mutator = PdartsMutator(
model, epoch, self.pdarts_num_to_drop, switches)
self.trainer = DartsTrainer(model, loss=loss, model_optim=model_optim,
lr_scheduler=lr_scheduler, mutator=mutator, **self.darts_parameters)
print("start pdrats training %s..." % epoch)
self.trainer.train()
# with open('log/parameters_%d.txt' % epoch, "w") as f:
# f.write(str(model.parameters))
switches = mutator.drop_paths()
def export(self):
if (self.trainer is not None) and hasattr(self.trainer, "export"):
self.trainer.export()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment