Unverified Commit 9f32a06f authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] NAS-Bench-101 (#3871)

parent dde4d862
import math
import torch.nn as nn
def truncated_normal_(tensor, mean=0, std=1):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
class ConvBnRelu(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super(ConvBnRelu, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
truncated_normal_(m.weight.data, mean=0., std=math.sqrt(1. / fan_in))
if isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
return self.conv_bn_relu(x)
class Conv3x3BnRelu(ConvBnRelu):
def __init__(self, in_channels, out_channels):
super(Conv3x3BnRelu, self).__init__(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
class Conv1x1BnRelu(ConvBnRelu):
def __init__(self, in_channels, out_channels):
super(Conv1x1BnRelu, self).__init__(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
Projection = Conv1x1BnRelu
import click
import nni
import nni.retiarii.evaluator.pytorch.lightning as pl
import torch.nn as nn
import torchmetrics
from nni.retiarii import model_wrapper, serialize, serialize_cls
from nni.retiarii.experiment.pytorch import RetiariiExperiment, RetiariiExeConfig
from nni.retiarii.nn.pytorch import NasBench101Cell
from nni.retiarii.strategy import Random
from pytorch_lightning.callbacks import LearningRateMonitor
from timm.optim import RMSpropTF
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torchvision.datasets import CIFAR10
from base_ops import Conv3x3BnRelu, Conv1x1BnRelu, Projection
@model_wrapper
class NasBench101(nn.Module):
def __init__(self,
stem_out_channels: int = 128,
num_stacks: int = 3,
num_modules_per_stack: int = 3,
max_num_vertices: int = 7,
max_num_edges: int = 9,
num_labels: int = 10,
bn_eps: float = 1e-5,
bn_momentum: float = 0.003):
super().__init__()
op_candidates = {
'conv3x3': lambda num_features: Conv3x3BnRelu(num_features, num_features),
'conv1x1': lambda num_features: Conv1x1BnRelu(num_features, num_features),
'maxpool': lambda num_features: nn.MaxPool2d(3, 1, 1)
}
# initial stem convolution
self.stem_conv = Conv3x3BnRelu(3, stem_out_channels)
layers = []
in_channels = out_channels = stem_out_channels
for stack_num in range(num_stacks):
if stack_num > 0:
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
layers.append(downsample)
out_channels *= 2
for _ in range(num_modules_per_stack):
cell = NasBench101Cell(op_candidates, in_channels, out_channels,
lambda cin, cout: Projection(cin, cout),
max_num_vertices, max_num_edges, label='cell')
layers.append(cell)
in_channels = out_channels
self.features = nn.ModuleList(layers)
self.gap = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(out_channels, num_labels)
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.eps = bn_eps
module.momentum = bn_momentum
def forward(self, x):
bs = x.size(0)
out = self.stem_conv(x)
for layer in self.features:
out = layer(out)
out = self.gap(out).view(bs, -1)
out = self.classifier(out)
return out
def reset_parameters(self):
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.eps = self.config.bn_eps
module.momentum = self.config.bn_momentum
class AccuracyWithLogits(torchmetrics.Accuracy):
def update(self, pred, target):
return super().update(nn.functional.softmax(pred), target)
@serialize_cls
class NasBench101TrainingModule(pl.LightningModule):
def __init__(self, max_epochs=108, learning_rate=0.1, weight_decay=1e-4):
super().__init__()
self.save_hyperparameters('learning_rate', 'weight_decay', 'max_epochs')
self.criterion = nn.CrossEntropyLoss()
self.accuracy = AccuracyWithLogits()
def forward(self, x):
y_hat = self.model(x)
return y_hat
def training_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
loss = self.criterion(y_hat, y)
self.log('train_loss', loss, prog_bar=True)
self.log('train_accuracy', self.accuracy(y_hat, y), prog_bar=True)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
y_hat = self(x)
self.log('val_loss', self.criterion(y_hat, y), prog_bar=True)
self.log('val_accuracy', self.accuracy(y_hat, y), prog_bar=True)
def configure_optimizers(self):
optimizer = RMSpropTF(self.parameters(), lr=self.hparams.learning_rate,
weight_decay=self.hparams.weight_decay,
momentum=0.9, alpha=0.9, eps=1.0)
return {
'optimizer': optimizer,
'scheduler': CosineAnnealingLR(optimizer, self.hparams.max_epochs)
}
def on_validation_epoch_end(self):
nni.report_intermediate_result(self.trainer.callback_metrics['val_accuracy'].item())
def teardown(self, stage):
if stage == 'fit':
nni.report_final_result(self.trainer.callback_metrics['val_accuracy'].item())
@click.command()
@click.option('--epochs', default=108, help='Training length.')
@click.option('--batch_size', default=256, help='Batch size.')
@click.option('--port', default=8081, help='On which port the experiment is run.')
def _multi_trial_test(epochs, batch_size, port):
# initalize dataset. Note that 50k+10k is used. It's a little different from paper
transf = [
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip()
]
normalize = [
transforms.ToTensor(),
transforms.Normalize([0.49139968, 0.48215827, 0.44653124], [0.24703233, 0.24348505, 0.26158768])
]
train_dataset = serialize(CIFAR10, 'data', train=True, download=True, transform=transforms.Compose(transf + normalize))
test_dataset = serialize(CIFAR10, 'data', train=False, transform=transforms.Compose(normalize))
# specify training hyper-parameters
training_module = NasBench101TrainingModule(max_epochs=epochs)
# FIXME: need to fix a bug in serializer for this to work
# lr_monitor = serialize(LearningRateMonitor, logging_interval='step')
trainer = pl.Trainer(max_epochs=epochs, gpus=1)
lightning = pl.Lightning(
lightning_module=training_module,
trainer=trainer,
train_dataloader=pl.DataLoader(train_dataset, batch_size=batch_size, shuffle=True),
val_dataloaders=pl.DataLoader(test_dataset, batch_size=batch_size),
)
strategy = Random()
model = NasBench101()
exp = RetiariiExperiment(model, lightning, [], strategy)
exp_config = RetiariiExeConfig('local')
exp_config.trial_concurrency = 2
exp_config.max_trial_number = 20
exp_config.trial_gpu_number = 1
exp_config.training_service.use_active_gpu = False
exp.run(exp_config, port)
if __name__ == '__main__':
_multi_trial_test()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import (Any, Iterable, List, Optional)
from typing import (Any, Iterable, List, Optional, Tuple)
from .graph import Model, Mutation, ModelStatus
__all__ = ['Sampler', 'Mutator']
__all__ = ['Sampler', 'Mutator', 'InvalidMutation']
Choice = Any
......@@ -77,7 +77,7 @@ class Mutator:
self._cur_choice_idx = None
return copy
def dry_run(self, model: Model) -> List[List[Choice]]:
def dry_run(self, model: Model) -> Tuple[List[List[Choice]], Model]:
"""
Dry run mutator on a model to collect choice candidates.
If you invoke this method multiple times on same or different models,
......@@ -115,3 +115,7 @@ class _RecorderSampler(Sampler):
def choice(self, candidates: List[Choice], *args) -> Choice:
self.recorded_candidates.append(candidates)
return candidates[0]
class InvalidMutation(Exception):
pass
......@@ -3,13 +3,13 @@
import copy
import warnings
from collections import OrderedDict
from typing import Any, List, Union, Dict, Optional
import torch
import torch.nn as nn
from ...serializer import Translatable, basic_unit
from ...utils import NoContextError
from .utils import generate_new_label, get_fixed_value
......@@ -26,6 +26,8 @@ class LayerChoice(nn.Module):
----------
candidates : list of nn.Module or OrderedDict
A module list to be selected from.
prior : list of float
Prior distribution used in random sampling.
label : str
Identifier of the layer choice.
......@@ -55,17 +57,21 @@ class LayerChoice(nn.Module):
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs):
# FIXME: prior is designed but not supported yet
def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
try:
chosen = get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
else:
return candidates[chosen]
except AssertionError:
except NoContextError:
return super().__new__(cls)
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs):
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
super(LayerChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
......@@ -75,10 +81,12 @@ class LayerChoice(nn.Module):
if 'reduction' in kwargs:
warnings.warn(f'"reduction" is deprecated. Ignoring...')
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
self._label = generate_new_label(label)
self.names = []
if isinstance(candidates, OrderedDict):
if isinstance(candidates, dict):
for name, module in candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
......@@ -169,17 +177,23 @@ class InputChoice(nn.Module):
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``.
prior : list of float
Prior distribution used in random sampling.
label : str
Identifier of the input choice.
"""
def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs):
def __new__(cls, n_candidates: int, n_chosen: Optional[int] = 1,
reduction: str = 'sum', *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
try:
return ChosenInputs(get_fixed_value(label), reduction=reduction)
except AssertionError:
except NoContextError:
return super().__new__(cls)
def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs):
def __init__(self, n_candidates: int, n_chosen: Optional[int] = 1,
reduction: str = 'sum', *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
super(InputChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
......@@ -191,6 +205,7 @@ class InputChoice(nn.Module):
self.n_candidates = n_candidates
self.n_chosen = n_chosen
self.reduction = reduction
self.prior = prior or [1 / n_candidates for _ in range(n_candidates)]
assert self.reduction in ['mean', 'concat', 'sum', 'none']
self._label = generate_new_label(label)
......@@ -277,19 +292,25 @@ class ValueChoice(Translatable, nn.Module):
----------
candidates : list
List of values to choose from.
prior : list of float
Prior distribution to sample from.
label : str
Identifier of the value choice.
"""
def __new__(cls, candidates: List[Any], label: Optional[str] = None):
# FIXME: prior is designed but not supported yet
def __new__(cls, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
try:
return get_fixed_value(label)
except AssertionError:
except NoContextError:
return super().__new__(cls)
def __init__(self, candidates: List[Any], label: Optional[str] = None):
def __init__(self, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__()
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
self._label = generate_new_label(label)
self._accessor = []
......@@ -323,7 +344,7 @@ class ValueChoice(Translatable, nn.Module):
return self
def __deepcopy__(self, memo):
new_item = ValueChoice(self.candidates, self.label)
new_item = ValueChoice(self.candidates, label=self.label)
new_item._accessor = [*self._accessor]
return new_item
......
......@@ -7,10 +7,12 @@ import torch.nn as nn
from .api import LayerChoice, InputChoice
from .nn import ModuleList
from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .utils import generate_new_label, get_fixed_value
from ...utils import NoContextError
__all__ = ['Repeat', 'Cell']
__all__ = ['Repeat', 'Cell', 'NasBench101Cell', 'NasBench101Mutator']
class Repeat(nn.Module):
......@@ -33,7 +35,7 @@ class Repeat(nn.Module):
try:
repeat = get_fixed_value(label)
return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
except AssertionError:
except NoContextError:
return super().__new__(cls)
def __init__(self,
......
......@@ -9,7 +9,7 @@ import torch.nn as nn
from ...mutator import Mutator
from ...graph import Cell, Graph, Model, ModelStatus, Node
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
from .component import Repeat
from .component import Repeat, NasBench101Cell, NasBench101Mutator
from ...utils import uid
......@@ -47,7 +47,12 @@ class InputChoiceMutator(Mutator):
n_candidates = self.nodes[0].operation.parameters['n_candidates']
n_chosen = self.nodes[0].operation.parameters['n_chosen']
candidates = list(range(n_candidates))
chosen = [self.choice(candidates) for _ in range(n_chosen)]
if n_chosen is None:
chosen = [i for i in candidates if self.choice([False, True])]
# FIXME This is a hack to make choice align with the previous format
self._cur_samples = chosen
else:
chosen = [self.choice(candidates) for _ in range(n_chosen)]
for node in self.nodes:
target = model.get_node_by_name(node.name)
target.update_operation('__torch__.nni.retiarii.nn.pytorch.ChosenInputs',
......@@ -199,8 +204,15 @@ class ManyChooseManyMutator(Mutator):
def mutate(self, model: Model):
# this mutate does not have any effect, but it is recorded in the mutation history
for node in model.get_nodes_by_label(self.label):
for _ in range(self.number_of_chosen(node)):
self.choice(self.candidates(node))
n_chosen = self.number_of_chosen(node)
if n_chosen is None:
candidates = [i for i in self.candidates(node) if self.choice([False, True])]
# FIXME This is a hack to make choice align with the previous format
# For example, it will convert [False, True, True] into [1, 2].
self._cur_samples = candidates
else:
for _ in range(n_chosen):
self.choice(self.candidates(node))
break
......@@ -242,6 +254,11 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
'candidates': list(range(module.min_depth, module.max_depth + 1))
})
node.label = module.label
if isinstance(module, NasBench101Cell):
node = graph.add_node(name, 'NasBench101Cell', {
'max_num_edges': module.max_num_edges
})
node.label = module.label
if isinstance(module, Placeholder):
raise NotImplementedError('Placeholder is not supported in python execution mode.')
......@@ -250,13 +267,17 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
return model, None
mutators = []
mutators_final = []
for nodes in _group_by_label_and_type(graph.hidden_nodes):
assert _is_all_equal(map(lambda n: n.operation.type, nodes)), \
f'Node with label "{nodes[0].label}" does not all have the same type.'
assert _is_all_equal(map(lambda n: n.operation.parameters, nodes)), \
f'Node with label "{nodes[0].label}" does not agree on parameters.'
mutators.append(ManyChooseManyMutator(nodes[0].label))
return model, mutators
if nodes[0].operation.type == 'NasBench101Cell':
mutators_final.append(NasBench101Mutator(nodes[0].label))
else:
mutators.append(ManyChooseManyMutator(nodes[0].label))
return model, mutators + mutators_final
# utility functions
......
import logging
from collections import OrderedDict
from typing import Callable, List, Optional, Union, Dict
import numpy as np
import torch
import torch.nn as nn
from .api import InputChoice, ValueChoice, LayerChoice
from .utils import generate_new_label, get_fixed_dict
from ...mutator import InvalidMutation, Mutator
from ...graph import Model
from ...utils import NoContextError
_logger = logging.getLogger(__name__)
def compute_vertex_channels(input_channels, output_channels, matrix):
"""
This is (almost) copied from the original NAS-Bench-101 implementation.
Computes the number of channels at every vertex.
Given the input channels and output channels, this calculates the number of channels at each interior vertex.
Interior vertices have the same number of channels as the max of the channels of the vertices it feeds into.
The output channels are divided amongst the vertices that are directly connected to it.
When the division is not even, some vertices may receive an extra channel to compensate.
Parameters
----------
in_channels : int
input channels count.
output_channels : int
output channel count.
matrix : np.ndarray
adjacency matrix for the module (pruned by model_spec).
Returns
-------
list of int
list of channel counts, in order of the vertices.
"""
num_vertices = np.shape(matrix)[0]
vertex_channels = [0] * num_vertices
vertex_channels[0] = input_channels
vertex_channels[num_vertices - 1] = output_channels
if num_vertices == 2:
# Edge case where module only has input and output vertices
return vertex_channels
# Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
# the dst vertex. Summing over 0 gives the in-degree count of each vertex.
in_degree = np.sum(matrix[1:], axis=0)
interior_channels = output_channels // in_degree[num_vertices - 1]
correction = output_channels % in_degree[num_vertices - 1] # Remainder to add
# Set channels of vertices that flow directly to output
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
vertex_channels[v] = interior_channels
if correction:
vertex_channels[v] += 1
correction -= 1
# Set channels for all other vertices to the max of the out edges, going backwards.
# (num_vertices - 2) index skipped because it only connects to output.
for v in range(num_vertices - 3, 0, -1):
if not matrix[v, num_vertices - 1]:
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
assert vertex_channels[v] > 0
_logger.debug('vertex_channels: %s', str(vertex_channels))
# Sanity check, verify that channels never increase and final channels add up.
final_fan_in = 0
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
final_fan_in += vertex_channels[v]
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
assert vertex_channels[v] >= vertex_channels[dst]
assert final_fan_in == output_channels or num_vertices == 2
# num_vertices == 2 means only input/output nodes, so 0 fan-in
return vertex_channels
def prune(matrix, ops):
"""
Prune the extraneous parts of the graph.
General procedure:
1. Remove parts of graph not connected to input.
2. Remove parts of graph not connected to output.
3. Reorder the vertices so that they are consecutive after steps 1 and 2.
These 3 steps can be combined by deleting the rows and columns of the
vertices that are not reachable from both the input and output (in reverse).
"""
num_vertices = np.shape(matrix)[0]
# calculate the connection matrix within V number of steps.
connections = np.linalg.matrix_power(matrix + np.eye(num_vertices), num_vertices)
visited_from_input = set([i for i in range(num_vertices) if connections[0, i]])
visited_from_output = set([i for i in range(num_vertices) if connections[i, -1]])
# Any vertex that isn't connected to both input and output is extraneous to the computation graph.
extraneous = set(range(num_vertices)).difference(
visited_from_input.intersection(visited_from_output))
if len(extraneous) > num_vertices - 2:
raise InvalidMutation('Non-extraneous graph is less than 2 vertices, '
'the input is not connected to the output and the spec is invalid.')
matrix = np.delete(matrix, list(extraneous), axis=0)
matrix = np.delete(matrix, list(extraneous), axis=1)
for index in sorted(extraneous, reverse=True):
del ops[index]
return matrix, ops
def truncate(inputs, channels):
input_channels = inputs.size(1)
if input_channels < channels:
raise ValueError('input channel < output channels for truncate')
elif input_channels == channels:
return inputs # No truncation necessary
else:
# Truncation should only be necessary when channel division leads to
# vertices with +1 channels. The input vertex should always be projected to
# the minimum channel count.
assert input_channels - channels == 1
return inputs[:, :channels]
class _NasBench101CellFixed(nn.Module):
"""
The fixed version of NAS-Bench-101 Cell, used in python-version execution engine.
"""
def __init__(self, operations: List[Callable[[int], nn.Module]],
adjacency_list: List[List[int]],
in_features: int, out_features: int, num_nodes: int,
projection: Callable[[int, int], nn.Module]):
super().__init__()
assert num_nodes == len(operations) + 2 == len(adjacency_list) + 1
self.operations = ['IN'] + operations + ['OUT'] # add psuedo nodes
self.connection_matrix = self.build_connection_matrix(adjacency_list, num_nodes)
del num_nodes # raw number of nodes is no longer used
self.connection_matrix, self.operations = prune(self.connection_matrix, self.operations)
self.hidden_features = compute_vertex_channels(in_features, out_features, self.connection_matrix)
self.num_nodes = len(self.connection_matrix)
self.in_features = in_features
self.out_features = out_features
_logger.info('Prund number of nodes: %d', self.num_nodes)
_logger.info('Pruned connection matrix: %s', str(self.connection_matrix))
self.projections = nn.ModuleList([nn.Identity()])
self.ops = nn.ModuleList([nn.Identity()])
for i in range(1, self.num_nodes):
self.projections.append(projection(in_features, self.hidden_features[i]))
for i in range(1, self.num_nodes - 1):
self.ops.append(operations[i - 1](self.hidden_features[i]))
@staticmethod
def build_connection_matrix(adjacency_list, num_nodes):
adjacency_list = [[]] + adjacency_list # add adjacency for first node
connections = np.zeros((num_nodes, num_nodes), dtype='int')
for i, lst in enumerate(adjacency_list):
assert all([0 <= k < i for k in lst])
for k in lst:
connections[k, i] = 1
return connections
def forward(self, inputs):
tensors = [inputs]
for t in range(1, self.num_nodes - 1):
# Create interior connections, truncating if necessary
add_in = [truncate(tensors[src], self.hidden_features[t])
for src in range(1, t) if self.connection_matrix[src, t]]
# Create add connection from projected input
if self.connection_matrix[0, t]:
add_in.append(self.projections[t](tensors[0]))
if len(add_in) == 1:
vertex_input = add_in[0]
else:
vertex_input = sum(add_in)
# Perform op at vertex t
vertex_out = self.ops[t](vertex_input)
tensors.append(vertex_out)
# Construct final output tensor by concating all fan-in and adding input.
if np.sum(self.connection_matrix[:, -1]) == 1:
src = np.where(self.connection_matrix[:, -1] == 1)[0][0]
return self.projections[-1](tensors[0]) if src == 0 else tensors[src]
outputs = torch.cat([tensors[src] for src in range(1, self.num_nodes - 1) if self.connection_matrix[src, -1]], 1)
if self.connection_matrix[0, -1]:
outputs += self.projections[-1](tensors[0])
assert outputs.size(1) == self.out_features
return outputs
class NasBench101Cell(nn.Module):
"""
Cell structure that is proposed in NAS-Bench-101 [nasbench101]_ .
This cell is usually used in evaluation of NAS algorithms because there is a ``comprehensive analysis'' of this search space
available, which includes a full architecture-dataset that ``maps 423k unique architectures to metrics
including run time and accuracy''. You can also use the space in your own space design, in which scenario it should be possible
to leverage results in the benchmark to narrow the huge space down to a few efficient architectures.
The space of this cell architecture consists of all possible directed acyclic graphs on no more than ``max_num_nodes`` nodes,
where each possible node (other than IN and OUT) has one of ``op_candidates``, representing the corresponding operation.
Edges connecting the nodes can be no more than ``max_num_edges``.
To align with the paper settings, two vertices specially labeled as operation IN and OUT, are also counted into
``max_num_nodes`` in our implementaion, the default value of ``max_num_nodes`` is 7 and ``max_num_edges`` is 9.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be `[N, C_{out}, *]`. The shape
of each hidden nodes will be first automatically computed, depending on the cell structure. Each of the ``op_candidates``
should be a callable that accepts computed ``num_features`` and returns a ``Module``. For example,
.. code-block:: python
def conv_bn_relu(num_features):
return nn.Sequential(
nn.Conv2d(num_features, num_features, 1),
nn.BatchNorm2d(num_features),
nn.ReLU()
)
The output of each node is the sum of its input node feed into its operation, except for the last node (output node),
which is the concatenation of its input *hidden* nodes, adding the *IN* node (if IN and OUT are connected).
When input tensor is added with any other tensor, there could be shape mismatch. Therefore, a projection transformation
is needed to transform the input tensor. In paper, this is simply a Conv1x1 followed by BN and ReLU. The ``projection``
parameters accepts ``in_features`` and ``out_features``, returns a ``Module``. This parameter has no default value,
as we hold no assumption that users are dealing with images. An example for this parameter is,
.. code-block:: python
def projection_fn(in_features, out_features):
return nn.Conv2d(in_features, out_features, 1)
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts number of feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
projection : callable
Projection module that is used to preprocess the input tensor of the whole cell.
A callable that accept input feature and output feature, returning nn.Module.
max_num_nodes : int
Maximum number of nodes in the cell, input and output included. At least 2. Default: 7.
max_num_edges : int
Maximum number of edges in the cell. Default: 9.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
References
----------
.. [nasbench101] Ying, Chris, et al. "Nas-bench-101: Towards reproducible neural architecture search."
International Conference on Machine Learning. PMLR, 2019.
"""
@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __new__(cls, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
def make_list(x): return x if isinstance(x, list) else [x]
try:
label, selected = get_fixed_dict(label)
op_candidates = cls._make_dict(op_candidates)
num_nodes = selected[f'{label}/num_nodes']
adjacency_list = [make_list(selected[f'{label}/input_{i}']) for i in range(1, num_nodes)]
if sum([len(e) for e in adjacency_list]) > max_num_edges:
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
return _NasBench101CellFixed(
[op_candidates[selected[f'{label}/op_{i}']] for i in range(1, num_nodes - 1)],
adjacency_list, in_features, out_features, num_nodes, projection)
except NoContextError:
return super().__new__(cls)
def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
num_vertices_prior = [2 ** i for i in range(2, max_num_nodes + 1)]
num_vertices_prior = (np.array(num_vertices_prior) / sum(num_vertices_prior)).tolist()
self.num_nodes = ValueChoice(list(range(2, max_num_nodes + 1)),
prior=num_vertices_prior,
label=f'{self._label}/num_nodes')
self.max_num_nodes = max_num_nodes
self.max_num_edges = max_num_edges
op_candidates = self._make_dict(op_candidates)
# this is only for input validation and instantiating enough layer choice and input choice
self.hidden_features = out_features
self.projections = nn.ModuleList([nn.Identity()])
self.ops = nn.ModuleList([nn.Identity()])
self.inputs = nn.ModuleList([nn.Identity()])
for _ in range(1, max_num_nodes):
self.projections.append(projection(in_features, self.hidden_features))
for i in range(1, max_num_nodes):
if i < max_num_nodes - 1:
self.ops.append(LayerChoice(OrderedDict([(k, op(self.hidden_features)) for k, op in op_candidates.items()]),
label=f'{self._label}/op_{i}'))
self.inputs.append(InputChoice(i, None, label=f'{self._label}/input_{i}'))
@property
def label(self):
return self._label
def forward(self, x):
# This is a dummy forward and actually not used
tensors = [x]
for i in range(1, self.max_num_nodes):
node_input = self.inputs[i]([self.projections[i](tensors[0])] + [t for t in tensors[1:]])
if i < self.max_num_nodes - 1:
node_output = self.ops[i](node_input)
else:
node_output = node_input
tensors.append(node_output)
return tensors[-1]
class NasBench101Mutator(Mutator):
# for validation purposes
# for python execution engine
def __init__(self, label: Optional[str]):
super().__init__(label=label)
@staticmethod
def candidates(node):
if 'n_candidates' in node.operation.parameters:
return list(range(node.operation.parameters['n_candidates']))
else:
return node.operation.parameters['candidates']
@staticmethod
def number_of_chosen(node):
if 'n_chosen' in node.operation.parameters:
return node.operation.parameters['n_chosen']
return 1
def mutate(self, model: Model):
for node in model.get_nodes_by_label(self.label):
max_num_edges = node.operation.parameters['max_num_edges']
break
mutation_dict = {mut.mutator.label: mut.samples for mut in model.history}
num_nodes = mutation_dict[f'{self.label}/num_nodes'][0]
adjacency_list = [mutation_dict[f'{self.label}/input_{i}'] for i in range(1, num_nodes)]
if sum([len(e) for e in adjacency_list]) > max_num_edges:
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
matrix = _NasBench101CellFixed.build_connection_matrix(adjacency_list, num_nodes)
prune(matrix, [None] * len(matrix)) # dummy ops, possible to raise InvalidMutation inside
def dry_run(self, model):
return [], model
from typing import Optional
from typing import Any, Optional, Tuple
from ...utils import uid, get_current_context
......@@ -9,9 +9,21 @@ def generate_new_label(label: Optional[str]):
return label
def get_fixed_value(label: str):
def get_fixed_value(label: str) -> Any:
ret = get_current_context('fixed')
try:
return ret[generate_new_label(label)]
except KeyError:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
def get_fixed_dict(label_prefix: str) -> Tuple[str, Any]:
ret = get_current_context('fixed')
try:
label_prefix = generate_new_label(label_prefix)
ret = {k: v for k, v in ret.items() if k.startswith(label_prefix + '/')}
if not ret:
raise KeyError
return label_prefix, ret
except KeyError:
raise KeyError(f'Fixed context with prefix {label_prefix} not found. Existing values are: {ret}')
......@@ -8,7 +8,7 @@ import random
import time
from typing import Any, Dict, List
from .. import Sampler, submit_models, query_available_resources, budget_exhausted
from .. import InvalidMutation, Sampler, submit_models, query_available_resources, budget_exhausted
from .base import BaseStrategy
from .utils import dry_run_for_search_space, get_targeted_model
......@@ -121,4 +121,7 @@ class Random(BaseStrategy):
if budget_exhausted():
return
time.sleep(self._polling_interval)
submit_models(get_targeted_model(base_model, applied_mutators, sample))
try:
submit_models(get_targeted_model(base_model, applied_mutators, sample))
except InvalidMutation as e:
_logger.warning(f'Invalid mutation: {e}. Skip.')
......@@ -67,6 +67,10 @@ def get_importable_name(cls, relocate_module=False):
return module_name + '.' + cls.__name__
class NoContextError(Exception):
pass
class ContextStack:
"""
This is to maintain a globally-accessible context envinronment that is visible to everywhere.
......@@ -98,7 +102,8 @@ class ContextStack:
@classmethod
def top(cls, key: str) -> Any:
assert cls._stack[key], 'Context is empty.'
if not cls._stack[key]:
raise NoContextError('Context is empty.')
return cls._stack[key][-1]
......
......@@ -10,3 +10,4 @@ _generated_model
data
generated
lightning_logs
model.onnx
......@@ -5,7 +5,7 @@ from collections import Counter
import nni.retiarii.nn.pytorch as nn
import torch
import torch.nn.functional as F
from nni.retiarii import Sampler, basic_unit
from nni.retiarii import InvalidMutation, Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.execution.python import _unpack_if_only_one
......@@ -518,3 +518,29 @@ class Python(GraphIR):
@unittest.skip
def test_valuechoice_access_functional_expression(self): ...
def test_nasbench101_cell(self):
# this is only supported in python engine for now.
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.NasBench101Cell([lambda x: nn.Linear(x, x), lambda x: nn.Linear(x, x, bias=False)],
10, 16, lambda x, y: nn.Linear(x, y), max_num_nodes=5, max_num_edges=7)
def forward(self, x):
return self.cell(x)
raw_model, mutators = self._get_model_with_mutators(Net())
succeeded = 0
sampler = RandomSampler()
while succeeded <= 10:
try:
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
succeeded += 1
except InvalidMutation:
continue
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16]))
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment