Unverified Commit 468917ca authored by QuanluZhang's avatar QuanluZhang Committed by GitHub
Browse files

Merge pull request #3155 from microsoft/dev-retiarii

[Do NOT Squash] Merge retiarii dev branch to master
parents f8424a9f d5a551c8
from ..operation import TensorFlowOperation
class Conv2D(TensorFlowOperation):
def __init__(self, type_name, parameters, _internal):
if 'padding' not in parameters:
parameters['padding'] = 'same'
super().__init__(type_name, parameters, _internal)
from ..operation import PyTorchOperation
class relu(PyTorchOperation):
def to_init_code(self, field):
return ''
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = nn.functional.relu({inputs[0]})'
class Flatten(PyTorchOperation):
def to_init_code(self, field):
return ''
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = {inputs[0]}.view({inputs[0]}.size(0), -1)'
class ToDevice(PyTorchOperation):
def to_init_code(self, field):
return ''
def to_forward_code(self, field, output, inputs) -> str:
assert len(inputs) == 1
return f"{output} = {inputs[0]}.to('{self.parameters['device']}')"
class Dense(PyTorchOperation):
def to_init_code(self, field):
return f"self.{field} = nn.Linear({self.parameters['in_features']}, {self.parameters['out_features']})"
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = self.{field}({inputs[0]})'
class Softmax(PyTorchOperation):
def to_init_code(self, field):
return ''
def to_forward_code(self, field, output, *inputs) -> str:
assert len(inputs) == 1
return f'{output} = F.softmax({inputs[0]}, -1)'
from .tpe_strategy import TPEStrategy
import abc
from typing import List
from ..graph import Model
from ..mutator import Mutator
class BaseStrategy(abc.ABC):
@abc.abstractmethod
def run(self, base_model: Model, applied_mutators: List[Mutator]) -> None:
pass
import logging
from nni.algorithms.hpo.hyperopt_tuner import HyperoptTuner
from .. import Sampler, submit_models, wait_models
from .strategy import BaseStrategy
_logger = logging.getLogger(__name__)
class TPESampler(Sampler):
def __init__(self, optimize_mode='minimize'):
self.tpe_tuner = HyperoptTuner('tpe', optimize_mode)
self.cur_sample = None
self.index = None
self.total_parameters = {}
def update_sample_space(self, sample_space):
search_space = {}
for i, each in enumerate(sample_space):
search_space[str(i)] = {'_type': 'choice', '_value': each}
self.tpe_tuner.update_search_space(search_space)
def generate_samples(self, model_id):
self.cur_sample = self.tpe_tuner.generate_parameters(model_id)
self.total_parameters[model_id] = self.cur_sample
self.index = 0
def receive_result(self, model_id, result):
self.tpe_tuner.receive_trial_result(model_id, self.total_parameters[model_id], result)
def choice(self, candidates, mutator, model, index):
chosen = self.cur_sample[str(self.index)]
self.index += 1
return chosen
class TPEStrategy(BaseStrategy):
def __init__(self):
self.tpe_sampler = TPESampler()
self.model_id = 0
def run(self, base_model, applied_mutators):
sample_space = []
new_model = base_model
for mutator in applied_mutators:
recorded_candidates, new_model = mutator.dry_run(new_model)
sample_space.extend(recorded_candidates)
self.tpe_sampler.update_sample_space(sample_space)
try:
_logger.info('stargety start...')
while True:
model = base_model
_logger.info('apply mutators...')
_logger.info('mutators: %s', str(applied_mutators))
self.tpe_sampler.generate_samples(self.model_id)
for mutator in applied_mutators:
_logger.info('mutate model...')
mutator.bind_sampler(self.tpe_sampler)
model = mutator.apply(model)
# run models
submit_models(model)
wait_models(model)
self.tpe_sampler.receive_result(self.model_id, model.metric)
self.model_id += 1
_logger.info('Strategy says: %s', model.metric)
except Exception:
_logger.error(logging.exception('message'))
from .interface import BaseTrainer
from .pytorch import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
import abc
from typing import Any
class BaseTrainer(abc.ABC):
"""
In this version, we plan to write our own trainers instead of using PyTorch-lightning, to
ease the burden to integrate our optmization with PyTorch-lightning, a large part of which is
opaque to us.
We will try to align with PyTorch-lightning name conversions so that we can easily migrate to
PyTorch-lightning in the future.
Currently, our trainer = LightningModule + LightningTrainer. We might want to separate these two things
in future.
Trainer has a ``fit`` function with no return value. Intermediate results and final results should be
directly sent via ``nni.report_intermediate_result()`` and ``nni.report_final_result()`` functions.
"""
@abc.abstractmethod
def fit(self) -> None:
pass
class BaseOneShotTrainer(BaseTrainer):
"""
Build many (possibly all) architectures into a full graph, search (with train) and export the best.
It has an extra ``export`` function that exports an object representing the final searched architecture.
"""
@abc.abstractmethod
def export(self) -> Any:
pass
from .base import PyTorchImageClassificationTrainer, PyTorchMultiModelTrainer
from .darts import DartsTrainer
from .enas import EnasTrainer
from .proxyless import ProxylessTrainer
from .random import RandomTrainer, SinglePathTrainer
from typing import Any, List, Dict, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import nni
from ..interface import BaseTrainer
from ...utils import register_trainer
def get_default_transform(dataset: str) -> Any:
"""
To get a default transformation of image for a specific dataset.
This is needed because transform objects can not be directly passed as arguments.
Parameters
----------
dataset : str
Dataset class name.
Returns
-------
transform object
"""
if dataset == 'MNIST':
return transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
if dataset == 'CIFAR10':
return transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize((0.4914, 0.4822, 0.4465),
(0.2023, 0.1994, 0.2010)),
])
# unsupported dataset, return None
return None
@register_trainer()
class PyTorchImageClassificationTrainer(BaseTrainer):
"""
Image classification trainer for PyTorch.
A model, along with corresponding dataset, optimizer config is used to initialize the trainer.
The trainer will run for a fixed number of epochs (by default 10), and report the final result.
TODO
Support scheduler, validate every n epochs, train/valid dataset
Limitation induced by NNI: kwargs must be serializable to put into a JSON packed in parameters.
"""
def __init__(self, model,
dataset_cls='MNIST', dataset_kwargs=None, dataloader_kwargs=None,
optimizer_cls='SGD', optimizer_kwargs=None, trainer_kwargs=None):
"""Initialization of image classification trainer.
Parameters
----------
model : nn.Module
Model to train.
dataset_cls : str, optional
Dataset class name that is available in ``torchvision.datasets``, by default 'MNIST'
dataset_kwargs : dict, optional
Keyword arguments passed to initialization of dataset class, by default None
dataset_kwargs : dict, optional
Keyword arguments passed to ``torch.utils.data.DataLoader``, by default None
optimizer_cls : str, optional
Optimizer class name that is available in ``torch.optim``, by default 'SGD'
optimizer_kwargs : dict, optional
Keyword arguments passed to initialization of optimizer class, by default None
trainer_kwargs: dict, optional
Keyword arguments passed to trainer. Will be passed to Trainer class in future. Currently,
only the key ``max_epochs`` is useful.
"""
super(PyTorchImageClassificationTrainer, self).__init__()
self._use_cuda = torch.cuda.is_available()
self.model = model
if self._use_cuda:
self.model.cuda()
self._loss_fn = nn.CrossEntropyLoss()
self._train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
self._optimizer = getattr(torch.optim, optimizer_cls)(model.parameters(), **(optimizer_kwargs or {}))
self._trainer_kwargs = trainer_kwargs or {'max_epochs': 10}
self._train_dataloader = DataLoader(self._train_dataset, **(dataloader_kwargs or {}))
self._val_dataloader = DataLoader(self._val_dataset, **(dataloader_kwargs or {}))
def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
def training_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.training_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.training_step_after_model(x, y, y_hat)
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
x, y = batch
if self._use_cuda:
x, y = x.cuda(torch.device('cuda:0')), y.cuda(torch.device('cuda:0'))
return x, y
def training_step_after_model(self, x, y, y_hat):
loss = self._loss_fn(y_hat, y)
return loss
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> Dict[str, Any]:
x, y = self.validation_step_before_model(batch, batch_idx)
y_hat = self.model(x)
return self.validation_step_after_model(x, y, y_hat)
def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int):
x, y = batch
if self._use_cuda:
x, y = x.cuda(), y.cuda()
return x, y
def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y)
return {'val_acc': acc}
def validation_epoch_end(self, outputs: List[Dict[str, Any]]) -> Dict[str, Any]:
# We might need dict metrics in future?
avg_acc = np.mean([x['val_acc'] for x in outputs]).item()
nni.report_intermediate_result(avg_acc)
return {'val_acc': avg_acc}
def _validate(self):
validation_outputs = []
for i, batch in enumerate(self._val_dataloader):
validation_outputs.append(self.validation_step(batch, i))
return self.validation_epoch_end(validation_outputs)
def _train(self):
for i, batch in enumerate(self._train_dataloader):
loss = self.training_step(batch, i)
loss.backward()
def fit(self) -> None:
for _ in range(self._trainer_kwargs['max_epochs']):
self._train()
# assuming val_acc here
nni.report_final_result(self._validate()['val_acc'])
class PyTorchMultiModelTrainer(BaseTrainer):
def __init__(self, multi_model, kwargs=[]):
self.multi_model = multi_model
self.kwargs = kwargs
self._train_dataloaders = []
self._train_datasets = []
self._val_dataloaders = []
self._val_datasets = []
self._optimizers = []
self._trainers = []
self._loss_fn = nn.CrossEntropyLoss()
self.max_steps = self.kwargs['max_steps'] if 'makx_steps' in self.kwargs else None
self.n_model = len(self.kwargs['model_kwargs'])
for m in self.kwargs['model_kwargs']:
if m['use_input']:
dataset_cls = m['dataset_cls']
dataset_kwargs = m['dataset_kwargs']
dataloader_kwargs = m['dataloader_kwargs']
train_dataset = getattr(datasets, dataset_cls)(train=True, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
val_dataset = getattr(datasets, dataset_cls)(train=False, transform=get_default_transform(dataset_cls),
**(dataset_kwargs or {}))
train_dataloader = DataLoader(train_dataset, **(dataloader_kwargs or {}))
val_dataloader = DataLoader(val_dataset, **(dataloader_kwargs or {}))
self._train_datasets.append(train_dataset)
self._train_dataloaders.append(train_dataloader)
self._val_datasets.append(val_dataset)
self._val_dataloaders.append(val_dataloader)
if m['use_output']:
optimizer_cls = m['optimizer_cls']
optimizer_kwargs = m['optimizer_kwargs']
m_header = f"M_{m['model_id']}"
one_model_params = []
for name, param in multi_model.named_parameters():
name_prefix = '_'.join(name.split('_')[:2])
if m_header == name_prefix:
one_model_params.append(param)
optimizer = getattr(torch.optim, optimizer_cls)(one_model_params, **(optimizer_kwargs or {}))
self._optimizers.append(optimizer)
def fit(self) -> None:
torch.autograd.set_detect_anomaly(True)
max_epochs = max([x['trainer_kwargs']['max_epochs'] for x in self.kwargs['model_kwargs']])
for _ in range(max_epochs):
self._train()
nni.report_final_result(self._validate())
def _train(self):
for batch_idx, multi_model_batch in enumerate(zip(*self._train_dataloaders)):
for opt in self._optimizers:
opt.zero_grad()
xs = []
ys = []
for idx, batch in enumerate(multi_model_batch):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x)
ys.append(y)
y_hats = self.multi_model(*xs)
if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)')
losses = []
report_loss = {}
for output_idx, yhat in enumerate(y_hats):
if len(ys) == len(y_hats):
loss = self.training_step_after_model(xs[output_idx], ys[output_idx], yhat)
elif len(ys) == 1:
loss = self.training_step_after_model(xs[0], ys[0].to(yhat.get_device()), yhat)
else:
raise ValueError('len(ys) should be either 1 or len(y_hats)')
losses.append(loss.to("cuda:0"))
report_loss[self.kwargs['model_kwargs'][output_idx]['model_id']] = loss.item()
summed_loss = sum(losses)
summed_loss.backward()
for opt in self._optimizers:
opt.step()
if self.max_steps and batch_idx >= self.max_steps:
return
def training_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if device:
x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
return x, y
def training_step_after_model(self, x, y, y_hat):
loss = self._loss_fn(y_hat, y)
return loss
def _validate(self):
all_val_outputs = {idx: [] for idx in range(self.n_model)}
for batch_idx, multi_model_batch in enumerate(zip(*self._val_dataloaders)):
xs = []
ys = []
for idx, batch in enumerate(multi_model_batch):
x, y = self.training_step_before_model(batch, batch_idx, f'cuda:{idx}')
xs.append(x)
ys.append(y)
if len(ys) != len(xs):
raise ValueError('len(ys) should be equal to len(xs)')
y_hats = self.multi_model(*xs)
for output_idx, yhat in enumerate(y_hats):
if len(ys) == len(y_hats):
acc = self.validation_step_after_model(xs[output_idx], ys[output_idx], yhat)
elif len(ys) == 1:
acc = self.validation_step_after_model(xs[0], ys[0].to(yhat.get_device()), yhat)
else:
raise ValueError('len(ys) should be either 1 or len(y_hats)')
all_val_outputs[output_idx].append(acc)
report_acc = {}
for idx in all_val_outputs:
avg_acc = np.mean([x['val_acc'] for x in all_val_outputs[idx]]).item()
report_acc[self.kwargs['model_kwargs'][idx]['model_id']] = avg_acc
nni.report_intermediate_result(report_acc)
return report_acc
def validation_step_before_model(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, device=None):
x, y = batch
if device:
x, y = x.cuda(torch.device(device)), y.cuda(torch.device(device))
return x, y
def validation_step_after_model(self, x, y, y_hat):
acc = self._accuracy(y_hat, y)
return {'val_acc': acc}
def _accuracy(self, input, target): # pylint: disable=redefined-builtin
_, predict = torch.max(input.data, 1)
correct = predict.eq(target.data).cpu().sum().item()
return correct / input.size(0)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import logging
import torch
import torch.nn as nn
from ..interface import BaseOneShotTrainer
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice
_logger = logging.getLogger(__name__)
class DartsLayerChoice(nn.Module):
def __init__(self, layer_choice):
super(DartsLayerChoice, self).__init__()
self.op_choices = nn.ModuleDict(layer_choice.named_children())
self.alpha = nn.Parameter(torch.randn(len(self.op_choices)) * 1e-3)
def forward(self, *args, **kwargs):
op_results = torch.stack([op(*args, **kwargs) for op in self.op_choices.values()])
alpha_shape = [-1] + [1] * (len(op_results.size()) - 1)
return torch.sum(op_results * self.alpha.view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super(DartsLayerChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return torch.argmax(self.alpha).item()
class DartsInputChoice(nn.Module):
def __init__(self, input_choice):
super(DartsInputChoice, self).__init__()
self.alpha = nn.Parameter(torch.randn(input_choice.n_candidates) * 1e-3)
self.n_chosen = input_choice.n_chosen or 1
def forward(self, inputs):
inputs = torch.stack(inputs)
alpha_shape = [-1] + [1] * (len(inputs.size()) - 1)
return torch.sum(inputs * self.alpha.view(*alpha_shape), 0)
def parameters(self):
for _, p in self.named_parameters():
yield p
def named_parameters(self):
for name, p in super(DartsInputChoice, self).named_parameters():
if name == 'alpha':
continue
yield name, p
def export(self):
return torch.argsort(-self.alpha).cpu().numpy().tolist()[:self.n_chosen]
class DartsTrainer(BaseOneShotTrainer):
"""
DARTS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
learning_rate : float
Learning rate to optimize the model.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
arc_learning_rate : float
Learning rate of architecture parameters.
unrolled : float
``True`` if using second order optimization, else first order optimization.
"""
def __init__(self, model, loss, metrics, optimizer,
num_epochs, dataset, grad_clip=5.,
learning_rate=2.5E-3, batch_size=64, workers=4,
device=None, log_frequency=None,
arc_learning_rate=3.0E-4, unrolled=False):
self.model = model
self.loss = loss
self.metrics = metrics
self.num_epochs = num_epochs
self.dataset = dataset
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.model.to(self.device)
self.nas_modules = []
replace_layer_choice(self.model, DartsLayerChoice, self.nas_modules)
replace_input_choice(self.model, DartsInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.model_optim = optimizer
self.ctrl_optim = torch.optim.Adam([m.alpha for _, m in self.nas_modules], arc_learning_rate, betas=(0.5, 0.999),
weight_decay=1.0E-3)
self.unrolled = unrolled
self.grad_clip = 5.
self._init_dataloader()
def _init_dataloader(self):
n_train = len(self.dataset)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
def _train_one_epoch(self, epoch):
self.model.train()
meters = AverageMeterGroup()
for step, ((trn_X, trn_y), (val_X, val_y)) in enumerate(zip(self.train_loader, self.valid_loader)):
trn_X, trn_y = trn_X.to(self.device), trn_y.to(self.device)
val_X, val_y = val_X.to(self.device), val_y.to(self.device)
# phase 1. architecture step
self.ctrl_optim.zero_grad()
if self.unrolled:
self._unrolled_backward(trn_X, trn_y, val_X, val_y)
else:
self._backward(val_X, val_y)
self.ctrl_optim.step()
# phase 2: child network step
self.model_optim.zero_grad()
logits, loss = self._logits_and_loss(trn_X, trn_y)
loss.backward()
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip) # gradient clipping
self.model_optim.step()
metrics = self.metrics(logits, trn_y)
metrics['loss'] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info('Epoch [%s/%s] Step [%s/%s] %s', epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _logits_and_loss(self, X, y):
logits = self.model(X)
loss = self.loss(logits, y)
return logits, loss
def _backward(self, val_X, val_y):
"""
Simple backward with gradient descent
"""
_, loss = self._logits_and_loss(val_X, val_y)
loss.backward()
def _unrolled_backward(self, trn_X, trn_y, val_X, val_y):
"""
Compute unrolled loss and backward its gradients
"""
backup_params = copy.deepcopy(tuple(self.model.parameters()))
# do virtual step on training data
lr = self.model_optim.param_groups[0]["lr"]
momentum = self.model_optim.param_groups[0]["momentum"]
weight_decay = self.model_optim.param_groups[0]["weight_decay"]
self._compute_virtual_model(trn_X, trn_y, lr, momentum, weight_decay)
# calculate unrolled loss on validation data
# keep gradients for model here for compute hessian
_, loss = self._logits_and_loss(val_X, val_y)
w_model, w_ctrl = tuple(self.model.parameters()), tuple([c.alpha for c in self.nas_modules])
w_grads = torch.autograd.grad(loss, w_model + w_ctrl)
d_model, d_ctrl = w_grads[:len(w_model)], w_grads[len(w_model):]
# compute hessian and final gradients
hessian = self._compute_hessian(backup_params, d_model, trn_X, trn_y)
with torch.no_grad():
for param, d, h in zip(w_ctrl, d_ctrl, hessian):
# gradient = dalpha - lr * hessian
param.grad = d - lr * h
# restore weights
self._restore_weights(backup_params)
def _compute_virtual_model(self, X, y, lr, momentum, weight_decay):
"""
Compute unrolled weights w`
"""
# don't need zero_grad, using autograd to calculate gradients
_, loss = self._logits_and_loss(X, y)
gradients = torch.autograd.grad(loss, self.model.parameters())
with torch.no_grad():
for w, g in zip(self.model.parameters(), gradients):
m = self.model_optim.state[w].get('momentum_buffer', 0.)
w = w - lr * (momentum * m + g + weight_decay * w)
def _restore_weights(self, backup_params):
with torch.no_grad():
for param, backup in zip(self.model.parameters(), backup_params):
param.copy_(backup)
def _compute_hessian(self, backup_params, dw, trn_X, trn_y):
"""
dw = dw` { L_val(w`, alpha) }
w+ = w + eps * dw
w- = w - eps * dw
hessian = (dalpha { L_trn(w+, alpha) } - dalpha { L_trn(w-, alpha) }) / (2*eps)
eps = 0.01 / ||dw||
"""
self._restore_weights(backup_params)
norm = torch.cat([w.view(-1) for w in dw]).norm()
eps = 0.01 / norm
if norm < 1E-8:
_logger.warning('In computing hessian, norm is smaller than 1E-8, cause eps to be %.6f.', norm.item())
dalphas = []
for e in [eps, -2. * eps]:
# w+ = w + eps*dw`, w- = w - eps*dw`
with torch.no_grad():
for p, d in zip(self.model.parameters(), dw):
p += e * d
_, loss = self._logits_and_loss(trn_X, trn_y)
dalphas.append(torch.autograd.grad(loss, [c.alpha for c in self.nas_modules]))
dalpha_pos, dalpha_neg = dalphas # dalpha { L_trn(w+) }, # dalpha { L_trn(w-) }
hessian = [(p - n) / (2. * eps) for p, n in zip(dalpha_pos, dalpha_neg)]
return hessian
def fit(self):
for i in range(self.num_epochs):
self._train_one_epoch(i)
@torch.no_grad()
def export(self):
result = dict()
for name, module in self.nas_modules:
if name not in result:
result[name] = module.export()
return result
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from ..interface import BaseOneShotTrainer
from .random import PathSamplingLayerChoice, PathSamplingInputChoice
from .utils import AverageMeterGroup, replace_layer_choice, replace_input_choice, to_device
_logger = logging.getLogger(__name__)
class StackedLSTMCell(nn.Module):
def __init__(self, layers, size, bias):
super().__init__()
self.lstm_num_layers = layers
self.lstm_modules = nn.ModuleList([nn.LSTMCell(size, size, bias=bias)
for _ in range(self.lstm_num_layers)])
def forward(self, inputs, hidden):
prev_h, prev_c = hidden
next_h, next_c = [], []
for i, m in enumerate(self.lstm_modules):
curr_h, curr_c = m(inputs, (prev_h[i], prev_c[i]))
next_c.append(curr_c)
next_h.append(curr_h)
# current implementation only supports batch size equals 1,
# but the algorithm does not necessarily have this limitation
inputs = curr_h[-1].view(1, -1)
return next_h, next_c
class ReinforceField:
"""
A field with ``name``, with ``total`` choices. ``choose_one`` is true if one and only one is meant to be
selected. Otherwise, any number of choices can be chosen.
"""
def __init__(self, name, total, choose_one):
self.name = name
self.total = total
self.choose_one = choose_one
def __repr__(self):
return f'ReinforceField(name={self.name}, total={self.total}, choose_one={self.choose_one})'
class ReinforceController(nn.Module):
"""
A controller that mutates the graph with RL.
Parameters
----------
fields : list of ReinforceField
List of fields to choose.
lstm_size : int
Controller LSTM hidden units.
lstm_num_layers : int
Number of layers for stacked LSTM.
tanh_constant : float
Logits will be equal to ``tanh_constant * tanh(logits)``. Don't use ``tanh`` if this value is ``None``.
skip_target : float
Target probability that skipconnect will appear.
temperature : float
Temperature constant that divides the logits.
entropy_reduction : str
Can be one of ``sum`` and ``mean``. How the entropy of multi-input-choice is reduced.
"""
def __init__(self, fields, lstm_size=64, lstm_num_layers=1, tanh_constant=1.5,
skip_target=0.4, temperature=None, entropy_reduction='sum'):
super(ReinforceController, self).__init__()
self.fields = fields
self.lstm_size = lstm_size
self.lstm_num_layers = lstm_num_layers
self.tanh_constant = tanh_constant
self.temperature = temperature
self.skip_target = skip_target
self.lstm = StackedLSTMCell(self.lstm_num_layers, self.lstm_size, False)
self.attn_anchor = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.attn_query = nn.Linear(self.lstm_size, self.lstm_size, bias=False)
self.v_attn = nn.Linear(self.lstm_size, 1, bias=False)
self.g_emb = nn.Parameter(torch.randn(1, self.lstm_size) * 0.1)
self.skip_targets = nn.Parameter(torch.tensor([1.0 - self.skip_target, self.skip_target]), # pylint: disable=not-callable
requires_grad=False)
assert entropy_reduction in ['sum', 'mean'], 'Entropy reduction must be one of sum and mean.'
self.entropy_reduction = torch.sum if entropy_reduction == 'sum' else torch.mean
self.cross_entropy_loss = nn.CrossEntropyLoss(reduction='none')
self.soft = nn.ModuleDict({
field.name: nn.Linear(self.lstm_size, field.total, bias=False) for field in fields
})
self.embedding = nn.ModuleDict({
field.name: nn.Embedding(field.total, self.lstm_size) for field in fields
})
def resample(self):
self._initialize()
result = dict()
for field in self.fields:
result[field.name] = self._sample_single(field)
return result
def _initialize(self):
self._inputs = self.g_emb.data
self._c = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self._h = [torch.zeros((1, self.lstm_size),
dtype=self._inputs.dtype,
device=self._inputs.device) for _ in range(self.lstm_num_layers)]
self.sample_log_prob = 0
self.sample_entropy = 0
self.sample_skip_penalty = 0
def _lstm_next_step(self):
self._h, self._c = self.lstm(self._inputs, (self._h, self._c))
def _sample_single(self, field):
self._lstm_next_step()
logit = self.soft[field.name](self._h[-1])
if self.temperature is not None:
logit /= self.temperature
if self.tanh_constant is not None:
logit = self.tanh_constant * torch.tanh(logit)
if field.choose_one:
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
log_prob = self.cross_entropy_loss(logit, sampled)
self._inputs = self.embedding[field.name](sampled)
else:
logit = logit.view(-1, 1)
logit = torch.cat([-logit, logit], 1) # pylint: disable=invalid-unary-operand-type
sampled = torch.multinomial(F.softmax(logit, dim=-1), 1).view(-1)
skip_prob = torch.sigmoid(logit)
kl = torch.sum(skip_prob * torch.log(skip_prob / self.skip_targets))
self.sample_skip_penalty += kl
log_prob = self.cross_entropy_loss(logit, sampled)
sampled = sampled.nonzero().view(-1)
if sampled.sum().item():
self._inputs = (torch.sum(self.embedding[field.name](sampled.view(-1)), 0) / (1. + torch.sum(sampled))).unsqueeze(0)
else:
self._inputs = torch.zeros(1, self.lstm_size, device=self.embedding[field.name].weight.device)
sampled = sampled.detach().numpy().tolist()
self.sample_log_prob += self.entropy_reduction(log_prob)
entropy = (log_prob * torch.exp(-log_prob)).detach() # pylint: disable=invalid-unary-operand-type
self.sample_entropy += self.entropy_reduction(entropy)
if len(sampled) == 1:
sampled = sampled[0]
return sampled
class EnasTrainer(BaseOneShotTrainer):
"""
ENAS trainer.
Parameters
----------
model : nn.Module
PyTorch model to be trained.
loss : callable
Receives logits and ground truth label, return a loss tensor.
metrics : callable
Receives logits and ground truth label, return a dict of metrics.
reward_function : callable
Receives logits and ground truth label, return a tensor, which will be feeded to RL controller as reward.
optimizer : Optimizer
The optimizer used for optimizing the model.
num_epochs : int
Number of epochs planned for training.
dataset : Dataset
Dataset for training. Will be split for training weights and architecture weights.
batch_size : int
Batch size.
workers : int
Workers for data loading.
device : torch.device
``torch.device("cpu")`` or ``torch.device("cuda")``.
log_frequency : int
Step count per logging.
grad_clip : float
Gradient clipping. Set to 0 to disable. Default: 5.
entropy_weight : float
Weight of sample entropy loss.
skip_weight : float
Weight of skip penalty loss.
baseline_decay : float
Decay factor of baseline. New baseline will be equal to ``baseline_decay * baseline_old + reward * (1 - baseline_decay)``.
ctrl_lr : float
Learning rate for RL controller.
ctrl_steps_aggregate : int
Number of steps that will be aggregated into one mini-batch for RL controller.
ctrl_steps : int
Number of mini-batches for each epoch of RL controller learning.
ctrl_kwargs : dict
Optional kwargs that will be passed to :class:`ReinforceController`.
"""
def __init__(self, model, loss, metrics, reward_function,
optimizer, num_epochs, dataset,
batch_size=64, workers=4, device=None, log_frequency=None,
grad_clip=5., entropy_weight=0.0001, skip_weight=0.8, baseline_decay=0.999,
ctrl_lr=0.00035, ctrl_steps_aggregate=20, ctrl_kwargs=None):
self.model = model
self.loss = loss
self.metrics = metrics
self.optimizer = optimizer
self.num_epochs = num_epochs
self.dataset = dataset
self.batch_size = batch_size
self.workers = workers
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
self.log_frequency = log_frequency
self.nas_modules = []
replace_layer_choice(self.model, PathSamplingLayerChoice, self.nas_modules)
replace_input_choice(self.model, PathSamplingInputChoice, self.nas_modules)
for _, module in self.nas_modules:
module.to(self.device)
self.model.to(self.device)
self.nas_fields = [ReinforceField(name, len(module),
isinstance(module, PathSamplingLayerChoice) or module.n_chosen == 1)
for name, module in self.nas_modules]
self.controller = ReinforceController(self.nas_fields, **(ctrl_kwargs or {}))
self.grad_clip = grad_clip
self.reward_function = reward_function
self.ctrl_optim = optim.Adam(self.controller.parameters(), lr=ctrl_lr)
self.batch_size = batch_size
self.workers = workers
self.entropy_weight = entropy_weight
self.skip_weight = skip_weight
self.baseline_decay = baseline_decay
self.baseline = 0.
self.ctrl_steps_aggregate = ctrl_steps_aggregate
self.init_dataloader()
def init_dataloader(self):
n_train = len(self.dataset)
split = n_train // 2
indices = list(range(n_train))
train_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[:-split])
valid_sampler = torch.utils.data.sampler.SubsetRandomSampler(indices[-split:])
self.train_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=train_sampler,
num_workers=self.workers)
self.valid_loader = torch.utils.data.DataLoader(self.dataset,
batch_size=self.batch_size,
sampler=valid_sampler,
num_workers=self.workers)
def _train_model(self, epoch):
self.model.train()
self.controller.eval()
meters = AverageMeterGroup()
for step, (x, y) in enumerate(self.train_loader):
x, y = to_device(x, self.device), to_device(y, self.device)
self.optimizer.zero_grad()
self._resample()
logits = self.model(x)
metrics = self.metrics(logits, y)
loss = self.loss(logits, y)
loss.backward()
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clip)
self.optimizer.step()
metrics['loss'] = loss.item()
meters.update(metrics)
if self.log_frequency is not None and step % self.log_frequency == 0:
_logger.info('Model Epoch [%d/%d] Step [%d/%d] %s', epoch + 1,
self.num_epochs, step + 1, len(self.train_loader), meters)
def _train_controller(self, epoch):
self.model.eval()
self.controller.train()
meters = AverageMeterGroup()
self.ctrl_optim.zero_grad()
for ctrl_step, (x, y) in enumerate(self.valid_loader):
x, y = to_device(x, self.device), to_device(y, self.device)
self._resample()
with torch.no_grad():
logits = self.model(x)
metrics = self.metrics(logits, y)
reward = self.reward_function(logits, y)
if self.entropy_weight:
reward += self.entropy_weight * self.controller.sample_entropy.item()
self.baseline = self.baseline * self.baseline_decay + reward * (1 - self.baseline_decay)
loss = self.controller.sample_log_prob * (reward - self.baseline)
if self.skip_weight:
loss += self.skip_weight * self.controller.sample_skip_penalty
metrics['reward'] = reward
metrics['loss'] = loss.item()
metrics['ent'] = self.controller.sample_entropy.item()
metrics['log_prob'] = self.controller.sample_log_prob.item()
metrics['baseline'] = self.baseline
metrics['skip'] = self.controller.sample_skip_penalty
loss /= self.ctrl_steps_aggregate
loss.backward()
meters.update(metrics)
if (ctrl_step + 1) % self.ctrl_steps_aggregate == 0:
if self.grad_clip > 0:
nn.utils.clip_grad_norm_(self.controller.parameters(), self.grad_clip)
self.ctrl_optim.step()
self.ctrl_optim.zero_grad()
if self.log_frequency is not None and ctrl_step % self.log_frequency == 0:
_logger.info('RL Epoch [%d/%d] Step [%d/%d] %s', epoch + 1, self.num_epochs,
ctrl_step + 1, len(self.valid_loader), meters)
def _resample(self):
result = self.controller.resample()
for name, module in self.nas_modules:
module.sampled = result[name]
def fit(self):
for i in range(self.num_epochs):
self._train_model(i)
self._train_controller(i)
def export(self):
self.controller.eval()
with torch.no_grad():
return self.controller.resample()
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
"""
Entrypoint for trials.
Assuming execution engine is BaseExecutionEngine.
"""
import os
from .execution.base import BaseExecutionEngine
from .execution.cgo_engine import CGOExecutionEngine
if __name__ == '__main__':
if os.environ.get('CGO') == 'true':
CGOExecutionEngine.trial_execute_graph()
else:
BaseExecutionEngine.trial_execute_graph()
This diff is collapsed.
...@@ -41,7 +41,7 @@ jobs: ...@@ -41,7 +41,7 @@ jobs:
python3 -m pip install --upgrade pygments python3 -m pip install --upgrade pygments
python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html python3 -m pip install --upgrade torch>=1.7.0+cpu torchvision>=0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html
python3 -m pip install --upgrade tensorflow python3 -m pip install --upgrade tensorflow
python3 -m pip install --upgrade gym onnx peewee thop python3 -m pip install --upgrade gym onnx peewee thop graphviz
python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx python3 -m pip install sphinx==1.8.3 sphinx-argparse==0.2.5 sphinx-markdown-tables==0.0.9 sphinx-rtd-theme==0.4.2 sphinxcontrib-websupport==1.1.0 recommonmark==0.5.0 nbsphinx
sudo apt-get install swig -y sudo apt-get install swig -y
python3 -m pip install -e .[SMAC,BOHB] python3 -m pip install -e .[SMAC,BOHB]
......
...@@ -2,4 +2,9 @@ __pycache__ ...@@ -2,4 +2,9 @@ __pycache__
tuner_search_space.json tuner_search_space.json
tuner_result.txt tuner_result.txt
assessor_result.txt assessor_result.txt
\ No newline at end of file
_generated_model.py
data
generated
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment