Unverified Commit 165756cc authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Weight-sharing trainers (#3137)

* First commit for WS in Retiarii

* Refactor ENAS trainer

* Fix DARTS trainer

* Fix ENAS trainer

* Fix issues in DARTS and Proxyless

* Fix ProxylessNAS and Random trainer

* Refactor mask
parent 990364b7
...@@ -11,9 +11,9 @@ import torch.nn as nn ...@@ -11,9 +11,9 @@ import torch.nn as nn
import datasets import datasets
from model import CNN from model import CNN
from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback from nni.nas.pytorch.callbacks import ArchitectureCheckpoint, LRSchedulerCallback
from nni.algorithms.nas.pytorch.darts import DartsTrainer
from utils import accuracy from utils import accuracy
logger = logging.getLogger('nni') logger = logging.getLogger('nni')
if __name__ == "__main__": if __name__ == "__main__":
...@@ -25,6 +25,7 @@ if __name__ == "__main__": ...@@ -25,6 +25,7 @@ if __name__ == "__main__":
parser.add_argument("--channels", default=16, type=int) parser.add_argument("--channels", default=16, type=int)
parser.add_argument("--unrolled", default=False, action="store_true") parser.add_argument("--unrolled", default=False, action="store_true")
parser.add_argument("--visualization", default=False, action="store_true") parser.add_argument("--visualization", default=False, action="store_true")
parser.add_argument("--v1", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10") dataset_train, dataset_valid = datasets.get_dataset("cifar10")
...@@ -35,17 +36,35 @@ if __name__ == "__main__": ...@@ -35,17 +36,35 @@ if __name__ == "__main__":
optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4) optim = torch.optim.SGD(model.parameters(), 0.025, momentum=0.9, weight_decay=3.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, args.epochs, eta_min=0.001)
trainer = DartsTrainer(model, if args.v1:
loss=criterion, from nni.algorithms.nas.pytorch.darts import DartsTrainer
metrics=lambda output, target: accuracy(output, target, topk=(1,)), trainer = DartsTrainer(model,
optimizer=optim, loss=criterion,
num_epochs=args.epochs, metrics=lambda output, target: accuracy(output, target, topk=(1,)),
dataset_train=dataset_train, optimizer=optim,
dataset_valid=dataset_valid, num_epochs=args.epochs,
batch_size=args.batch_size, dataset_train=dataset_train,
log_frequency=args.log_frequency, dataset_valid=dataset_valid,
unrolled=args.unrolled, batch_size=args.batch_size,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")]) log_frequency=args.log_frequency,
if args.visualization: unrolled=args.unrolled,
trainer.enable_visualization() callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")])
trainer.train() if args.visualization:
trainer.enable_visualization()
trainer.train()
else:
from nni.retiarii.trainer.pytorch import DartsTrainer
trainer = DartsTrainer(
model=model,
loss=criterion,
metrics=lambda output, target: accuracy(output, target, topk=(1,)),
optimizer=optim,
num_epochs=args.epochs,
dataset=dataset_train,
batch_size=args.batch_size,
log_frequency=args.log_frequency,
unrolled=args.unrolled
)
trainer.fit()
print('Final architecture:', trainer.export())
...@@ -48,9 +48,15 @@ class Cell(nn.Module): ...@@ -48,9 +48,15 @@ class Cell(nn.Module):
], key=cell_name + "_op") ], key=cell_name + "_op")
def forward(self, prev_layers): def forward(self, prev_layers):
chosen_input, chosen_mask = self.input_choice(prev_layers) from nni.retiarii.trainer.pytorch.random import PathSamplingInputChoice
cell_out = self.op_choice(chosen_input) out = self.input_choice(prev_layers)
return cell_out, chosen_mask if isinstance(self.input_choice, PathSamplingInputChoice):
# Retiarii pattern
return out, self.input_choice.mask
else:
chosen_input, chosen_mask = out
cell_out = self.op_choice(chosen_input)
return cell_out, chosen_mask
class Node(mutables.MutableScope): class Node(mutables.MutableScope):
...@@ -71,7 +77,7 @@ class Calibration(nn.Module): ...@@ -71,7 +77,7 @@ class Calibration(nn.Module):
self.process = None self.process = None
if in_channels != out_channels: if in_channels != out_channels:
self.process = StdConv(in_channels, out_channels) self.process = StdConv(in_channels, out_channels)
def forward(self, x): def forward(self, x):
if self.process is None: if self.process is None:
return x return x
...@@ -83,7 +89,7 @@ class ReductionLayer(nn.Module): ...@@ -83,7 +89,7 @@ class ReductionLayer(nn.Module):
super().__init__() super().__init__()
self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False) self.reduce0 = FactorizedReduce(in_channels_pp, out_channels, affine=False)
self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False) self.reduce1 = FactorizedReduce(in_channels_p, out_channels, affine=False)
def forward(self, pprev, prev): def forward(self, pprev, prev):
return self.reduce0(pprev), self.reduce1(prev) return self.reduce0(pprev), self.reduce1(prev)
...@@ -109,7 +115,7 @@ class ENASLayer(nn.Module): ...@@ -109,7 +115,7 @@ class ENASLayer(nn.Module):
nn.init.kaiming_normal_(self.final_conv_w) nn.init.kaiming_normal_(self.final_conv_w)
def forward(self, pprev, prev): def forward(self, pprev, prev):
pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev) pprev_, prev_ = self.preproc0(pprev), self.preproc1(prev)
prev_nodes_out = [pprev_, prev_] prev_nodes_out = [pprev_, prev_]
nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device) nodes_used_mask = torch.zeros(self.num_nodes + 2, dtype=torch.bool, device=prev.device)
......
...@@ -26,17 +26,22 @@ if __name__ == "__main__": ...@@ -26,17 +26,22 @@ if __name__ == "__main__":
parser.add_argument("--search-for", choices=["macro", "micro"], default="macro") parser.add_argument("--search-for", choices=["macro", "micro"], default="macro")
parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)") parser.add_argument("--epochs", default=None, type=int, help="Number of epochs (default: macro 310, micro 150)")
parser.add_argument("--visualization", default=False, action="store_true") parser.add_argument("--visualization", default=False, action="store_true")
parser.add_argument("--v1", default=False, action="store_true")
args = parser.parse_args() args = parser.parse_args()
dataset_train, dataset_valid = datasets.get_dataset("cifar10") dataset_train, dataset_valid = datasets.get_dataset("cifar10")
mutator = None
ctrl_kwargs = {}
if args.search_for == "macro": if args.search_for == "macro":
model = GeneralNetwork() model = GeneralNetwork()
num_epochs = args.epochs or 310 num_epochs = args.epochs or 310
mutator = None
elif args.search_for == "micro": elif args.search_for == "micro":
model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=True) model = MicroNetwork(num_layers=6, out_channels=20, num_nodes=5, dropout_rate=0.1, use_aux_heads=False)
num_epochs = args.epochs or 150 num_epochs = args.epochs or 150
mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True) if args.v1:
mutator = enas.EnasMutator(model, tanh_constant=1.1, cell_exit_extra_step=True)
else:
ctrl_kwargs = {"tanh_constant": 1.1}
else: else:
raise AssertionError raise AssertionError
...@@ -44,18 +49,32 @@ if __name__ == "__main__": ...@@ -44,18 +49,32 @@ if __name__ == "__main__":
optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4) optimizer = torch.optim.SGD(model.parameters(), 0.05, momentum=0.9, weight_decay=1.0E-4)
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001) lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=0.001)
trainer = enas.EnasTrainer(model, if args.v1:
loss=criterion, trainer = enas.EnasTrainer(model,
metrics=accuracy, loss=criterion,
reward_function=reward_accuracy, metrics=accuracy,
optimizer=optimizer, reward_function=reward_accuracy,
callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")], optimizer=optimizer,
batch_size=args.batch_size, callbacks=[LRSchedulerCallback(lr_scheduler), ArchitectureCheckpoint("./checkpoints")],
num_epochs=num_epochs, batch_size=args.batch_size,
dataset_train=dataset_train, num_epochs=num_epochs,
dataset_valid=dataset_valid, dataset_train=dataset_train,
log_frequency=args.log_frequency, dataset_valid=dataset_valid,
mutator=mutator) log_frequency=args.log_frequency,
if args.visualization: mutator=mutator)
trainer.enable_visualization() if args.visualization:
trainer.train() trainer.enable_visualization()
trainer.train()
else:
from nni.retiarii.trainer.pytorch.enas import EnasTrainer
trainer = EnasTrainer(model,
loss=criterion,
metrics=accuracy,
reward_function=reward_accuracy,
optimizer=optimizer,
batch_size=args.batch_size,
num_epochs=num_epochs,
dataset=dataset_train,
log_frequency=args.log_frequency,
ctrl_kwargs=ctrl_kwargs)
trainer.fit()
...@@ -6,7 +6,7 @@ import torchvision ...@@ -6,7 +6,7 @@ import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from nni.nas.pytorch.mutables import LayerChoice, InputChoice from nni.nas.pytorch.mutables import LayerChoice, InputChoice
from nni.nas.pytorch.darts import DartsTrainer from nni.algorithms.nas.pytorch.darts import DartsTrainer
class Net(nn.Module): class Net(nn.Module):
......
import logging
import os import os
import sys import sys
import logging
from argparse import ArgumentParser from argparse import ArgumentParser
import torch import torch
import datasets from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from torchvision import transforms
from putils import get_parameters import datasets
from model import SearchMobileNet from model import SearchMobileNet
from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer from nni.algorithms.nas.pytorch.proxylessnas import ProxylessNasTrainer
from putils import LabelSmoothingLoss, accuracy, get_parameters
from retrain import Retrain from retrain import Retrain
logger = logging.getLogger('nni_proxylessnas') logger = logging.getLogger('nni_proxylessnas')
...@@ -30,7 +33,7 @@ if __name__ == "__main__": ...@@ -30,7 +33,7 @@ if __name__ == "__main__":
parser.add_argument("--resize_scale", default=0.08, type=float) parser.add_argument("--resize_scale", default=0.08, type=float)
parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None']) parser.add_argument("--distort_color", default='normal', type=str, choices=['normal', 'strong', 'None'])
# configurations for training mode # configurations for training mode
parser.add_argument("--train_mode", default='search', type=str, choices=['search', 'retrain']) parser.add_argument("--train_mode", default='search', type=str, choices=['search_v1', 'search', 'retrain'])
# configurations for search # configurations for search
parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str) parser.add_argument("--checkpoint_path", default='./search_mobile_net.pt', type=str)
parser.add_argument("--arch_path", default='./arch_path.pt', type=str) parser.add_argument("--arch_path", default='./arch_path.pt', type=str)
...@@ -80,6 +83,26 @@ if __name__ == "__main__": ...@@ -80,6 +83,26 @@ if __name__ == "__main__":
optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5) optimizer = torch.optim.SGD(get_parameters(model), lr=0.05, momentum=momentum, nesterov=nesterov, weight_decay=4e-5)
if args.train_mode == 'search': if args.train_mode == 'search':
from nni.retiarii.trainer.pytorch import ProxylessTrainer
from torchvision.datasets import ImageNet
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
dataset = ImageNet(args.data_path, transform=transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]))
trainer = ProxylessTrainer(model,
loss=LabelSmoothingLoss(),
dataset=dataset,
optimizer=optimizer,
metrics=lambda output, target: accuracy(output, target, topk=(1, 5,)),
num_epochs=120,
log_frequency=10)
trainer.fit()
print('Final architecture:', trainer.export())
elif args.train_mode == 'search_v1':
# this is architecture search # this is architecture search
logger.info('Creating ProxylessNasTrainer...') logger.info('Creating ProxylessNasTrainer...')
trainer = ProxylessNasTrainer(model, trainer = ProxylessNasTrainer(model,
......
...@@ -58,11 +58,9 @@ class SearchMobileNet(nn.Module): ...@@ -58,11 +58,9 @@ class SearchMobileNet(nn.Module):
# if it is not the first one # if it is not the first one
op_candidates += [ops.OPS['Zero'](input_channel, width, stride)] op_candidates += [ops.OPS['Zero'](input_channel, width, stride)]
conv_op = nas.mutables.LayerChoice(op_candidates, conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i)) key="s{}_c{}".format(stage_cnt, i))
else: else:
conv_op = nas.mutables.LayerChoice(op_candidates, conv_op = nas.mutables.LayerChoice(op_candidates,
return_mask=True,
key="s{}_c{}".format(stage_cnt, i)) key="s{}_c{}".format(stage_cnt, i))
# shortcut # shortcut
if stride == 1 and input_channel == width: if stride == 1 and input_channel == width:
......
...@@ -39,19 +39,13 @@ class MobileInvertedResidualBlock(nn.Module): ...@@ -39,19 +39,13 @@ class MobileInvertedResidualBlock(nn.Module):
self.op_candidates_list = op_candidates_list self.op_candidates_list = op_candidates_list
def forward(self, x): def forward(self, x):
out, idx = self.mobile_inverted_conv(x) out = self.mobile_inverted_conv(x)
# TODO: unify idx format if torch.sum(torch.abs(out)).item() == 0 and x.size() == out.size():
if not isinstance(idx, int): # is zero layer
idx = (idx == 1).nonzero() return x
if self.op_candidates_list[idx].is_zero_layer(): if self.shortcut is None:
res = x return out
elif self.shortcut is None: return out + self.shortcut(x)
res = out
else:
conv_x = out
skip_x = self.shortcut(x)
res = skip_x + conv_x
return res
class ShuffleLayer(nn.Module): class ShuffleLayer(nn.Module):
......
import torch
import torch.nn as nn import torch.nn as nn
def get_parameters(model, keys=None, mode='include'): def get_parameters(model, keys=None, mode='include'):
if keys is None: if keys is None:
for name, param in model.named_parameters(): for name, param in model.named_parameters():
...@@ -36,6 +38,7 @@ def get_same_padding(kernel_size): ...@@ -36,6 +38,7 @@ def get_same_padding(kernel_size):
assert kernel_size % 2 > 0, 'kernel size should be odd number' assert kernel_size % 2 > 0, 'kernel size should be odd number'
return kernel_size // 2 return kernel_size // 2
def build_activation(act_func, inplace=True): def build_activation(act_func, inplace=True):
if act_func == 'relu': if act_func == 'relu':
return nn.ReLU(inplace=inplace) return nn.ReLU(inplace=inplace)
...@@ -65,3 +68,40 @@ def make_divisible(v, divisor, min_val=None): ...@@ -65,3 +68,40 @@ def make_divisible(v, divisor, min_val=None):
if new_v < 0.9 * v: if new_v < 0.9 * v:
new_v += divisor new_v += divisor
return new_v return new_v
def accuracy(output, target, topk=(1,)):
""" Computes the precision@k for the specified values of k """
maxk = max(topk)
batch_size = target.size(0)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
# one-hot case
if target.ndimension() > 1:
target = target.max(1)[1]
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = dict()
for k in topk:
correct_k = correct[:k].view(-1).float().sum(0)
res["acc{}".format(k)] = correct_k.mul_(1.0 / batch_size).item()
return res
class LabelSmoothingLoss(nn.Module):
def __init__(self, smoothing=0.1, dim=-1):
super(LabelSmoothingLoss, self).__init__()
self.confidence = 1.0 - smoothing
self.smoothing = smoothing
self.dim = dim
def forward(self, pred, target):
pred = pred.log_softmax(dim=self.dim)
num_classes = pred.size(self.dim)
with torch.no_grad():
true_dist = torch.zeros_like(pred)
true_dist.fill_(self.smoothing / (num_classes - 1))
true_dist.scatter_(1, target.data.unsqueeze(1), self.confidence)
return torch.mean(torch.sum(-true_dist * pred, dim=self.dim))
...@@ -109,7 +109,13 @@ class MutableScope(Mutable): ...@@ -109,7 +109,13 @@ class MutableScope(Mutable):
def __init__(self, key): def __init__(self, key):
super().__init__(key=key) super().__init__(key=key)
def _check_built(self):
return True # bypass the test because it's deprecated
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
if not hasattr(self, 'mutator'):
return super().__call__(*args, **kwargs)
warnings.warn("`MutableScope` is deprecated in Retiarii.", DeprecationWarning)
try: try:
self._check_built() self._check_built()
self.mutator.enter_mutable_scope(self) self.mutator.enter_mutable_scope(self)
......
import abc import abc
from typing import *
class BaseTrainer(abc.ABC): class BaseTrainer(abc.ABC):
...@@ -20,3 +21,15 @@ class BaseTrainer(abc.ABC): ...@@ -20,3 +21,15 @@ class BaseTrainer(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def fit(self) -> None: def fit(self) -> None:
pass pass
class BaseOneShotTrainer(BaseTrainer):
"""
Build many (possibly all) architectures into a full graph, search (with train) and export the best.
It has an extra ``export`` function that exports an object representing the final searched architecture.
"""
@abc.abstractmethod
def export(self) -> Any:
pass
from .base import PyTorchImageClassificationTrainer
from .darts import DartsTrainer
from .enas import EnasTrainer
from .proxyless import ProxylessTrainer
from .random import RandomTrainer, SinglePathTrainer
...@@ -9,7 +9,7 @@ from torchvision import datasets, transforms ...@@ -9,7 +9,7 @@ from torchvision import datasets, transforms
import nni import nni
from .interface import BaseTrainer from ..interface import BaseTrainer
def get_default_transform(dataset: str) -> Any: def get_default_transform(dataset: str) -> Any:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import logging
import torch
import torch.nn as nn
from nni.nas.pytorch.mutables import LayerChoice
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger = logging.getLogger(__name__)
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs):
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * self.alpha.view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super(DartsLayerChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return torch.argmax(self.alpha).item()
class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1
def forward(self, inputs):
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * self.alpha.view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super(DartsInputChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return torch.argsort(-self.alpha).cpu().numpy().tolist()[:self.n_chosen]
class DartsTrainer(BaseOneShotTrainer):
"""
DARTS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
learning_rate : float
Learning rate to optimize the model.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
def __init__(self, model, loss, metrics, optimizer,
num_epochs, dataset, grad_clip=5.,
learning_rate=2.5E-3, batch_size=64, workers=4,
device=None, log_frequency=None,
arc_learning_rate=3.0E-4, unrolled=False):
self.model = model
self.loss = loss
self.metrics = metrics
self.num_epochs = num_epochs
self.dataset = dataset
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.model.to(self.device)
self.nas_modules = []
replace_layer_choice(self.model, DartsLayerChoice, self.nas_modules)
replace_input_choice(self.model, DartsInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.model_optim = optimizer
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3)
self.unrolled = unrolled
self.grad_clip = 5.
self._init_dataloader()
def _init_dataloader(self):
n_train = len(self.dataset)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
def _train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
# phase 1. architecture step
self.ctrl_optim.zero_grad()
if self.unrolled:
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
else:
self._backward(val_X, val_y)
self.ctrl_optim.step()
# phase 2: child network step
self.model_optim.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) # gradient clipping
self.model_optim.step()
metrics = self.metrics(logits, trn_y)
metrics['loss'] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info('Epoch [%s/%s] Step [%s/%s] %s', epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _logits_and_loss(self, X, y):
logits = self.model(X)
loss = self.loss(logits, y)
return logits, loss
def _backward(self, val_X, val_y):
"""
Simple backward with gradient descent
"""
_, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
"""
Compute unrolled loss and backward its gradients
"""
backup_params = copy.deepcopy(tuple(self.model.parameters()))
# do virtual step on training data
lr = self.model_optim.param_groups[0]["lr"]
momentum = self.model_optim.param_groups[0]["momentum"]
weight_decay = self.model_optim.param_groups[0]["weight_decay"]
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_, loss = self._logits_and_loss(val_X, val_y)
w_model, w_ctrl = tuple(self.model.parameters()), tuple([c.alpha for c in self.nas_modules])
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
with torch.no_grad():
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h
# restore weights
self._restore_weights(backup_params)
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
"""
Compute unrolled weights w`
"""
# don't need zero_grad, using autograd to calculate gradients
_, loss = self._logits_and_loss(X, y)
gradients = torch.autograd.grad(loss, self.model.parameters())
with torch.no_grad():
for w, g in zip(self.model.parameters(), gradients):
m = self.model_optim.state[w].get('momentum_buffer', 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
with torch.no_grad():
for param, backup in zip(self.model.parameters(), backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
_logger.warning('In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.', norm.item())
dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.model.parameters(), dw):
p += e * d
_, loss = self._logits_and_loss(trn_X, trn_y)
dalphas.append(torch.autograd.grad(loss, [c.alpha for c in self.nas_modules]))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian
def fit(self):
for i in range(self.num_epochs):
self._train_one_epoch(i)
@torch.no_grad()
def export(self):
result = dict()
for name, module in self.nas_modules:
if name not in result:
result[name] = module.export()
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ..interface import BaseOneShotTrainer
from .random import PathSamplingLayerChoice, PathSamplingInputChoice
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, to_device
_logger = logging.getLogger(__name__)
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c
class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
class ReinforceController(nn.Module):
"""
A controller that mutates the graph with RL.
Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, entropy_reduction='sum'):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.skip_target = skip_target
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]),
requires_grad=False) # pylint: disable=not-callable
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
})
self.embedding = nn.ModuleDict({
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})
def resample(self):
self._initialize()
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
return result
def _initialize(self):
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field):
self._lstm_next_step()
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one:
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device)
sampled = sampled.detach().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1:
sampled = sampled[0]
return sampled
class EnasTrainer(BaseOneShotTrainer):
"""
ENAS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_lr : float
Learning rate for RL controller.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_steps : int
Number of mini-batches for each epoch of RL controller learning.
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
"""
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset,
batch_size=64, workers=4, device=None, log_frequency=None,
grad_clip=5., entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None):
self.model = model
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.num_epochs = num_epochs
self.dataset = dataset
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.nas_modules = []
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.model.to(self.device)
self.nas_fields = [ReinforceField(name, len(module),
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1)
for name, module in self.nas_modules]
self.controller = ReinforceController(self.nas_fields, **(ctrl_kwargs or {}))
self.grad_clip = grad_clip
self.reward_function = reward_function
self.ctrl_optim = optim.Adam(self.controller.parameters(), lr=ctrl_lr)
self.batch_size = batch_size
self.workers = workers
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.init_dataloader()
def init_dataloader(self):
n_train = len(self.dataset)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
def _train_model(self, epoch):
self.model.train()
self.controller.eval()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = to_device(x, self.device), to_device(y, self.device)
self.optimizer.zero_grad()
self._resample()
logits = self.model(x)
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
loss.backward()
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.optimizer.step()
metrics['loss'] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info('Model Epoch [%d/%d] Step [%d/%d] %s', epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _train_controller(self, epoch):
self.model.eval()
self.controller.train()
meters = AverageMeterGroup()
self.ctrl_optim.zero_grad()
for ctrl_step, (x, y) in enumerate(self.valid_loader):
x, y = to_device(x, self.device), to_device(y, self.device)
self._resample()
with torch.no_grad():
logits = self.model(x)
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight:
reward += self.entropy_weight * self.controller.sample_entropy.item()
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.controller.sample_skip_penalty
metrics['reward'] = reward
metrics['loss'] = loss.item()
metrics['ent'] = self.controller.sample_entropy.item()
metrics['log_prob'] = self.controller.sample_log_prob.item()
metrics['baseline'] = self.baseline
metrics['skip'] = self.controller.sample_skip_penalty
loss /= self.ctrl_steps_aggregate
loss.backward()
meters.update(metrics)
if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0:
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.grad_clip)
self.ctrl_optim.step()
self.ctrl_optim.zero_grad()
if self.log_frequency is not None and ctrl_step % self.log_frequency == 0:
_logger.info('RL Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs,
ctrl_step + 1, len(self.valid_loader), meters)
def _resample(self):
result = self.controller.resample()
for name, module in self.nas_modules:
module.sampled = result[name]
def fit(self):
for i in range(self.num_epochs):
self._train_model(i)
self._train_controller(i)
def export(self):
self.controller.eval()
with torch.no_grad():
return self.controller.resample()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger = logging.getLogger(__name__)
class ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = x.detach()
detached_x.requires_grad = x.requires_grad
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class ProxylessLayerChoice(nn.Module):
def __init__(self, ops):
super(ProxylessLayerChoice, self).__init__()
self.ops = nn.ModuleList(ops)
self.alpha = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
self._binary_gates = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
self.sampled = None
def forward(self, *args):
def run_function(ops, active_id):
def forward(_x):
return ops[active_id](_x)
return forward
def backward_function(ops, active_id, binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1
x = args[0]
return ArchGradientFunction.apply(
x, self._binary_gates, run_function(self.ops, self.sampled),
backward_function(self.ops, self.sampled, self._binary_gates)
)
def resample(self):
probs = F.softmax(self.alpha, dim=-1)
sample = torch.multinomial(probs, 1)[0].item()
self.sampled = sample
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0
def finalize_grad(self):
binary_grads = self._binary_gates.grad
with torch.no_grad():
if self.alpha.grad is None:
self.alpha.grad = torch.zeros_like(self.alpha.data)
probs = F.softmax(self.alpha, dim=-1)
for i in range(len(self.ops)):
for j in range(len(self.ops)):
self.alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
def export(self):
return torch.argmax(self.alpha).item()
class ProxylessInputChoice(nn.Module):
def __init__(self, *args, **kwargs):
raise NotImplementedError('Input choice is not supported for ProxylessNAS.')
class ProxylessTrainer(BaseOneShotTrainer):
"""
Proxyless trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
warmup_epochs : int
Number of epochs to warmup model parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
arc_learning_rate : float
Learning rate of architecture parameters.
"""
def __init__(self, model, loss, metrics, optimizer,
num_epochs, dataset, warmup_epochs=0,
batch_size=64, workers=4, device=None, log_frequency=None,
arc_learning_rate=1.0E-3):
self.model = model
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.num_epochs = num_epochs
self.warmup_epochs = warmup_epochs
self.dataset = dataset
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.model.to(self.device)
self.nas_modules = []
replace_layer_choice(self.model, ProxylessLayerChoice, self.nas_modules)
replace_input_choice(self.model, ProxylessInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.optimizer = optimizer
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate,
weight_decay=0, betas=(0, 0.999), eps=1e-8)
self._init_dataloader()
def _init_dataloader(self):
n_train = len(self.dataset)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
def _train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
if epoch >= self.warmup_epochs:
# 1) train architecture parameters
for _, module in self.nas_modules:
module.resample()
self.ctrl_optim.zero_grad()
logits, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
for _, module in self.nas_modules:
module.finalize_grad()
self.ctrl_optim.step()
# 2) train model parameters
for _, module in self.nas_modules:
module.resample()
self.optimizer.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _logits_and_loss(self, X, y):
logits = self.model(X)
loss = self.loss(logits, y)
return logits, loss
def fit(self):
for i in range(self.num_epochs):
self._train_one_epoch(i)
@torch.no_grad()
def export(self):
result = dict()
for name, module in self.nas_modules:
if name not in result:
result[name] = module.export()
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import random
import torch
import torch.nn as nn
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger = logging.getLogger(__name__)
def _get_mask(sampled, total):
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)]
return torch.tensor(multihot, dtype=torch.bool)
class PathSamplingLayerChoice(nn.Module):
"""
Mixed module, in which fprop is decided by exactly one or multiple (sampled) module.
If multiple module is selected, the result will be sumed and returned.
Attributes
----------
sampled : int or list of int
Sampled module indices.
mask : tensor
A multi-hot bool 1D-tensor representing the sampled mask.
"""
def __init__(self, layer_choice):
super(PathSamplingLayerChoice, self).__init__()
self.op_names = []
for name, module in layer_choice.named_children():
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self.sampled = None # sampled can be either a list of indices or an index
def forward(self, *args, **kwargs):
assert self.sampled is not None, 'At least one path needs to be sampled before fprop.'
if isinstance(self.sampled, list):
return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled])
else:
return getattr(self, self.op_names[self.sampled])(*args, **kwargs)
def __len__(self):
return len(self.op_names)
@property
def mask(self):
return _get_mask(self.sampled, len(self))
class PathSamplingInputChoice(nn.Module):
"""
Mixed input. Take a list of tensor as input, select some of them and return the sum.
Attributes
----------
sampled : int or list of int
Sampled module indices.
mask : tensor
A multi-hot bool 1D-tensor representing the sampled mask.
"""
def __init__(self, input_choice):
super(PathSamplingInputChoice, self).__init__()
self.n_candidates = input_choice.n_candidates
self.n_chosen = input_choice.n_chosen
self.sampled = None
def forward(self, input_tensors):
if isinstance(self.sampled, list):
return sum([input_tensors[t] for t in self.sampled])
else:
return input_tensors[self.sampled]
def __len__(self):
return self.n_candidates
@property
def mask(self):
return _get_mask(self.sampled, len(self))
class SinglePathTrainer(BaseOneShotTrainer):
"""
Single-path trainer. Samples a path every time and backpropagates on that path.
Parameters
----------
model : nn.Module
Model with mutables.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : Dataset
Dataset of training.
dataset_valid : Dataset
Dataset of validation.
batch_size : int
Batch size.
workers: int
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None):
self.model = model
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.model.to(self.device)
self.nas_modules = []
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def _resample(self):
result = {}
for name, module in self.nas_modules:
if name not in result:
result[name] = random.randint(0, len(module) - 1)
module.sampled = result[name]
return result
def _train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
self._resample()
logits = self.model(x)
loss = self.loss(logits, y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self._resample()
logits = self.model(x)
loss = self.loss(logits, y)
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)
def fit(self):
for i in range(self.num_epochs):
self._train_one_epoch(i)
self._validate_one_epoch(i)
def export(self):
return self._resample()
RandomTrainer = SinglePathTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import OrderedDict
import numpy as np
import torch
from nni.nas.pytorch.mutables import InputChoice, LayerChoice
_logger = logging.getLogger(__name__)
def to_device(obj, device):
"""
Move a tensor, tuple, list, or dict onto device.
"""
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, tuple):
return tuple(to_device(t, device) for t in obj)
if isinstance(obj, list):
return [to_device(t, device) for t in obj]
if isinstance(obj, dict):
return {k: to_device(v, device) for k, v in obj.items()}
if isinstance(obj, (int, float, str)):
return obj
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))
def to_list(arr):
if torch.is_tensor(arr):
return arr.cpu().numpy().tolist()
if isinstance(arr, np.ndarray):
return arr.tolist()
if isinstance(arr, (list, tuple)):
return list(arr)
return arr
class AverageMeterGroup:
"""
Average meter group for multiple average meters.
"""
def __init__(self):
self.meters = OrderedDict()
def update(self, data):
"""
Update the meter group with a dict of metrics.
Non-exist average meters will be automatically created.
"""
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k, ":4f")
self.meters[k].update(v)
def __getattr__(self, item):
return self.meters[item]
def __getitem__(self, item):
return self.meters[item]
def __str__(self):
return " ".join(str(v) for v in self.meters.values())
def summary(self):
"""
Return a summary string of group data.
"""
return " ".join(v.summary() for v in self.meters.values())
class AverageMeter:
"""
Computes and stores the average and current value.
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
"""
Reset the meter.
"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""
Update with value and weight.
Parameters
----------
val : float or int
The new value to be accounted in.
n : int
The weight of the new value.
"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = '{name}: {avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
def _replace_module_with_type(root_module, init_fn, type, modules):
if modules is None:
modules = []
def apply(m):
for name, child in m.named_children():
if isinstance(child, type):
setattr(m, name, init_fn(child))
modules.append((child.key, getattr(m, name)))
else:
apply(child)
apply(root_module)
return modules
def replace_layer_choice(root_module, init_fn, modules=None):
"""
Replace layer choice modules with modules that are initiated with init_fn.
Parameters
----------
root_module : nn.Module
Root module to traverse.
init_fn : Callable
Initializing function.
modules : dict, optional
Update the replaced modules into the dict and check duplicate if provided.
Returns
-------
List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, LayerChoice, modules)
def replace_input_choice(root_module, init_fn, modules=None):
"""
Replace input choice modules with modules that are initiated with init_fn.
Parameters
----------
root_module : nn.Module
Root module to traverse.
init_fn : Callable
Initializing function.
modules : dict, optional
Update the replaced modules into the dict and check duplicate if provided.
Returns
-------
List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, InputChoice, modules)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment