Unverified Commit 468917ca authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #3155 from microsoft/dev-retiarii

[Do NOT Squash] Merge retiarii dev branch to master
parents f8424a9f d5a551c8
from ..operation import TensorFlowOperation
class Conv2D(TensorFlowOperation):
def __init__(self, type_name, parameters, _internal):
if 'padding' not in parameters:
parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal)
from ..operation import PyTorchOperation
class relu(PyTorchOperation):
def to_init_code(self, field):
return ''
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = nn.functional.relu({inputs[0]})'
class Flatten(PyTorchOperation):
def to_init_code(self, field):
return ''
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = {inputs[0]}.view({inputs[0]}.size(0), -1)'
class ToDevice(PyTorchOperation):
def to_init_code(self, field):
return ''
def to_forward_code(self, field, output, inputs) -> str:
assert len(inputs) == 1
return f"{output} = {inputs[0]}.to('{self.parameters['device']}')"
class Dense(PyTorchOperation):
def to_init_code(self, field):
return f"self.{field} = nn.Linear({self.parameters['in_features']}, {self.parameters['out_features']})"
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = self.{field}({inputs[0]})'
class Softmax(PyTorchOperation):
def to_init_code(self, field):
return ''
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = F.softmax({inputs[0]}, -1)'
from .tpe_strategy import TPEStrategy
import abc
from typing import List
from ..graph import Model
from ..mutator import Mutator
class BaseStrategy(abc.ABC):
@abc.abstractmethod
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass
import logging
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
from .. import Sampler, submit_models, wait_models
from .strategy import BaseStrategy
_logger = logging.getLogger(__name__)
class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'):
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
self.cur_sample = None
self.index = None
self.total_parameters = {}
def update_sample_space(self, sample_space):
search_space = {}
for i, each in enumerate(sample_space):
search_space[str(i)] = {'_type': 'choice', '_value': each}
self.tpe_tuner.update_search_space(search_space)
def generate_samples(self, model_id):
self.cur_sample = self.tpe_tuner.generate_parameters(model_id)
self.total_parameters[model_id] = self.cur_sample
self.index = 0
def receive_result(self, model_id, result):
self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result)
def choice(self, candidates, mutator, model, index):
chosen = self.cur_sample[str(self.index)]
self.index += 1
return chosen
class TPEStrategy(BaseStrategy):
def __init__(self):
self.tpe_sampler = TPESampler()
self.model_id = 0
def run(self, base_model, applied_mutators):
sample_space = []
new_model = base_model
for mutator in applied_mutators:
recorded_candidates, new_model = mutator.dry_run(new_model)
sample_space.extend(recorded_candidates)
self.tpe_sampler.update_sample_space(sample_space)
try:
_logger.info('stargety start...')
while True:
model = base_model
_logger.info('apply mutators...')
_logger.info('mutators: %s', str(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators:
_logger.info('mutate model...')
mutator.bind_sampler(self.tpe_sampler)
model = mutator.apply(model)
# run models
submit_models(model)
wait_models(model)
self.tpe_sampler.receive_result(self.model_id, model.metric)
self.model_id += 1
_logger.info('Strategy says: %s', model.metric)
except Exception:
_logger.error(logging.exception('message'))
from .interface import BaseTrainer
from .pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
import abc
from typing import Any
class BaseTrainer(abc.ABC):
"""
In this version, we plan to write our own trainers instead of using PyTorch-lightning, to
ease the burden to integrate our optmization with PyTorch-lightning, a large part of which is
opaque to us.
We will try to align with PyTorch-lightning name conversions so that we can easily migrate to
PyTorch-lightning in the future.
Currently, our trainer = LightningModule + LightningTrainer. We might want to separate these two things
in future.
Trainer has a ``fit`` function with no return value. Intermediate results and final results should be
directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions.
"""
@abc.abstractmethod
def fit(self) -> None:
pass
class BaseOneShotTrainer(BaseTrainer):
"""
Build many (possibly all) architectures into a full graph, search (with train) and export the best.
It has an extra ``export`` function that exports an object representing the final searched architecture.
"""
@abc.abstractmethod
def export(self) -> Any:
pass
from .base import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from .darts import DartsTrainer
from .enas import EnasTrainer
from .proxyless import ProxylessTrainer
from .random import RandomTrainer, SinglePathTrainer
from typing import Any, List, Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import nni
from ..interface import BaseTrainer
from ...utils import register_trainer
def get_default_transform(dataset: str) -> Any:
"""
To get a default transformation of image for a specific dataset.
This is needed because transform objects can not be directly passed as arguments.
Parameters
----------
dataset : str
Dataset class name.
Returns
-------
transform object
"""
if dataset == 'MNIST':
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
if dataset == 'CIFAR10':
return transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
# unsupported dataset, return None
return None
@register_trainer()
class PyTorchImageClassificationTrainer(BaseTrainer):
"""
Image classification trainer for PyTorch.
A model, along with corresponding dataset, optimizer config is used to initialize the trainer.
The trainer will run for a fixed number of epochs (by default 10), and report the final result.
TODO
Support scheduler, validate every n epochs, train/valid dataset
Limitation induced by NNI: kwargs must be serializable to put into a JSON packed in parameters.
"""
def __init__(self, model,
dataset_cls='MNIST', dataset_kwargs=None, dataloader_kwargs=None,
optimizer_cls='SGD', optimizer_kwargs=None, trainer_kwargs=None):
"""Initialization of image classification trainer.
Parameters
----------
model : nn.Module
Model to train.
dataset_cls : str, optional
Dataset class name that is available in ``torchvision.datasets``, by default 'MNIST'
dataset_kwargs : dict, optional
Keyword arguments passed to initialization of dataset class, by default None
dataset_kwargs : dict, optional
Keyword arguments passed to ``torch.utils.data.DataLoader``, by default None
optimizer_cls : str, optional
Optimizer class name that is available in ``torch.optim``, by default 'SGD'
optimizer_kwargs : dict, optional
Keyword arguments passed to initialization of optimizer class, by default None
trainer_kwargs: dict, optional
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super(PyTorchImageClassificationTrainer, self).__init__()
self._use_cuda = torch.cuda.is_available()
self.model = model
if self._use_cuda:
self.model.cuda()
self._loss_fn = nn.CrossEntropyLoss()
self._train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)(model.parameters(), **(optimizer_kwargs or {}))
self._trainer_kwargs = trainer_kwargs or {'max_epochs': 10}
self._train_dataloader = DataLoader(self._train_dataset, **(dataloader_kwargs or {}))
self._val_dataloader = DataLoader(self._val_dataset, **(dataloader_kwargs or {}))
def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.training_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.training_step_after_model(x, y, y_hat)
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
x, y = batch
if self._use_cuda:
x, y = x.cuda(torch.device('cuda:0')), y.cuda(torch.device('cuda:0'))
return x, y
def training_step_after_model(self, x, y, y_hat):
loss = self._loss_fn(y_hat, y)
return loss
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.validation_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.validation_step_after_model(x, y, y_hat)
def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
x, y = batch
if self._use_cuda:
x, y = x.cuda(), y.cuda()
return x, y
def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y)
return {'val_acc': acc}
def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
# We might need dict metrics in future?
avg_acc = np.mean([x['val_acc'] for x in outputs]).item()
nni.report_intermediate_result(avg_acc)
return {'val_acc': avg_acc}
def _validate(self):
validation_outputs = []
for i, batch in enumerate(self._val_dataloader):
validation_outputs.append(self.validation_step(batch, i))
return self.validation_epoch_end(validation_outputs)
def _train(self):
for i, batch in enumerate(self._train_dataloader):
loss = self.training_step(batch, i)
loss.backward()
def fit(self) -> None:
for _ in range(self._trainer_kwargs['max_epochs']):
self._train()
# assuming val_acc here
nni.report_final_result(self._validate()['val_acc'])
class PyTorchMultiModelTrainer(BaseTrainer):
def __init__(self, multi_model, kwargs=[]):
self.multi_model = multi_model
self.kwargs = kwargs
self._train_dataloaders = []
self._train_datasets = []
self._val_dataloaders = []
self._val_datasets = []
self._optimizers = []
self._trainers = []
self._loss_fn = nn.CrossEntropyLoss()
self.max_steps = self.kwargs['max_steps'] if 'makx_steps' in self.kwargs else None
self.n_model = len(self.kwargs['model_kwargs'])
for m in self.kwargs['model_kwargs']:
if m['use_input']:
dataset_cls = m['dataset_cls']
dataset_kwargs = m['dataset_kwargs']
dataloader_kwargs = m['dataloader_kwargs']
train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
train_dataloader = DataLoader(train_dataset, **(dataloader_kwargs or {}))
val_dataloader = DataLoader(val_dataset, **(dataloader_kwargs or {}))
self._train_datasets.append(train_dataset)
self._train_dataloaders.append(train_dataloader)
self._val_datasets.append(val_dataset)
self._val_dataloaders.append(val_dataloader)
if m['use_output']:
optimizer_cls = m['optimizer_cls']
optimizer_kwargs = m['optimizer_kwargs']
m_header = f"M_{m['model_id']}"
one_model_params = []
for name, param in multi_model.named_parameters():
name_prefix = '_'.join(name.split('_')[:2])
if m_header == name_prefix:
one_model_params.append(param)
optimizer = getattr(torch.optim, optimizer_cls)(one_model_params, **(optimizer_kwargs or {}))
self._optimizers.append(optimizer)
def fit(self) -> None:
torch.autograd.set_detect_anomaly(True)
max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']])
for _ in range(max_epochs):
self._train()
nni.report_final_result(self._validate())
def _train(self):
for batch_idx, multi_model_batch in enumerate(zip(*self._train_dataloaders)):
for opt in self._optimizers:
opt.zero_grad()
xs = []
ys = []
for idx, batch in enumerate(multi_model_batch):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x)
ys.append(y)
y_hats = self.multi_model(*xs)
if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)')
losses = []
report_loss = {}
for output_idx, yhat in enumerate(y_hats):
if len(ys) == len(y_hats):
loss = self.training_step_after_model(xs[output_idx], ys[output_idx], yhat)
elif len(ys) == 1:
loss = self.training_step_after_model(xs[0], ys[0].to(yhat.get_device()), yhat)
else:
raise ValueError('len(ys) should be either 1 or len(y_hats)')
losses.append(loss.to("cuda:0"))
report_loss[self.kwargs['model_kwargs'][output_idx]['model_id']] = loss.item()
summed_loss = sum(losses)
summed_loss.backward()
for opt in self._optimizers:
opt.step()
if self.max_steps and batch_idx >= self.max_steps:
return
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if device:
x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
return x, y
def training_step_after_model(self, x, y, y_hat):
loss = self._loss_fn(y_hat, y)
return loss
def _validate(self):
all_val_outputs = {idx: [] for idx in range(self.n_model)}
for batch_idx, multi_model_batch in enumerate(zip(*self._val_dataloaders)):
xs = []
ys = []
for idx, batch in enumerate(multi_model_batch):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x)
ys.append(y)
if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)')
y_hats = self.multi_model(*xs)
for output_idx, yhat in enumerate(y_hats):
if len(ys) == len(y_hats):
acc = self.validation_step_after_model(xs[output_idx], ys[output_idx], yhat)
elif len(ys) == 1:
acc = self.validation_step_after_model(xs[0], ys[0].to(yhat.get_device()), yhat)
else:
raise ValueError('len(ys) should be either 1 or len(y_hats)')
all_val_outputs[output_idx].append(acc)
report_acc = {}
for idx in all_val_outputs:
avg_acc = np.mean([x['val_acc'] for x in all_val_outputs[idx]]).item()
report_acc[self.kwargs['model_kwargs'][idx]['model_id']] = avg_acc
nni.report_intermediate_result(report_acc)
return report_acc
def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if device:
x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
return x, y
def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y)
return {'val_acc': acc}
def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import logging
import torch
import torch.nn as nn
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger = logging.getLogger(__name__)
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs):
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * self.alpha.view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super(DartsLayerChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return torch.argmax(self.alpha).item()
class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1
def forward(self, inputs):
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * self.alpha.view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super(DartsInputChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return torch.argsort(-self.alpha).cpu().numpy().tolist()[:self.n_chosen]
class DartsTrainer(BaseOneShotTrainer):
"""
DARTS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
learning_rate : float
Learning rate to optimize the model.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
def __init__(self, model, loss, metrics, optimizer,
num_epochs, dataset, grad_clip=5.,
learning_rate=2.5E-3, batch_size=64, workers=4,
device=None, log_frequency=None,
arc_learning_rate=3.0E-4, unrolled=False):
self.model = model
self.loss = loss
self.metrics = metrics
self.num_epochs = num_epochs
self.dataset = dataset
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.model.to(self.device)
self.nas_modules = []
replace_layer_choice(self.model, DartsLayerChoice, self.nas_modules)
replace_input_choice(self.model, DartsInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.model_optim = optimizer
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3)
self.unrolled = unrolled
self.grad_clip = 5.
self._init_dataloader()
def _init_dataloader(self):
n_train = len(self.dataset)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
def _train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
# phase 1. architecture step
self.ctrl_optim.zero_grad()
if self.unrolled:
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
else:
self._backward(val_X, val_y)
self.ctrl_optim.step()
# phase 2: child network step
self.model_optim.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) # gradient clipping
self.model_optim.step()
metrics = self.metrics(logits, trn_y)
metrics['loss'] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info('Epoch [%s/%s] Step [%s/%s] %s', epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _logits_and_loss(self, X, y):
logits = self.model(X)
loss = self.loss(logits, y)
return logits, loss
def _backward(self, val_X, val_y):
"""
Simple backward with gradient descent
"""
_, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
"""
Compute unrolled loss and backward its gradients
"""
backup_params = copy.deepcopy(tuple(self.model.parameters()))
# do virtual step on training data
lr = self.model_optim.param_groups[0]["lr"]
momentum = self.model_optim.param_groups[0]["momentum"]
weight_decay = self.model_optim.param_groups[0]["weight_decay"]
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_, loss = self._logits_and_loss(val_X, val_y)
w_model, w_ctrl = tuple(self.model.parameters()), tuple([c.alpha for c in self.nas_modules])
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
with torch.no_grad():
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h
# restore weights
self._restore_weights(backup_params)
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
"""
Compute unrolled weights w`
"""
# don't need zero_grad, using autograd to calculate gradients
_, loss = self._logits_and_loss(X, y)
gradients = torch.autograd.grad(loss, self.model.parameters())
with torch.no_grad():
for w, g in zip(self.model.parameters(), gradients):
m = self.model_optim.state[w].get('momentum_buffer', 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
with torch.no_grad():
for param, backup in zip(self.model.parameters(), backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
_logger.warning('In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.', norm.item())
dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.model.parameters(), dw):
p += e * d
_, loss = self._logits_and_loss(trn_X, trn_y)
dalphas.append(torch.autograd.grad(loss, [c.alpha for c in self.nas_modules]))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian
def fit(self):
for i in range(self.num_epochs):
self._train_one_epoch(i)
@torch.no_grad()
def export(self):
result = dict()
for name, module in self.nas_modules:
if name not in result:
result[name] = module.export()
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ..interface import BaseOneShotTrainer
from .random import PathSamplingLayerChoice, PathSamplingInputChoice
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, to_device
_logger = logging.getLogger(__name__)
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c
class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
class ReinforceController(nn.Module):
"""
A controller that mutates the graph with RL.
Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, entropy_reduction='sum'):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.skip_target = skip_target
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
})
self.embedding = nn.ModuleDict({
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})
def resample(self):
self._initialize()
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
return result
def _initialize(self):
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field):
self._lstm_next_step()
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one:
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device)
sampled = sampled.detach().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1:
sampled = sampled[0]
return sampled
class EnasTrainer(BaseOneShotTrainer):
"""
ENAS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_lr : float
Learning rate for RL controller.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_steps : int
Number of mini-batches for each epoch of RL controller learning.
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
"""
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset,
batch_size=64, workers=4, device=None, log_frequency=None,
grad_clip=5., entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None):
self.model = model
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.num_epochs = num_epochs
self.dataset = dataset
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.nas_modules = []
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.model.to(self.device)
self.nas_fields = [ReinforceField(name, len(module),
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1)
for name, module in self.nas_modules]
self.controller = ReinforceController(self.nas_fields, **(ctrl_kwargs or {}))
self.grad_clip = grad_clip
self.reward_function = reward_function
self.ctrl_optim = optim.Adam(self.controller.parameters(), lr=ctrl_lr)
self.batch_size = batch_size
self.workers = workers
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.init_dataloader()
def init_dataloader(self):
n_train = len(self.dataset)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
def _train_model(self, epoch):
self.model.train()
self.controller.eval()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = to_device(x, self.device), to_device(y, self.device)
self.optimizer.zero_grad()
self._resample()
logits = self.model(x)
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
loss.backward()
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.optimizer.step()
metrics['loss'] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info('Model Epoch [%d/%d] Step [%d/%d] %s', epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _train_controller(self, epoch):
self.model.eval()
self.controller.train()
meters = AverageMeterGroup()
self.ctrl_optim.zero_grad()
for ctrl_step, (x, y) in enumerate(self.valid_loader):
x, y = to_device(x, self.device), to_device(y, self.device)
self._resample()
with torch.no_grad():
logits = self.model(x)
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight:
reward += self.entropy_weight * self.controller.sample_entropy.item()
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.controller.sample_skip_penalty
metrics['reward'] = reward
metrics['loss'] = loss.item()
metrics['ent'] = self.controller.sample_entropy.item()
metrics['log_prob'] = self.controller.sample_log_prob.item()
metrics['baseline'] = self.baseline
metrics['skip'] = self.controller.sample_skip_penalty
loss /= self.ctrl_steps_aggregate
loss.backward()
meters.update(metrics)
if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0:
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.grad_clip)
self.ctrl_optim.step()
self.ctrl_optim.zero_grad()
if self.log_frequency is not None and ctrl_step % self.log_frequency == 0:
_logger.info('RL Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs,
ctrl_step + 1, len(self.valid_loader), meters)
def _resample(self):
result = self.controller.resample()
for name, module in self.nas_modules:
module.sampled = result[name]
def fit(self):
for i in range(self.num_epochs):
self._train_model(i)
self._train_controller(i)
def export(self):
self.controller.eval()
with torch.no_grad():
return self.controller.resample()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger = logging.getLogger(__name__)
class ArchGradientFunction(torch.autograd.Function):
@staticmethod
def forward(ctx, x, binary_gates, run_func, backward_func):
ctx.run_func = run_func
ctx.backward_func = backward_func
detached_x = x.detach()
detached_x.requires_grad = x.requires_grad
with torch.enable_grad():
output = run_func(detached_x)
ctx.save_for_backward(detached_x, output)
return output.data
@staticmethod
def backward(ctx, grad_output):
detached_x, output = ctx.saved_tensors
grad_x = torch.autograd.grad(output, detached_x, grad_output, only_inputs=True)
# compute gradients w.r.t. binary_gates
binary_grads = ctx.backward_func(detached_x.data, output.data, grad_output.data)
return grad_x[0], binary_grads, None, None
class ProxylessLayerChoice(nn.Module):
def __init__(self, ops):
super(ProxylessLayerChoice, self).__init__()
self.ops = nn.ModuleList(ops)
self.alpha = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
self._binary_gates = nn.Parameter(torch.randn(len(self.ops)) * 1E-3)
self.sampled = None
def forward(self, *args):
def run_function(ops, active_id):
def forward(_x):
return ops[active_id](_x)
return forward
def backward_function(ops, active_id, binary_gates):
def backward(_x, _output, grad_output):
binary_grads = torch.zeros_like(binary_gates.data)
with torch.no_grad():
for k in range(len(ops)):
if k != active_id:
out_k = ops[k](_x.data)
else:
out_k = _output.data
grad_k = torch.sum(out_k * grad_output)
binary_grads[k] = grad_k
return binary_grads
return backward
assert len(args) == 1
x = args[0]
return ArchGradientFunction.apply(
x, self._binary_gates, run_function(self.ops, self.sampled),
backward_function(self.ops, self.sampled, self._binary_gates)
)
def resample(self):
probs = F.softmax(self.alpha, dim=-1)
sample = torch.multinomial(probs, 1)[0].item()
self.sampled = sample
with torch.no_grad():
self._binary_gates.zero_()
self._binary_gates.grad = torch.zeros_like(self._binary_gates.data)
self._binary_gates.data[sample] = 1.0
def finalize_grad(self):
binary_grads = self._binary_gates.grad
with torch.no_grad():
if self.alpha.grad is None:
self.alpha.grad = torch.zeros_like(self.alpha.data)
probs = F.softmax(self.alpha, dim=-1)
for i in range(len(self.ops)):
for j in range(len(self.ops)):
self.alpha.grad[i] += binary_grads[j] * probs[j] * (int(i == j) - probs[i])
def export(self):
return torch.argmax(self.alpha).item()
class ProxylessInputChoice(nn.Module):
def __init__(self, *args, **kwargs):
raise NotImplementedError('Input choice is not supported for ProxylessNAS.')
class ProxylessTrainer(BaseOneShotTrainer):
"""
Proxyless trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
warmup_epochs : int
Number of epochs to warmup model parameters.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
arc_learning_rate : float
Learning rate of architecture parameters.
"""
def __init__(self, model, loss, metrics, optimizer,
num_epochs, dataset, warmup_epochs=0,
batch_size=64, workers=4, device=None, log_frequency=None,
arc_learning_rate=1.0E-3):
self.model = model
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.num_epochs = num_epochs
self.warmup_epochs = warmup_epochs
self.dataset = dataset
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.model.to(self.device)
self.nas_modules = []
replace_layer_choice(self.model, ProxylessLayerChoice, self.nas_modules)
replace_input_choice(self.model, ProxylessInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.optimizer = optimizer
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate,
weight_decay=0, betas=(0, 0.999), eps=1e-8)
self._init_dataloader()
def _init_dataloader(self):
n_train = len(self.dataset)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
def _train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
if epoch >= self.warmup_epochs:
# 1) train architecture parameters
for _, module in self.nas_modules:
module.resample()
self.ctrl_optim.zero_grad()
logits, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
for _, module in self.nas_modules:
module.finalize_grad()
self.ctrl_optim.step()
# 2) train model parameters
for _, module in self.nas_modules:
module.resample()
self.optimizer.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, trn_y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _logits_and_loss(self, X, y):
logits = self.model(X)
loss = self.loss(logits, y)
return logits, loss
def fit(self):
for i in range(self.num_epochs):
self._train_one_epoch(i)
@torch.no_grad()
def export(self):
result = dict()
for name, module in self.nas_modules:
if name not in result:
result[name] = module.export()
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import random
import torch
import torch.nn as nn
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger = logging.getLogger(__name__)
def _get_mask(sampled, total):
multihot = [i == sampled or (isinstance(sampled, list) and i in sampled) for i in range(total)]
return torch.tensor(multihot, dtype=torch.bool) # pylint: disable=not-callable
class PathSamplingLayerChoice(nn.Module):
"""
Mixed module, in which fprop is decided by exactly one or multiple (sampled) module.
If multiple module is selected, the result will be sumed and returned.
Attributes
----------
sampled : int or list of int
Sampled module indices.
mask : tensor
A multi-hot bool 1D-tensor representing the sampled mask.
"""
def __init__(self, layer_choice):
super(PathSamplingLayerChoice, self).__init__()
self.op_names = []
for name, module in layer_choice.named_children():
self.add_module(name, module)
self.op_names.append(name)
assert self.op_names, 'There has to be at least one op to choose from.'
self.sampled = None # sampled can be either a list of indices or an index
def forward(self, *args, **kwargs):
assert self.sampled is not None, 'At least one path needs to be sampled before fprop.'
if isinstance(self.sampled, list):
return sum([getattr(self, self.op_names[i])(*args, **kwargs) for i in self.sampled]) # pylint: disable=not-an-iterable
else:
return getattr(self, self.op_names[self.sampled])(*args, **kwargs) # pylint: disable=invalid-sequence-index
def __len__(self):
return len(self.op_names)
@property
def mask(self):
return _get_mask(self.sampled, len(self))
class PathSamplingInputChoice(nn.Module):
"""
Mixed input. Take a list of tensor as input, select some of them and return the sum.
Attributes
----------
sampled : int or list of int
Sampled module indices.
mask : tensor
A multi-hot bool 1D-tensor representing the sampled mask.
"""
def __init__(self, input_choice):
super(PathSamplingInputChoice, self).__init__()
self.n_candidates = input_choice.n_candidates
self.n_chosen = input_choice.n_chosen
self.sampled = None
def forward(self, input_tensors):
if isinstance(self.sampled, list):
return sum([input_tensors[t] for t in self.sampled]) # pylint: disable=not-an-iterable
else:
return input_tensors[self.sampled]
def __len__(self):
return self.n_candidates
@property
def mask(self):
return _get_mask(self.sampled, len(self))
class SinglePathTrainer(BaseOneShotTrainer):
"""
Single-path trainer. Samples a path every time and backpropagates on that path.
Parameters
----------
model : nn.Module
Model with mutables.
loss : callable
Called with logits and targets. Returns a loss tensor.
metrics : callable
Returns a dict that maps metrics keys to metrics data.
optimizer : Optimizer
Optimizer that optimizes the model.
num_epochs : int
Number of epochs of training.
dataset_train : Dataset
Dataset of training.
dataset_valid : Dataset
Dataset of validation.
batch_size : int
Batch size.
workers: int
Number of threads for data preprocessing. Not used for this trainer. Maybe removed in future.
device : torch.device
Device object. Either ``torch.device("cuda")`` or ``torch.device("cpu")``. When ``None``, trainer will
automatic detects GPU and selects GPU first.
log_frequency : int
Number of mini-batches to log metrics.
"""
def __init__(self, model, loss, metrics,
optimizer, num_epochs, dataset_train, dataset_valid,
mutator=None, batch_size=64, workers=4, device=None, log_frequency=None):
self.model = model
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.num_epochs = num_epochs
self.dataset_train = dataset_train
self.dataset_valid = dataset_valid
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.model.to(self.device)
self.nas_modules = []
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.train_loader = torch.utils.data.DataLoader(self.dataset_train,
batch_size=batch_size,
num_workers=workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset_valid,
batch_size=batch_size,
num_workers=workers)
def _resample(self):
result = {}
for name, module in self.nas_modules:
if name not in result:
result[name] = random.randint(0, len(module) - 1)
module.sampled = result[name]
return result
def _train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = x.to(self.device), y.to(self.device)
self.optimizer.zero_grad()
self._resample()
logits = self.model(x)
loss = self.loss(logits, y)
loss.backward()
self.optimizer.step()
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info("Epoch [%s/%s] Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _validate_one_epoch(self, epoch):
self.model.eval()
meters = AverageMeterGroup()
with torch.no_grad():
for step, (x, y) in enumerate(self.valid_loader):
x, y = x.to(self.device), y.to(self.device)
self._resample()
logits = self.model(x)
loss = self.loss(logits, y)
metrics = self.metrics(logits, y)
metrics["loss"] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info("Epoch [%s/%s] Validation Step [%s/%s] %s", epoch + 1,
self.num_epochs, step + 1, len(self.valid_loader), meters)
def fit(self):
for i in range(self.num_epochs):
self._train_one_epoch(i)
self._validate_one_epoch(i)
def export(self):
return self._resample()
RandomTrainer = SinglePathTrainer
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
from collections import OrderedDict
import numpy as np
import torch
from nni.nas.pytorch.mutables import InputChoice, LayerChoice
_logger = logging.getLogger(__name__)
def to_device(obj, device):
"""
Move a tensor, tuple, list, or dict onto device.
"""
if torch.is_tensor(obj):
return obj.to(device)
if isinstance(obj, tuple):
return tuple(to_device(t, device) for t in obj)
if isinstance(obj, list):
return [to_device(t, device) for t in obj]
if isinstance(obj, dict):
return {k: to_device(v, device) for k, v in obj.items()}
if isinstance(obj, (int, float, str)):
return obj
raise ValueError("'%s' has unsupported type '%s'" % (obj, type(obj)))
def to_list(arr):
if torch.is_tensor(arr):
return arr.cpu().numpy().tolist()
if isinstance(arr, np.ndarray):
return arr.tolist()
if isinstance(arr, (list, tuple)):
return list(arr)
return arr
class AverageMeterGroup:
"""
Average meter group for multiple average meters.
"""
def __init__(self):
self.meters = OrderedDict()
def update(self, data):
"""
Update the meter group with a dict of metrics.
Non-exist average meters will be automatically created.
"""
for k, v in data.items():
if k not in self.meters:
self.meters[k] = AverageMeter(k, ":4f")
self.meters[k].update(v)
def __getattr__(self, item):
return self.meters[item]
def __getitem__(self, item):
return self.meters[item]
def __str__(self):
return " ".join(str(v) for v in self.meters.values())
def summary(self):
"""
Return a summary string of group data.
"""
return " ".join(v.summary() for v in self.meters.values())
class AverageMeter:
"""
Computes and stores the average and current value.
Parameters
----------
name : str
Name to display.
fmt : str
Format string to print the values.
"""
def __init__(self, name, fmt=':f'):
self.name = name
self.fmt = fmt
self.reset()
def reset(self):
"""
Reset the meter.
"""
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
"""
Update with value and weight.
Parameters
----------
val : float or int
The new value to be accounted in.
n : int
The weight of the new value.
"""
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def __str__(self):
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
def summary(self):
fmtstr = '{name}: {avg' + self.fmt + '}'
return fmtstr.format(**self.__dict__)
def _replace_module_with_type(root_module, init_fn, type_name, modules):
if modules is None:
modules = []
def apply(m):
for name, child in m.named_children():
if isinstance(child, type_name):
setattr(m, name, init_fn(child))
modules.append((child.key, getattr(m, name)))
else:
apply(child)
apply(root_module)
return modules
def replace_layer_choice(root_module, init_fn, modules=None):
"""
Replace layer choice modules with modules that are initiated with init_fn.
Parameters
----------
root_module : nn.Module
Root module to traverse.
init_fn : Callable
Initializing function.
modules : dict, optional
Update the replaced modules into the dict and check duplicate if provided.
Returns
-------
List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, LayerChoice, modules)
def replace_input_choice(root_module, init_fn, modules=None):
"""
Replace input choice modules with modules that are initiated with init_fn.
Parameters
----------
root_module : nn.Module
Root module to traverse.
init_fn : Callable
Initializing function.
modules : dict, optional
Update the replaced modules into the dict and check duplicate if provided.
Returns
-------
List[Tuple[str, nn.Module]]
A list from layer choice keys (names) and replaced modules.
"""
return _replace_module_with_type(root_module, init_fn, InputChoice, modules)
"""
Entrypoint for trials.
Assuming execution engine is BaseExecutionEngine.
"""
import os
from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
if __name__ == '__main__':
if os.environ.get('CGO') == 'true':
CGOExecutionEngine.trial_execute_graph()
else:
BaseExecutionEngine.trial_execute_graph()
import inspect
from collections import defaultdict
from typing import Any
def import_(target: str, allow_none: bool = False) -> Any:
if target is None:
return None
path, identifier = target.rsplit('.', 1)
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)
_records = {}
def get_records():
global _records
return _records
def add_record(key, value):
"""
"""
global _records
if _records is not None:
assert key not in _records, '{} already in _records'.format(key)
_records[key] = value
def _register_module(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
full_args[argname_list[i]] = arg
add_record(id(self), full_args)
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def register_module():
"""
Register a module.
"""
# use it as a decorator: @register_module()
def _register(cls):
m = _register_module(
original_class=cls)
return m
return _register
def _register_trainer(original_class):
orig_init = original_class.__init__
argname_list = list(inspect.signature(original_class).parameters.keys())
# Make copy of original __init__, so we can call it without recursion
full_class_name = original_class.__module__ + '.' + original_class.__name__
def __init__(self, *args, **kws):
full_args = {}
full_args.update(kws)
for i, arg in enumerate(args):
# TODO: support both pytorch and tensorflow
from .nn.pytorch import Module
if isinstance(args[i], Module):
# ignore the base model object
continue
full_args[argname_list[i]] = arg
add_record(id(self), {'modulename': full_class_name, 'args': full_args})
orig_init(self, *args, **kws) # Call the original __init__
original_class.__init__ = __init__ # Set the class' __init__ to the new one
return original_class
def register_trainer():
def _register(cls):
m = _register_trainer(
original_class=cls)
return m
return _register
_last_uid = defaultdict(int)
def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
return _last_uid[namespace]
...@@ -41,7 +41,7 @@ jobs: ...@@ -41,7 +41,7 @@ jobs:
python3 -m pip install --upgrade pygments python3 -m pip install --upgrade pygments
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade tensorflow python3 -m pip install --upgrade tensorflow
python3 -m pip install --upgrade gym onnx peewee thop python3 -m pip install --upgrade gym onnx peewee thop graphviz
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
sudo apt-get install swig -y sudo apt-get install swig -y
python3 -m pip install -e .[SMAC,BOHB] python3 -m pip install -e .[SMAC,BOHB]
......
...@@ -2,4 +2,9 @@ __pycache__ ...@@ -2,4 +2,9 @@ __pycache__
tuner_search_space.json tuner_search_space.json
tuner_result.txt tuner_result.txt
assessor_result.txt assessor_result.txt
\ No newline at end of file
_generated_model.py
data
generated
from collections import OrderedDict
from typing import (List, Optional)
import torch
import torch.nn as torch_nn
#sys.path.append(str(Path(__file__).resolve().parents[2]))
import ops
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
class AuxiliaryHead(nn.Module):
""" Auxiliary head in 2/3 place of network to let the gradient flow well """
def __init__(self, input_size, C, n_classes):
""" assuming input size 7x7 or 8x8 """
assert input_size in [7, 8]
super().__init__()
self.net = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=input_size - 5, padding=0, count_include_pad=False), # 2x2 out
nn.Conv2d(C, 128, kernel_size=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, kernel_size=2, bias=False), # 1x1 out
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.linear = nn.Linear(768, n_classes)
def forward(self, x):
out = self.net(x)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
return logits
@register_module()
class Node(nn.Module):
def __init__(self, node_id, num_prev_nodes, channels, num_downsample_connect):
super().__init__()
self.ops = nn.ModuleList()
choice_keys = []
for i in range(num_prev_nodes):
stride = 2 if i < num_downsample_connect else 1
choice_keys.append("{}_p{}".format(node_id, i))
self.ops.append(
nn.LayerChoice([
ops.PoolBN('max', channels, 3, stride, 1, affine=False),
ops.PoolBN('avg', channels, 3, stride, 1, affine=False),
nn.Identity() if stride == 1 else ops.FactorizedReduce(channels, channels, affine=False),
ops.SepConv(channels, channels, 3, stride, 1, affine=False),
ops.SepConv(channels, channels, 5, stride, 2, affine=False),
ops.DilConv(channels, channels, 3, stride, 2, 2, affine=False),
ops.DilConv(channels, channels, 5, stride, 4, 2, affine=False)
]))
self.drop_path = ops.DropPath()
self.input_switch = nn.InputChoice(n_chosen=2)
def forward(self, prev_nodes: List['Tensor']) -> 'Tensor':
#assert self.ops.__len__() == len(prev_nodes)
#out = [op(node) for op, node in zip(self.ops, prev_nodes)]
out = []
for i, op in enumerate(self.ops):
out.append(op(prev_nodes[i]))
#out = [self.drop_path(o) if o is not None else None for o in out]
return self.input_switch(out)
@register_module()
class Cell(nn.Module):
def __init__(self, n_nodes, channels_pp, channels_p, channels, reduction_p, reduction):
super().__init__()
self.reduction = reduction
self.n_nodes = n_nodes
# If previous cell is reduction cell, current input size does not match with
# output size of cell[k-2]. So the output[k-2] should be reduced by preprocessing.
if reduction_p:
self.preproc0 = ops.FactorizedReduce(channels_pp, channels, affine=False)
else:
self.preproc0 = ops.StdConv(channels_pp, channels, 1, 1, 0, affine=False)
self.preproc1 = ops.StdConv(channels_p, channels, 1, 1, 0, affine=False)
# generate dag
self.mutable_ops = nn.ModuleList()
for depth in range(2, self.n_nodes + 2):
self.mutable_ops.append(Node("{}_n{}".format("reduce" if reduction else "normal", depth),
depth, channels, 2 if reduction else 0))
def forward(self, s0, s1):
# s0, s1 are the outputs of previous previous cell and previous cell, respectively.
tensors = [self.preproc0(s0), self.preproc1(s1)]
new_tensors = []
for node in self.mutable_ops:
tmp = tensors + new_tensors
cur_tensor = node(tmp)
new_tensors.append(cur_tensor)
output = torch.cat(new_tensors, dim=1)
return output
@register_module()
class CNN(nn.Module):
def __init__(self, input_size, in_channels, channels, n_classes, n_layers, n_nodes=4,
stem_multiplier=3, auxiliary=False):
super().__init__()
self.in_channels = in_channels
self.channels = channels
self.n_classes = n_classes
self.n_layers = n_layers
self.aux_pos = 2 * n_layers // 3 if auxiliary else -1
c_cur = stem_multiplier * self.channels
self.stem = nn.Sequential(
nn.Conv2d(in_channels, c_cur, 3, 1, 1, bias=False),
nn.BatchNorm2d(c_cur)
)
# for the first cell, stem is used for both s0 and s1
# [!] channels_pp and channels_p is output channel size, but c_cur is input channel size.
channels_pp, channels_p, c_cur = c_cur, c_cur, channels
self.cells = nn.ModuleList()
reduction_p, reduction = False, False
for i in range(n_layers):
reduction_p, reduction = reduction, False
# Reduce featuremap size and double channels in 1/3 and 2/3 layer.
if i in [n_layers // 3, 2 * n_layers // 3]:
c_cur *= 2
reduction = True
cell = Cell(n_nodes, channels_pp, channels_p, c_cur, reduction_p, reduction)
self.cells.append(cell)
c_cur_out = c_cur * n_nodes
channels_pp, channels_p = channels_p, c_cur_out
#if i == self.aux_pos:
# self.aux_head = AuxiliaryHead(input_size // 4, channels_p, n_classes)
self.gap = nn.AdaptiveAvgPool2d(1)
self.linear = nn.Linear(channels_p, n_classes)
def forward(self, x):
s0 = s1 = self.stem(x)
#aux_logits = None
for i, cell in enumerate(self.cells):
s0, s1 = s1, cell(s0, s1)
#if i == self.aux_pos and self.training:
# aux_logits = self.aux_head(s1)
out = self.gap(s1)
out = out.view(out.size(0), -1) # flatten
logits = self.linear(out)
#if aux_logits is not None:
# return logits, aux_logits
return logits
def drop_path_prob(self, p):
for module in self.modules():
if isinstance(module, ops.DropPath):
module.p = p
if __name__ == '__main__':
base_model = CNN(32, 3, 16, 10, 8)
import torch
import nni.retiarii.nn.pytorch as nn
from nni.retiarii import register_module
@register_module()
class DropPath(nn.Module):
def __init__(self, p=0.):
"""
Drop path with probability.
Parameters
----------
p : float
Probability of an path to be zeroed.
"""
super(DropPath, self).__init__()
self.p = p
def forward(self, x):
if self.training and self.p > 0.:
keep_prob = 1. - self.p
# per data point mask
mask = torch.zeros((x.size(0), 1, 1, 1), device=x.device).bernoulli_(keep_prob)
return x / keep_prob * mask
return x
@register_module()
class PoolBN(nn.Module):
"""
AvgPool or MaxPool with BN. `pool_type` must be `max` or `avg`.
"""
def __init__(self, pool_type, C, kernel_size, stride, padding, affine=True):
super(PoolBN, self).__init__()
if pool_type.lower() == 'max':
self.pool = nn.MaxPool2d(kernel_size, stride, padding)
elif pool_type.lower() == 'avg':
self.pool = nn.AvgPool2d(kernel_size, stride, padding, count_include_pad=False)
else:
raise ValueError()
self.bn = nn.BatchNorm2d(C, affine=affine)
def forward(self, x):
out = self.pool(x)
out = self.bn(out)
return out
@register_module()
class StdConv(nn.Module):
"""
Standard conv: ReLU - Conv - BN
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(StdConv, self).__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_out, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@register_module()
class FacConv(nn.Module):
"""
Factorized conv: ReLU - Conv(Kx1) - Conv(1xK) - BN
"""
def __init__(self, C_in, C_out, kernel_length, stride, padding, affine=True):
super(FacConv, self).__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, (kernel_length, 1), stride, padding, bias=False),
nn.Conv2d(C_in, C_out, (1, kernel_length), stride, padding, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@register_module()
class DilConv(nn.Module):
"""
(Dilated) depthwise separable conv.
ReLU - (Dilated) depthwise separable - Pointwise - BN.
If dilation == 2, 3x3 conv => 5x5 receptive field, 5x5 conv => 9x9 receptive field.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super(DilConv, self).__init__()
self.net = nn.Sequential(
nn.ReLU(),
nn.Conv2d(C_in, C_in, kernel_size, stride, padding, dilation=dilation, groups=C_in,
bias=False),
nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine)
)
def forward(self, x):
return self.net(x)
@register_module()
class SepConv(nn.Module):
"""
Depthwise separable conv.
DilConv(dilation=1) * 2.
"""
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super(SepConv, self).__init__()
self.net = nn.Sequential(
DilConv(C_in, C_in, kernel_size, stride, padding, dilation=1, affine=affine),
DilConv(C_in, C_out, kernel_size, 1, padding, dilation=1, affine=affine)
)
def forward(self, x):
return self.net(x)
@register_module()
class FactorizedReduce(nn.Module):
"""
Reduce feature map size by factorized pointwise (stride=2).
"""
def __init__(self, C_in, C_out, affine=True):
super(FactorizedReduce, self).__init__()
self.relu = nn.ReLU()
self.conv1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
def forward(self, x):
x = self.relu(x)
out = torch.cat([self.conv1(x), self.conv2(x[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment