Unverified Commit 542a660d authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] NAS-Bench-201 (#3920)

parent 7eedec46
import torch
import torch.nn as nn
OPS_WITH_STRIDE = {
'none': lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
'avg_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'avg'),
'max_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'max'),
'conv_3x3': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3, 3), (stride, stride), (1, 1), (1, 1)),
'conv_1x1': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1, 1), (stride, stride), (0, 0), (1, 1)),
'skip_connect': lambda C_in, C_out, stride: nn.Identity() if stride == 1 and C_in == C_out
else FactorizedReduce(C_in, C_out, stride),
}
PRIMITIVES = ['none', 'skip_connect', 'conv_1x1', 'conv_3x3', 'avg_pool_3x3']
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out)
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out),
)
def forward(self, x):
return self.op(x)
class Pooling(nn.Module):
def __init__(self, C_in, C_out, stride, mode):
super(Pooling, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1)
if mode == 'avg':
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max':
self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else:
raise ValueError('Invalid mode={:} in Pooling'.format(mode))
def forward(self, x):
if self.preprocess:
x = self.preprocess(x)
return self.op(x)
class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
if self.C_in == self.C_out:
if self.stride == 1:
return x.mul(0.)
else:
return x[:, :, ::self.stride, ::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
self.conv_b = ReLUConvBN(planes, planes, 3, 1, 1, 1)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
inputs = self.downsample(inputs) # residual
return inputs + basicblock
import click
import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn
import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench201Cell
from nni.retiarii.strategy import Random
from pytorch_lightning.callbacks import LearningRateMonitor
from timm.optim import RMSpropTF
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torchvision.datasets import CIFAR100
from base_ops import ResNetBasicblock, PRIMITIVES, OPS_WITH_STRIDE
@model_wrapper
class NasBench201(nn.Module):
def __init__(self,
stem_out_channels: int = 16,
num_modules_per_stack: int = 5,
num_labels: int = 100):
super().__init__()
self.channels = C = stem_out_channels
self.num_modules = N = num_modules_per_stack
self.num_labels = num_labels
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for C_curr, reduction in zip(layer_channels, layer_reductions):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
cell = NasBench201Cell({prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES},
C_prev, C_curr, label='cell')
self.cells.append(cell)
C_prev = C_curr
self.lastact = nn.Sequential(
nn.BatchNorm2d(C_prev),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, self.num_labels)
def forward(self, inputs):
feature = self.stem(inputs)
for cell in self.cells:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits
class AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)
@serialize_cls
class NasBench201TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=200, learning_rate=0.1, weight_decay=5e-4):
super().__init__()
self.save_hyperparameters('learning_rate', 'weight_decay', 'max_epochs')
self.criterion = nn.CrossEntropyLoss()
self.accuracy = AccuracyWithLogits()
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss, prog_bar=True)
self.log('train_accuracy', self.accuracy(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
self.log('val_accuracy', self.accuracy(y_hat, y), prog_bar=True)
def configure_optimizers(self):
optimizer = RMSpropTF(self.parameters(), lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
momentum=0.9, alpha=0.9, eps=1.0)
return {
'optimizer': optimizer,
'scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs)
}
def on_validation_epoch_end(self):
nni.report_intermediate_result(self.trainer.callback_metrics['val_accuracy'].item())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_accuracy'].item())
@click.command()
@click.option('--epochs', default=12, help='Training length.')
@click.option('--batch_size', default=256, help='Batch size.')
@click.option('--port', default=8081, help='On which port the experiment is run.')
def _multi_trial_test(epochs, batch_size, port):
# initalize dataset. Note that 50k+10k is used. It's a little different from paper
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize([x / 255 for x in [129.3, 124.1, 112.4]], [x / 255 for x in [68.2, 65.4, 70.4]])
]
train_dataset = serialize(CIFAR100, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize))
test_dataset = serialize(CIFAR100, 'data', train=False, transform=transforms.Compose(normalize))
# specify training hyper-parameters
training_module = NasBench201TrainingModule(max_epochs=epochs)
# FIXME: need to fix a bug in serializer for this to work
# lr_monitor = serialize(LearningRateMonitor, logging_interval='step')
trainer = pl.Trainer(max_epochs=epochs, gpus=1)
lightning = pl.Lightning(
lightning_module=training_module,
trainer=trainer,
train_dataloader=pl.DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=batch_size),
)
strategy = Random()
model = NasBench201()
exp = RetiariiExperiment(model, lightning, [], strategy)
exp_config = RetiariiExeConfig('local')
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 20
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, port)
if __name__ == '__main__':
_multi_trial_test()
import copy
from collections import OrderedDict
from typing import Callable, List, Union, Tuple, Optional
import torch
......@@ -12,7 +13,7 @@ from .utils import generate_new_label, get_fixed_value
from ...utils import NoContextError
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator']
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator', 'NasBench201Cell']
class Repeat(nn.Module):
......@@ -147,3 +148,77 @@ class Cell(nn.Module):
current_state = torch.sum(torch.stack(current_state), 0)
states.append(current_state)
return torch.cat(states[self.num_predecessors:], 1)
class NasBench201Cell(nn.Module):
"""
Cell structure that is proposed in NAS-Bench-201 [nasbench201]_ .
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
For every i < j, there is an edge from i-th node to j-th node.
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
and returns a ``Module``.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
and :math:`N` is defined by ``num_tensors``.
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
num_tensors : int
Number of tensors in the cell (input included). Default: 4
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
References
----------
.. [nasbench201] Dong, X. and Yang, Y., 2020. Nas-bench-201: Extending the scope of reproducible neural architecture search.
arXiv preprint arXiv:2001.00326.
"""
@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __init__(self, op_candidates: List[Callable[[int, int], nn.Module]],
in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
self.layers = nn.ModuleList()
self.in_features = in_features
self.out_features = out_features
self.num_tensors = num_tensors
op_candidates = self._make_dict(op_candidates)
for tid in range(1, num_tensors):
node_ops = nn.ModuleList()
for j in range(tid):
inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}'))
self.layers.append(node_ops)
def forward(self, inputs):
tensors = [inputs]
for layer in self.layers:
current_tensor = []
for i, op in enumerate(layer):
current_tensor.append(op(tensors[i]))
current_tensor = torch.sum(torch.stack(current_tensor), 0)
tensors.append(current_tensor)
return tensors[-1]
......@@ -493,6 +493,27 @@ class GraphIR(unittest.TestCase):
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_nasbench201_cell(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.NasBench201Cell([
lambda x, y: nn.Linear(x, y),
lambda x, y: nn.Linear(x, y, bias=False)
], 10, 16)
def forward(self, x):
return self.cell(x)
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16]))
class Python(GraphIR):
def _get_converted_pytorch_model(self, model_ir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment