Unverified Commit a0fd0036 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Merge pull request #5036 from microsoft/promote-retiarii-to-nas

[DO NOT SQUASH] Promote retiarii to NAS
parents d6dcb483 bc6d8796
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .mutator import RegularizedDartsMutator, RegularizedMutatorParallel, DartsDiscreteMutator """Famous building blocks of search spaces."""
from .trainer import CdartsTrainer
from .autoactivation import *
from .nasbench101 import *
from .nasbench201 import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from __future__ import annotations
from packaging.version import Version
import torch
import torch.nn as nn
from nni.nas.utils import basic_unit
from nni.nas.nn.pytorch import LayerChoice
from nni.nas.nn.pytorch.mutation_utils import generate_new_label
__all__ = ['AutoActivation']
TorchVersion = '1.5.0'
# ============== unary function modules ==============
@basic_unit
class UnaryIdentity(nn.Module):
def forward(self, x):
return x
@basic_unit
class UnaryNegative(nn.Module):
def forward(self, x):
return -x
@basic_unit
class UnaryAbs(nn.Module):
def forward(self, x):
return torch.abs(x)
@basic_unit
class UnarySquare(nn.Module):
def forward(self, x):
return torch.square(x)
@basic_unit
class UnaryPow(nn.Module):
def forward(self, x):
return torch.pow(x, 3)
@basic_unit
class UnarySqrt(nn.Module):
def forward(self, x):
return torch.sqrt(x)
@basic_unit
class UnaryMul(nn.Module):
def __init__(self):
super().__init__()
# element-wise for now, will change to per-channel trainable parameter
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return x * self.beta
@basic_unit
class UnaryAdd(nn.Module):
def __init__(self):
super().__init__()
# element-wise for now, will change to per-channel trainable parameter
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return x + self.beta
@basic_unit
class UnaryLogAbs(nn.Module):
def forward(self, x):
return torch.log(torch.abs(x) + 1e-7)
@basic_unit
class UnaryExp(nn.Module):
def forward(self, x):
return torch.exp(x)
@basic_unit
class UnarySin(nn.Module):
def forward(self, x):
return torch.sin(x)
@basic_unit
class UnaryCos(nn.Module):
def forward(self, x):
return torch.cos(x)
@basic_unit
class UnarySinh(nn.Module):
def forward(self, x):
return torch.sinh(x)
@basic_unit
class UnaryCosh(nn.Module):
def forward(self, x):
return torch.cosh(x)
@basic_unit
class UnaryTanh(nn.Module):
def forward(self, x):
return torch.tanh(x)
if not Version(torch.__version__) >= Version(TorchVersion):
@basic_unit
class UnaryAsinh(nn.Module):
def forward(self, x):
return torch.asinh(x)
@basic_unit
class UnaryAtan(nn.Module):
def forward(self, x):
return torch.atan(x)
if not Version(torch.__version__) >= Version(TorchVersion):
@basic_unit
class UnarySinc(nn.Module):
def forward(self, x):
return torch.sinc(x)
@basic_unit
class UnaryMax(nn.Module):
def forward(self, x):
return torch.max(x, torch.zeros_like(x))
@basic_unit
class UnaryMin(nn.Module):
def forward(self, x):
return torch.min(x, torch.zeros_like(x))
@basic_unit
class UnarySigmoid(nn.Module):
def forward(self, x):
return torch.sigmoid(x)
@basic_unit
class UnaryLogExp(nn.Module):
def forward(self, x):
return torch.log(1 + torch.exp(x))
@basic_unit
class UnaryExpSquare(nn.Module):
def forward(self, x):
return torch.exp(-torch.square(x))
@basic_unit
class UnaryErf(nn.Module):
def forward(self, x):
return torch.erf(x)
unary_modules = ['UnaryIdentity', 'UnaryNegative', 'UnaryAbs', 'UnarySquare', 'UnaryPow',
'UnarySqrt', 'UnaryMul', 'UnaryAdd', 'UnaryLogAbs', 'UnaryExp', 'UnarySin', 'UnaryCos',
'UnarySinh', 'UnaryCosh', 'UnaryTanh', 'UnaryAtan', 'UnaryMax',
'UnaryMin', 'UnarySigmoid', 'UnaryLogExp', 'UnaryExpSquare', 'UnaryErf']
if not Version(torch.__version__) >= Version(TorchVersion):
unary_modules.append('UnaryAsinh')
unary_modules.append('UnarySinc')
# ============== binary function modules ==============
@basic_unit
class BinaryAdd(nn.Module):
def forward(self, x):
return x[0] + x[1]
@basic_unit
class BinaryMul(nn.Module):
def forward(self, x):
return x[0] * x[1]
@basic_unit
class BinaryMinus(nn.Module):
def forward(self, x):
return x[0] - x[1]
@basic_unit
class BinaryDivide(nn.Module):
def forward(self, x):
return x[0] / (x[1] + 1e-7)
@basic_unit
class BinaryMax(nn.Module):
def forward(self, x):
return torch.max(x[0], x[1])
@basic_unit
class BinaryMin(nn.Module):
def forward(self, x):
return torch.min(x[0], x[1])
@basic_unit
class BinarySigmoid(nn.Module):
def forward(self, x):
return torch.sigmoid(x[0]) * x[1]
@basic_unit
class BinaryExpSquare(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return torch.exp(-self.beta * torch.square(x[0] - x[1]))
@basic_unit
class BinaryExpAbs(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return torch.exp(-self.beta * torch.abs(x[0] - x[1]))
@basic_unit
class BinaryParamAdd(nn.Module):
def __init__(self):
super().__init__()
self.beta = torch.nn.Parameter(torch.tensor(1, dtype=torch.float32)) # pylint: disable=not-callable
def forward(self, x):
return self.beta * x[0] + (1 - self.beta) * x[1]
binary_modules = ['BinaryAdd', 'BinaryMul', 'BinaryMinus', 'BinaryDivide', 'BinaryMax',
'BinaryMin', 'BinarySigmoid', 'BinaryExpSquare', 'BinaryExpAbs', 'BinaryParamAdd']
class AutoActivation(nn.Module):
"""
This module is an implementation of the paper `Searching for Activation Functions <https://arxiv.org/abs/1710.05941>`__.
Parameters
----------
unit_num : int
the number of core units
Notes
-----
Current `beta` is not per-channel parameter.
"""
def __init__(self, unit_num: int = 1, label: str | None = None):
super().__init__()
self._label = generate_new_label(label)
self.unaries = nn.ModuleList()
self.binaries = nn.ModuleList()
self.first_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label = f'{self.label}__unary_0')
for i in range(unit_num):
one_unary = LayerChoice([eval('{}()'.format(unary)) for unary in unary_modules], label = f'{self.label}__unary_{i+1}')
self.unaries.append(one_unary)
for i in range(unit_num):
one_binary = LayerChoice([eval('{}()'.format(binary)) for binary in binary_modules], label = f'{self.label}__binary_{i}')
self.binaries.append(one_binary)
@property
def label(self):
return self._label
def forward(self, x):
out = self.first_unary(x)
for unary, binary in zip(self.unaries, self.binaries):
out = binary(torch.stack([out, unary(x)]))
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['NasBench101Cell', 'NasBench101Mutator']
import logging
from collections import OrderedDict
from typing import Callable, List, Optional, Union, Dict, Tuple, cast
import numpy as np
import torch
import torch.nn as nn
from nni.nas.mutable import InvalidMutation, Mutator
from nni.nas.execution.common import Model
from nni.nas.nn.pytorch import InputChoice, ValueChoice, LayerChoice
from nni.nas.nn.pytorch.mutation_utils import Mutable, generate_new_label, get_fixed_dict
_logger = logging.getLogger(__name__)
def compute_vertex_channels(input_channels, output_channels, matrix):
"""
This is (almost) copied from the original NAS-Bench-101 implementation.
Computes the number of channels at every vertex.
Given the input channels and output channels, this calculates the number of channels at each interior vertex.
Interior vertices have the same number of channels as the max of the channels of the vertices it feeds into.
The output channels are divided amongst the vertices that are directly connected to it.
When the division is not even, some vertices may receive an extra channel to compensate.
Parameters
----------
in_channels : int
input channels count.
output_channels : int
output channel count.
matrix : np.ndarray
adjacency matrix for the module (pruned by model_spec).
Returns
-------
list of int
list of channel counts, in order of the vertices.
"""
num_vertices = np.shape(matrix)[0]
vertex_channels = [0] * num_vertices
vertex_channels[0] = input_channels
vertex_channels[num_vertices - 1] = output_channels
if num_vertices == 2:
# Edge case where module only has input and output vertices
return vertex_channels
# Compute the in-degree ignoring input, axis 0 is the src vertex and axis 1 is
# the dst vertex. Summing over 0 gives the in-degree count of each vertex.
in_degree = np.sum(matrix[1:], axis=0)
interior_channels = output_channels // in_degree[num_vertices - 1]
correction = output_channels % in_degree[num_vertices - 1] # Remainder to add
# Set channels of vertices that flow directly to output
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
vertex_channels[v] = interior_channels
if correction:
vertex_channels[v] += 1
correction -= 1
# Set channels for all other vertices to the max of the out edges, going backwards.
# (num_vertices - 2) index skipped because it only connects to output.
for v in range(num_vertices - 3, 0, -1):
if not matrix[v, num_vertices - 1]:
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
vertex_channels[v] = max(vertex_channels[v], vertex_channels[dst])
assert vertex_channels[v] > 0
_logger.debug('vertex_channels: %s', str(vertex_channels))
# Sanity check, verify that channels never increase and final channels add up.
final_fan_in = 0
for v in range(1, num_vertices - 1):
if matrix[v, num_vertices - 1]:
final_fan_in += vertex_channels[v]
for dst in range(v + 1, num_vertices - 1):
if matrix[v, dst]:
assert vertex_channels[v] >= vertex_channels[dst]
assert final_fan_in == output_channels or num_vertices == 2
# num_vertices == 2 means only input/output nodes, so 0 fan-in
return vertex_channels
def prune(matrix, ops) -> Tuple[np.ndarray, List[Union[str, Callable[[int], nn.Module]]]]:
"""
Prune the extraneous parts of the graph.
General procedure:
1. Remove parts of graph not connected to input.
2. Remove parts of graph not connected to output.
3. Reorder the vertices so that they are consecutive after steps 1 and 2.
These 3 steps can be combined by deleting the rows and columns of the
vertices that are not reachable from both the input and output (in reverse).
"""
num_vertices = np.shape(matrix)[0]
# calculate the connection matrix within V number of steps.
connections = np.linalg.matrix_power(matrix + np.eye(num_vertices), num_vertices)
visited_from_input = set([i for i in range(num_vertices) if connections[0, i]])
visited_from_output = set([i for i in range(num_vertices) if connections[i, -1]])
# Any vertex that isn't connected to both input and output is extraneous to the computation graph.
extraneous = set(range(num_vertices)).difference(
visited_from_input.intersection(visited_from_output))
if len(extraneous) > num_vertices - 2:
raise InvalidMutation('Non-extraneous graph is less than 2 vertices, '
'the input is not connected to the output and the spec is invalid.')
matrix = np.delete(matrix, list(extraneous), axis=0)
matrix = np.delete(matrix, list(extraneous), axis=1)
for index in sorted(extraneous, reverse=True):
del ops[index]
return matrix, ops
def truncate(inputs, channels):
input_channels = inputs.size(1)
if input_channels < channels:
raise ValueError('input channel < output channels for truncate')
elif input_channels == channels:
return inputs # No truncation necessary
else:
# Truncation should only be necessary when channel division leads to
# vertices with +1 channels. The input vertex should always be projected to
# the minimum channel count.
assert input_channels - channels == 1
return inputs[:, :channels]
class _NasBench101CellFixed(nn.Module):
"""
The fixed version of NAS-Bench-101 Cell, used in python-version execution engine.
"""
def __init__(self, operations: List[Callable[[int], nn.Module]],
adjacency_list: List[List[int]],
in_features: int, out_features: int, num_nodes: int,
projection: Callable[[int, int], nn.Module]):
super().__init__()
assert num_nodes == len(operations) + 2 == len(adjacency_list) + 1
raw_operations: List[Union[str, Callable[[int], nn.Module]]] = list(operations)
del operations # operations is no longer needed. Delete it to avoid misuse
# add psuedo nodes
raw_operations.insert(0, 'IN')
raw_operations.append('OUT')
self.connection_matrix = self.build_connection_matrix(adjacency_list, num_nodes)
del num_nodes # raw number of nodes is no longer used
self.connection_matrix, self.operations = prune(self.connection_matrix, raw_operations)
self.hidden_features = compute_vertex_channels(in_features, out_features, self.connection_matrix)
self.num_nodes = len(self.connection_matrix)
self.in_features = in_features
self.out_features = out_features
_logger.info('Prund number of nodes: %d', self.num_nodes)
_logger.info('Pruned connection matrix: %s', str(self.connection_matrix))
self.projections = nn.ModuleList([nn.Identity()])
self.ops = nn.ModuleList([nn.Identity()])
for i in range(1, self.num_nodes):
self.projections.append(projection(in_features, self.hidden_features[i]))
for i in range(1, self.num_nodes - 1):
operation = cast(Callable[[int], nn.Module], self.operations[i])
self.ops.append(operation(self.hidden_features[i]))
@staticmethod
def build_connection_matrix(adjacency_list, num_nodes):
adjacency_list = [[]] + adjacency_list # add adjacency for first node
connections = np.zeros((num_nodes, num_nodes), dtype='int')
for i, lst in enumerate(adjacency_list):
assert all([0 <= k < i for k in lst])
for k in lst:
connections[k, i] = 1
return connections
def forward(self, inputs):
tensors = [inputs]
for t in range(1, self.num_nodes - 1):
# Create interior connections, truncating if necessary
add_in = [truncate(tensors[src], self.hidden_features[t])
for src in range(1, t) if self.connection_matrix[src, t]]
# Create add connection from projected input
if self.connection_matrix[0, t]:
add_in.append(self.projections[t](tensors[0]))
if len(add_in) == 1:
vertex_input = add_in[0]
else:
vertex_input = sum(add_in)
# Perform op at vertex t
vertex_out = self.ops[t](vertex_input)
tensors.append(vertex_out)
# Construct final output tensor by concating all fan-in and adding input.
if np.sum(self.connection_matrix[:, -1]) == 1:
src = np.where(self.connection_matrix[:, -1] == 1)[0][0]
return self.projections[-1](tensors[0]) if src == 0 else tensors[src]
outputs = torch.cat([tensors[src] for src in range(1, self.num_nodes - 1) if self.connection_matrix[src, -1]], 1)
if self.connection_matrix[0, -1]:
outputs += self.projections[-1](tensors[0])
assert outputs.size(1) == self.out_features
return outputs
class NasBench101Cell(Mutable):
"""
Cell structure that is proposed in NAS-Bench-101.
Proposed by `NAS-Bench-101: Towards Reproducible Neural Architecture Search <http://proceedings.mlr.press/v97/ying19a/ying19a.pdf>`__.
This cell is usually used in evaluation of NAS algorithms because there is a "comprehensive analysis" of this search space
available, which includes a full architecture-dataset that "maps 423k unique architectures to metrics
including run time and accuracy". You can also use the space in your own space design, in which scenario it should be possible
to leverage results in the benchmark to narrow the huge space down to a few efficient architectures.
The space of this cell architecture consists of all possible directed acyclic graphs on no more than ``max_num_nodes`` nodes,
where each possible node (other than IN and OUT) has one of ``op_candidates``, representing the corresponding operation.
Edges connecting the nodes can be no more than ``max_num_edges``.
To align with the paper settings, two vertices specially labeled as operation IN and OUT, are also counted into
``max_num_nodes`` in our implementaion, the default value of ``max_num_nodes`` is 7 and ``max_num_edges`` is 9.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. The shape
of each hidden nodes will be first automatically computed, depending on the cell structure. Each of the ``op_candidates``
should be a callable that accepts computed ``num_features`` and returns a ``Module``. For example,
.. code-block:: python
def conv_bn_relu(num_features):
return nn.Sequential(
nn.Conv2d(num_features, num_features, 1),
nn.BatchNorm2d(num_features),
nn.ReLU()
)
The output of each node is the sum of its input node feed into its operation, except for the last node (output node),
which is the concatenation of its input *hidden* nodes, adding the *IN* node (if IN and OUT are connected).
When input tensor is added with any other tensor, there could be shape mismatch. Therefore, a projection transformation
is needed to transform the input tensor. In paper, this is simply a Conv1x1 followed by BN and ReLU. The ``projection``
parameters accepts ``in_features`` and ``out_features``, returns a ``Module``. This parameter has no default value,
as we hold no assumption that users are dealing with images. An example for this parameter is,
.. code-block:: python
def projection_fn(in_features, out_features):
return nn.Conv2d(in_features, out_features, 1)
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts number of feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
projection : callable
Projection module that is used to preprocess the input tensor of the whole cell.
A callable that accept input feature and output feature, returning nn.Module.
max_num_nodes : int
Maximum number of nodes in the cell, input and output included. At least 2. Default: 7.
max_num_edges : int
Maximum number of edges in the cell. Default: 9.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
Warnings
--------
:class:`NasBench101Cell` is not supported in :ref:`graph-based execution engine <graph-based-execution-engine>`.
"""
@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
@classmethod
def create_fixed_module(cls, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
def make_list(x): return x if isinstance(x, list) else [x]
label, selected = get_fixed_dict(label)
op_candidates = cls._make_dict(op_candidates)
num_nodes = selected[f'{label}/num_nodes']
adjacency_list = [make_list(selected[f'{label}/input{i}']) for i in range(1, num_nodes)]
if sum([len(e) for e in adjacency_list]) > max_num_edges:
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
return _NasBench101CellFixed(
[op_candidates[selected[f'{label}/op{i}']] for i in range(1, num_nodes - 1)],
adjacency_list, in_features, out_features, num_nodes, projection)
# FIXME: weight inheritance on nasbench101 is not supported yet
def __init__(self, op_candidates: Union[Dict[str, Callable[[int], nn.Module]], List[Callable[[int], nn.Module]]],
in_features: int, out_features: int, projection: Callable[[int, int], nn.Module],
max_num_nodes: int = 7, max_num_edges: int = 9, label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
num_vertices_prior = [2 ** i for i in range(2, max_num_nodes + 1)]
num_vertices_prior = (np.array(num_vertices_prior) / sum(num_vertices_prior)).tolist()
self.num_nodes = ValueChoice(list(range(2, max_num_nodes + 1)),
prior=num_vertices_prior,
label=f'{self._label}/num_nodes')
self.max_num_nodes = max_num_nodes
self.max_num_edges = max_num_edges
op_candidates = self._make_dict(op_candidates)
# this is only for input validation and instantiating enough layer choice and input choice
self.hidden_features = out_features
self.projections = nn.ModuleList([nn.Identity()])
self.ops = nn.ModuleList([nn.Identity()])
self.inputs = nn.ModuleList([nn.Identity()])
for _ in range(1, max_num_nodes):
self.projections.append(projection(in_features, self.hidden_features))
for i in range(1, max_num_nodes):
if i < max_num_nodes - 1:
self.ops.append(LayerChoice(OrderedDict([(k, op(self.hidden_features)) for k, op in op_candidates.items()]),
label=f'{self._label}/op{i}'))
self.inputs.append(InputChoice(i, None, label=f'{self._label}/input{i}'))
@property
def label(self):
return self._label
def forward(self, x):
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors = [x]
for i in range(1, self.max_num_nodes):
node_input = self.inputs[i]([self.projections[i](tensors[0])] + [t for t in tensors[1:]])
if i < self.max_num_nodes - 1:
node_output = self.ops[i](node_input)
else:
node_output = node_input
tensors.append(node_output)
return tensors[-1]
class NasBench101Mutator(Mutator):
# for validation purposes
# for python execution engine
def __init__(self, label: str):
super().__init__(label=label)
@staticmethod
def candidates(node):
if 'n_candidates' in node.operation.parameters:
return list(range(node.operation.parameters['n_candidates']))
else:
return node.operation.parameters['candidates']
@staticmethod
def number_of_chosen(node):
if 'n_chosen' in node.operation.parameters:
return node.operation.parameters['n_chosen']
return 1
def mutate(self, model: Model):
max_num_edges = cast(int, None)
for node in model.get_nodes_by_label(self.label):
max_num_edges = node.operation.parameters['max_num_edges']
break
assert max_num_edges is not None
mutation_dict = {mut.mutator.label: mut.samples for mut in model.history}
num_nodes = mutation_dict[f'{self.label}/num_nodes'][0]
adjacency_list = [mutation_dict[f'{self.label}/input{i}'] for i in range(1, num_nodes)]
if sum([len(e) for e in adjacency_list]) > max_num_edges:
raise InvalidMutation(f'Expected {max_num_edges} edges, found: {adjacency_list}')
matrix = _NasBench101CellFixed.build_connection_matrix(adjacency_list, num_nodes)
operations = ['IN'] + [mutation_dict[f'{self.label}/op{i}'][0] for i in range(1, num_nodes - 1)] + ['OUT']
assert len(operations) == len(matrix)
matrix, operations = prune(matrix, operations) # possible to raise InvalidMutation inside
# NOTE: a hack to maintain a clean copy of what nasbench101 cell looks like
self._cur_samples = {}
for i in range(1, len(matrix)):
if i + 1 < len(matrix):
self._cur_samples[f'op{i}'] = operations[i]
self._cur_samples[f'input{i}'] = [k for k in range(i) if matrix[k, i]]
self._cur_samples = [self._cur_samples] # by design, _cur_samples is a list of samples
def dry_run(self, model):
return [], model
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
__all__ = ['NasBench201Cell']
from collections import OrderedDict
from typing import Callable, List, Dict, Union, Optional
import torch
import torch.nn as nn
from nni.nas.nn.pytorch import LayerChoice
from nni.nas.nn.pytorch.mutation_utils import generate_new_label
class NasBench201Cell(nn.Module):
"""
Cell structure that is proposed in NAS-Bench-201.
Proposed by `NAS-Bench-201: Extending the Scope of Reproducible Neural Architecture Search <https://arxiv.org/abs/2001.00326>`__.
This cell is a densely connected DAG with ``num_tensors`` nodes, where each node is tensor.
For every i < j, there is an edge from i-th node to j-th node.
Each edge in this DAG is associated with an operation transforming the hidden state from the source node
to the target node. All possible operations are selected from a predefined operation set, defined in ``op_candidates``.
Each of the ``op_candidates`` should be a callable that accepts input dimension and output dimension,
and returns a ``Module``.
Input of this cell should be of shape :math:`[N, C_{in}, *]`, while output should be :math:`[N, C_{out}, *]`. For example,
The space size of this cell would be :math:`|op|^{N(N-1)/2}`, where :math:`|op|` is the number of operation candidates,
and :math:`N` is defined by ``num_tensors``.
Parameters
----------
op_candidates : list of callable
Operation candidates. Each should be a function accepts input feature and output feature, returning nn.Module.
in_features : int
Input dimension of cell.
out_features : int
Output dimension of cell.
num_tensors : int
Number of tensors in the cell (input included). Default: 4
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
"""
@staticmethod
def _make_dict(x):
if isinstance(x, list):
return OrderedDict([(str(i), t) for i, t in enumerate(x)])
return OrderedDict(x)
def __init__(self, op_candidates: Union[Dict[str, Callable[[int, int], nn.Module]], List[Callable[[int, int], nn.Module]]],
in_features: int, out_features: int, num_tensors: int = 4,
label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
self.layers = nn.ModuleList()
self.in_features = in_features
self.out_features = out_features
self.num_tensors = num_tensors
op_candidates = self._make_dict(op_candidates)
for tid in range(1, num_tensors):
node_ops = nn.ModuleList()
for j in range(tid):
inp = in_features if j == 0 else out_features
op_choices = OrderedDict([(key, cls(inp, out_features))
for key, cls in op_candidates.items()])
node_ops.append(LayerChoice(op_choices, label=f'{self._label}__{j}_{tid}')) # put __ here to be compatible with base engine
self.layers.append(node_ops)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
"""
The forward of input choice is simply selecting first on all choices.
It shouldn't be called directly by users in most cases.
"""
tensors: List[torch.Tensor] = [inputs]
for layer in self.layers:
current_tensor: List[torch.Tensor] = []
for i, op in enumerate(layer): # type: ignore
current_tensor.append(op(tensors[i])) # type: ignore
tensors.append(torch.sum(torch.stack(current_tensor), 0))
return tensors[-1]
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
import torch
import torch.nn as nn
from nni.nas import model_wrapper
from .modules.nasbench101 import NasBench101Cell
__all__ = ['NasBench101']
def truncated_normal_(tensor: torch.Tensor, mean: float = 0, std: float = 1):
# https://discuss.pytorch.org/t/implementing-truncated-normal-initializer/4778/15
size = tensor.shape
tmp = tensor.new_empty(size + (4,)).normal_()
valid = (tmp < 2) & (tmp > -2)
ind = valid.max(-1, keepdim=True)[1]
tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
tensor.data.mul_(std).add_(mean)
class ConvBNReLU(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0):
super(ConvBNReLU, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.conv_bn_relu = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
fan_in = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
truncated_normal_(m.weight.data, mean=0., std=math.sqrt(1. / fan_in))
if isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def forward(self, x):
return self.conv_bn_relu(x)
class Conv3x3BNReLU(ConvBNReLU):
def __init__(self, in_channels, out_channels):
super(Conv3x3BNReLU, self).__init__(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
class Conv1x1BNReLU(ConvBNReLU):
def __init__(self, in_channels, out_channels):
super(Conv1x1BNReLU, self).__init__(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
Projection = Conv1x1BNReLU
@model_wrapper
class NasBench101(nn.Module):
"""The full search space, proposed by `NAS-Bench-101 <http://proceedings.mlr.press/v97/ying19a/ying19a.pdf>`__.
It's simply a stack of :class:`NasBench101Cell`. Operations are conv3x3, conv1x1 and maxpool respectively.
"""
def __init__(self,
stem_out_channels: int = 128,
num_stacks: int = 3,
num_modules_per_stack: int = 3,
max_num_vertices: int = 7,
max_num_edges: int = 9,
num_labels: int = 10,
bn_eps: float = 1e-5,
bn_momentum: float = 0.003):
super().__init__()
op_candidates = {
'conv3x3-bn-relu': lambda num_features: Conv3x3BNReLU(num_features, num_features),
'conv1x1-bn-relu': lambda num_features: Conv1x1BNReLU(num_features, num_features),
'maxpool3x3': lambda num_features: nn.MaxPool2d(3, 1, 1)
}
# initial stem convolution
self.stem_conv = Conv3x3BNReLU(3, stem_out_channels)
layers = []
in_channels = out_channels = stem_out_channels
for stack_num in range(num_stacks):
if stack_num > 0:
downsample = nn.MaxPool2d(kernel_size=2, stride=2)
layers.append(downsample)
out_channels *= 2
for _ in range(num_modules_per_stack):
cell = NasBench101Cell(op_candidates, in_channels, out_channels,
lambda cin, cout: Projection(cin, cout),
max_num_vertices, max_num_edges, label='cell')
layers.append(cell)
in_channels = out_channels
self.features = nn.ModuleList(layers)
self.gap = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(out_channels, num_labels)
for module in self.modules():
if isinstance(module, nn.BatchNorm2d):
module.eps = bn_eps
module.momentum = bn_momentum
def forward(self, x):
bs = x.size(0)
out = self.stem_conv(x)
for layer in self.features:
out = layer(out)
out = self.gap(out).view(bs, -1)
out = self.classifier(out)
return out
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import Callable, Dict
import torch
import torch.nn as nn
from nni.nas import model_wrapper
from .modules.nasbench201 import NasBench201Cell
__all__ = ['NasBench201']
OPS_WITH_STRIDE = {
'none': lambda C_in, C_out, stride: Zero(C_in, C_out, stride),
'avg_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'avg'),
'max_pool_3x3': lambda C_in, C_out, stride: Pooling(C_in, C_out, stride, 'max'),
'conv_3x3': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (3, 3), (stride, stride), (1, 1), (1, 1)),
'conv_1x1': lambda C_in, C_out, stride: ReLUConvBN(C_in, C_out, (1, 1), (stride, stride), (0, 0), (1, 1)),
'skip_connect': lambda C_in, C_out, stride: nn.Identity() if stride == 1 and C_in == C_out
else FactorizedReduce(C_in, C_out, stride),
}
PRIMITIVES = ['none', 'skip_connect', 'conv_1x1', 'conv_3x3', 'avg_pool_3x3']
class ReLUConvBN(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(ReLUConvBN, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_out, kernel_size, stride=stride,
padding=padding, dilation=dilation, bias=False),
nn.BatchNorm2d(C_out)
)
def forward(self, x):
return self.op(x)
class SepConv(nn.Module):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation):
super(SepConv, self).__init__()
self.op = nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out),
)
def forward(self, x):
return self.op(x)
class Pooling(nn.Module):
def __init__(self, C_in, C_out, stride, mode):
super(Pooling, self).__init__()
if C_in == C_out:
self.preprocess = None
else:
self.preprocess = ReLUConvBN(C_in, C_out, 1, 1, 0, 1)
if mode == 'avg':
self.op = nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False)
elif mode == 'max':
self.op = nn.MaxPool2d(3, stride=stride, padding=1)
else:
raise ValueError('Invalid mode={:} in Pooling'.format(mode))
def forward(self, x):
if self.preprocess:
x = self.preprocess(x)
return self.op(x)
class Zero(nn.Module):
def __init__(self, C_in, C_out, stride):
super(Zero, self).__init__()
self.C_in = C_in
self.C_out = C_out
self.stride = stride
self.is_zero = True
def forward(self, x):
if self.C_in == self.C_out:
if self.stride == 1:
return x.mul(0.)
else:
return x[:, :, ::self.stride, ::self.stride].mul(0.)
else:
shape = list(x.shape)
shape[1] = self.C_out
zeros = x.new_zeros(shape, dtype=x.dtype, device=x.device)
return zeros
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, stride):
super(FactorizedReduce, self).__init__()
self.stride = stride
self.C_in = C_in
self.C_out = C_out
self.relu = nn.ReLU(inplace=False)
if stride == 2:
C_outs = [C_out // 2, C_out - C_out // 2]
self.convs = nn.ModuleList()
for i in range(2):
self.convs.append(nn.Conv2d(C_in, C_outs[i], 1, stride=stride, padding=0, bias=False))
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
else:
raise ValueError('Invalid stride : {:}'.format(stride))
self.bn = nn.BatchNorm2d(C_out)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.convs[0](x), self.convs[1](y[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
class ResNetBasicblock(nn.Module):
def __init__(self, inplanes, planes, stride):
super(ResNetBasicblock, self).__init__()
assert stride == 1 or stride == 2, 'invalid stride {:}'.format(stride)
self.conv_a = ReLUConvBN(inplanes, planes, 3, stride, 1, 1)
self.conv_b = ReLUConvBN(planes, planes, 3, 1, 1, 1)
if stride == 2:
self.downsample = nn.Sequential(
nn.AvgPool2d(kernel_size=2, stride=2, padding=0),
nn.Conv2d(inplanes, planes, kernel_size=1, stride=1, padding=0, bias=False))
elif inplanes != planes:
self.downsample = ReLUConvBN(inplanes, planes, 1, 1, 0, 1)
else:
self.downsample = None
self.in_dim = inplanes
self.out_dim = planes
self.stride = stride
self.num_conv = 2
def forward(self, inputs):
basicblock = self.conv_a(inputs)
basicblock = self.conv_b(basicblock)
if self.downsample is not None:
inputs = self.downsample(inputs) # residual
return inputs + basicblock
@model_wrapper
class NasBench201(nn.Module):
"""The full search space proposed by `NAS-Bench-201 <https://arxiv.org/abs/2001.00326>`__.
It's a stack of :class:`NasBench201Cell`.
"""
def __init__(self,
stem_out_channels: int = 16,
num_modules_per_stack: int = 5,
num_labels: int = 10):
super().__init__()
self.channels = C = stem_out_channels
self.num_modules = N = num_modules_per_stack
self.num_labels = num_labels
self.stem = nn.Sequential(
nn.Conv2d(3, C, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(C)
)
layer_channels = [C] * N + [C * 2] + [C * 2] * N + [C * 4] + [C * 4] * N
layer_reductions = [False] * N + [True] + [False] * N + [True] + [False] * N
C_prev = C
self.cells = nn.ModuleList()
for C_curr, reduction in zip(layer_channels, layer_reductions):
if reduction:
cell = ResNetBasicblock(C_prev, C_curr, 2)
else:
ops: Dict[str, Callable[[int, int], nn.Module]] = {
prim: lambda C_in, C_out: OPS_WITH_STRIDE[prim](C_in, C_out, 1) for prim in PRIMITIVES
}
cell = NasBench201Cell(ops, C_prev, C_curr, label='cell')
self.cells.append(cell)
C_prev = C_curr
self.lastact = nn.Sequential(
nn.BatchNorm2d(C_prev),
nn.ReLU(inplace=True)
)
self.global_pooling = nn.AdaptiveAvgPool2d(1)
self.classifier = nn.Linear(C_prev, self.num_labels)
def forward(self, inputs):
feature = self.stem(inputs)
for cell in self.cells:
feature = cell(feature)
out = self.lastact(feature)
out = self.global_pooling(out)
out = out.view(out.size(0), -1)
logits = self.classifier(out)
return logits
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""File containing NASNet-series search space.
The implementation is based on NDS.
It's called ``nasnet.py`` simply because NASNet is the first to propose such structure.
"""
from collections import OrderedDict
from functools import partial
from typing import Tuple, List, Union, Iterable, Dict, Callable, Optional, cast
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from nni.nas.oneshot.pytorch.supermodule.sampling import PathSamplingRepeat
from nni.nas.oneshot.pytorch.supermodule.differentiable import DifferentiableMixedRepeat
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
# the following are NAS operations from
# https://github.com/facebookresearch/unnas/blob/main/pycls/models/nas/operations.py
OPS = {
'none': lambda C, stride, affine:
Zero(stride),
'avg_pool_2x2': lambda C, stride, affine:
nn.AvgPool2d(2, stride=stride, padding=0, count_include_pad=False),
'avg_pool_3x3': lambda C, stride, affine:
nn.AvgPool2d(3, stride=stride, padding=1, count_include_pad=False),
'avg_pool_5x5': lambda C, stride, affine:
nn.AvgPool2d(5, stride=stride, padding=2, count_include_pad=False),
'max_pool_2x2': lambda C, stride, affine:
nn.MaxPool2d(2, stride=stride, padding=0),
'max_pool_3x3': lambda C, stride, affine:
nn.MaxPool2d(3, stride=stride, padding=1),
'max_pool_5x5': lambda C, stride, affine:
nn.MaxPool2d(5, stride=stride, padding=2),
'max_pool_7x7': lambda C, stride, affine:
nn.MaxPool2d(7, stride=stride, padding=3),
'skip_connect': lambda C, stride, affine:
nn.Identity() if stride == 1 else FactorizedReduce(C, C, affine=affine),
'conv_1x1': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_3x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, 3, stride=stride, padding=1, bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'sep_conv_3x3': lambda C, stride, affine:
SepConv(C, C, 3, stride, 1, affine=affine),
'sep_conv_5x5': lambda C, stride, affine:
SepConv(C, C, 5, stride, 2, affine=affine),
'sep_conv_7x7': lambda C, stride, affine:
SepConv(C, C, 7, stride, 3, affine=affine),
'dil_conv_3x3': lambda C, stride, affine:
DilConv(C, C, 3, stride, 2, 2, affine=affine),
'dil_conv_5x5': lambda C, stride, affine:
DilConv(C, C, 5, stride, 4, 2, affine=affine),
'dil_sep_conv_3x3': lambda C, stride, affine:
DilSepConv(C, C, 3, stride, 2, 2, affine=affine),
'conv_3x1_1x3': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 3), stride=(1, stride), padding=(0, 1), bias=False),
nn.Conv2d(C, C, (3, 1), stride=(stride, 1), padding=(1, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
'conv_7x1_1x7': lambda C, stride, affine:
nn.Sequential(
nn.ReLU(inplace=False),
nn.Conv2d(C, C, (1, 7), stride=(1, stride), padding=(0, 3), bias=False),
nn.Conv2d(C, C, (7, 1), stride=(stride, 1), padding=(3, 0), bias=False),
nn.BatchNorm2d(C, affine=affine)
),
}
class ReLUConvBN(nn.Sequential):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_out, kernel_size, stride=stride,
padding=padding, bias=False
),
nn.BatchNorm2d(C_out, affine=affine)
)
class DilConv(nn.Sequential):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
class SepConv(nn.Sequential):
def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
super().__init__(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
class DilSepConv(nn.Sequential):
def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, affine=True):
super().__init__(
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=stride,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_in, affine=affine),
nn.ReLU(inplace=False),
nn.Conv2d(
C_in, C_in, kernel_size=kernel_size, stride=1,
padding=padding, dilation=dilation, groups=C_in, bias=False
),
nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
nn.BatchNorm2d(C_out, affine=affine),
)
class Zero(nn.Module):
def __init__(self, stride):
super().__init__()
self.stride = stride
def forward(self, x):
if self.stride == 1:
return x.mul(0.)
return x[:, :, ::self.stride, ::self.stride].mul(0.)
class FactorizedReduce(nn.Module):
def __init__(self, C_in, C_out, affine=True):
super().__init__()
if isinstance(C_out, int):
assert C_out % 2 == 0
else: # is a value choice
assert all(c % 2 == 0 for c in C_out.all_options())
self.relu = nn.ReLU(inplace=False)
self.conv_1 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.conv_2 = nn.Conv2d(C_in, C_out // 2, 1, stride=2, padding=0, bias=False)
self.bn = nn.BatchNorm2d(C_out, affine=affine)
self.pad = nn.ConstantPad2d((0, 1, 0, 1), 0)
def forward(self, x):
x = self.relu(x)
y = self.pad(x)
out = torch.cat([self.conv_1(x), self.conv_2(y[:, :, 1:, 1:])], dim=1)
out = self.bn(out)
return out
class DropPath_(nn.Module):
# https://github.com/khanrc/pt.darts/blob/0.1/models/ops.py
def __init__(self, drop_prob=0.):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
if self.training and self.drop_prob > 0.:
keep_prob = 1. - self.drop_prob
mask = torch.zeros((x.size(0), 1, 1, 1), dtype=torch.float, device=x.device).bernoulli_(keep_prob)
return x.div(keep_prob).mul(mask)
return x
class AuxiliaryHead(nn.Module):
def __init__(self, C: int, num_labels: int, dataset: Literal['imagenet', 'cifar']):
super().__init__()
if dataset == 'imagenet':
# assuming input size 14x14
stride = 2
elif dataset == 'cifar':
stride = 3
self.features = nn.Sequential(
nn.ReLU(inplace=True),
nn.AvgPool2d(5, stride=stride, padding=0, count_include_pad=False),
nn.Conv2d(C, 128, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 768, 2, bias=False),
nn.BatchNorm2d(768),
nn.ReLU(inplace=True)
)
self.classifier = nn.Linear(768, num_labels)
def forward(self, x):
x = self.features(x)
x = self.classifier(x.view(x.size(0), -1))
return x
class SequentialBreakdown(nn.Sequential):
"""Return all layers of a sequential."""
def __init__(self, sequential: nn.Sequential):
super().__init__(OrderedDict(sequential.named_children()))
def forward(self, inputs):
result = []
for module in self:
inputs = module(inputs)
result.append(inputs)
return result
class CellPreprocessor(nn.Module):
"""
Aligning the shape of predecessors.
If the last cell is a reduction cell, ``pre0`` should be ``FactorizedReduce`` instead of ``ReLUConvBN``.
See :class:`CellBuilder` on how to calculate those channel numbers.
"""
def __init__(self, C_pprev: nn.MaybeChoice[int], C_prev: nn.MaybeChoice[int], C: nn.MaybeChoice[int], last_cell_reduce: bool) -> None:
super().__init__()
if last_cell_reduce:
self.pre0 = FactorizedReduce(cast(int, C_pprev), cast(int, C))
else:
self.pre0 = ReLUConvBN(cast(int, C_pprev), cast(int, C), 1, 1, 0)
self.pre1 = ReLUConvBN(cast(int, C_prev), cast(int, C), 1, 1, 0)
def forward(self, cells):
assert len(cells) == 2
pprev, prev = cells
pprev = self.pre0(pprev)
prev = self.pre1(prev)
return [pprev, prev]
class CellPostprocessor(nn.Module):
"""
The cell outputs previous cell + this cell, so that cells can be directly chained.
"""
def forward(self, this_cell, previous_cells):
return [previous_cells[-1], this_cell]
class CellBuilder:
"""The cell builder is used in Repeat.
Builds an cell each time it's "called".
Note that the builder is ephemeral, it can only be called once for every index.
"""
def __init__(self, op_candidates: List[str],
C_prev_in: nn.MaybeChoice[int],
C_in: nn.MaybeChoice[int],
C: nn.MaybeChoice[int],
num_nodes: int,
merge_op: Literal['all', 'loose_end'],
first_cell_reduce: bool, last_cell_reduce: bool):
self.C_prev_in = C_prev_in # This is the out channels of the cell before last cell.
self.C_in = C_in # This is the out channesl of last cell.
self.C = C # This is NOT C_out of this stage, instead, C_out = C * len(cell.output_node_indices)
self.op_candidates = op_candidates
self.num_nodes = num_nodes
self.merge_op: Literal['all', 'loose_end'] = merge_op
self.first_cell_reduce = first_cell_reduce
self.last_cell_reduce = last_cell_reduce
self._expect_idx = 0
# It takes an index that is the index in the repeat.
# Number of predecessors for each cell is fixed to 2.
self.num_predecessors = 2
# Number of ops per node is fixed to 2.
self.num_ops_per_node = 2
def op_factory(self, node_index: int, op_index: int, input_index: Optional[int], *,
op: str, channels: int, is_reduction_cell: bool):
if is_reduction_cell and (
input_index is None or input_index < self.num_predecessors
): # could be none when constructing search sapce
stride = 2
else:
stride = 1
return OPS[op](channels, stride, True)
def __call__(self, repeat_idx: int):
if self._expect_idx != repeat_idx:
raise ValueError(f'Expect index {self._expect_idx}, found {repeat_idx}')
# Reduction cell means stride = 2 and channel multiplied by 2.
is_reduction_cell = repeat_idx == 0 and self.first_cell_reduce
# self.C_prev_in, self.C_in, self.last_cell_reduce are updated after each cell is built.
preprocessor = CellPreprocessor(self.C_prev_in, self.C_in, self.C, self.last_cell_reduce)
ops_factory: Dict[str, Callable[[int, int, Optional[int]], nn.Module]] = {}
for op in self.op_candidates:
ops_factory[op] = partial(self.op_factory, op=op, channels=cast(int, self.C), is_reduction_cell=is_reduction_cell)
cell = nn.Cell(ops_factory, self.num_nodes, self.num_ops_per_node, self.num_predecessors, self.merge_op,
preprocessor=preprocessor, postprocessor=CellPostprocessor(),
label='reduce' if is_reduction_cell else 'normal')
# update state
self.C_prev_in = self.C_in
self.C_in = self.C * len(cell.output_node_indices)
self.last_cell_reduce = is_reduction_cell
self._expect_idx += 1
return cell
class NDSStage(nn.Repeat):
"""This class defines NDSStage, a special type of Repeat, for isinstance check, and shape alignment.
In NDS, we can't simply use Repeat to stack the blocks,
because the output shape of each stacked block can be different.
This is a problem for one-shot strategy because they assume every possible candidate
should return values of the same shape.
Therefore, we need :class:`NDSStagePathSampling` and :class:`NDSStageDifferentiable`
to manually align the shapes -- specifically, to transform the first block in each stage.
This is not required though, when depth is not changing, or the mutable depth causes no problem
(e.g., when the minimum depth is large enough).
.. attention::
Assumption: Loose end is treated as all in ``merge_op`` (the case in one-shot),
which enforces reduction cell and normal cells in the same stage to have the exact same output shape.
"""
estimated_out_channels_prev: int
"""Output channels of cells in last stage."""
estimated_out_channels: int
"""Output channels of this stage. It's **estimated** because it assumes ``all`` as ``merge_op``."""
downsampling: bool
"""This stage has downsampling"""
def first_cell_transformation_factory(self) -> Optional[nn.Module]:
"""To make the "previous cell" in first cell's output have the same shape as cells in this stage."""
if self.downsampling:
return FactorizedReduce(self.estimated_out_channels_prev, self.estimated_out_channels)
elif self.estimated_out_channels_prev is not self.estimated_out_channels:
# Can't use != here, ValueChoice doesn't support
return ReLUConvBN(self.estimated_out_channels_prev, self.estimated_out_channels, 1, 1, 0)
return None
class NDSStagePathSampling(PathSamplingRepeat):
"""The path-sampling implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.choice.ValueChoiceX):
return cls(
module.first_cell_transformation_factory(),
cast(List[nn.Module], module.blocks),
module.depth_choice
)
def __init__(self, first_cell_transformation: Optional[nn.Module], *args, **kwargs):
super().__init__(*args, **kwargs)
self.first_cell_transformation = first_cell_transformation
def reduction(self, items: List[Tuple[torch.Tensor, torch.Tensor]], sampled: List[int]) -> Tuple[torch.Tensor, torch.Tensor]:
if 1 not in sampled or self.first_cell_transformation is None:
return super().reduction(items, sampled)
# items[0] must be the result of first cell
assert len(items[0]) == 2
# Only apply the transformation on "prev" output.
items[0] = (self.first_cell_transformation(items[0][0]), items[0][1])
return super().reduction(items, sampled)
class NDSStageDifferentiable(DifferentiableMixedRepeat):
"""The differentiable implementation (for one-shot) of each NDS stage if depth is mutating."""
@classmethod
def mutate(cls, module, name, memo, mutate_kwargs):
if isinstance(module, NDSStage) and isinstance(module.depth_choice, nn.choice.ValueChoiceX):
# Only interesting when depth is mutable
softmax = mutate_kwargs.get('softmax', nn.Softmax(-1))
return cls(
module.first_cell_transformation_factory(),
cast(List[nn.Module], module.blocks),
module.depth_choice,
softmax,
memo
)
def __init__(self, first_cell_transformation: Optional[nn.Module], *args, **kwargs):
super().__init__(*args, **kwargs)
self.first_cell_transformation = first_cell_transformation
def reduction(
self, items: List[Tuple[torch.Tensor, torch.Tensor]], weights: List[float], depths: List[int]
) -> Tuple[torch.Tensor, torch.Tensor]:
if 1 not in depths or self.first_cell_transformation is None:
return super().reduction(items, weights, depths)
# Same as NDSStagePathSampling
assert len(items[0]) == 2
items[0] = (self.first_cell_transformation(items[0][0]), items[0][1])
return super().reduction(items, weights, depths)
_INIT_PARAMETER_DOCS = """
Parameters
----------
width : int or tuple of int
A fixed initial width or a tuple of widths to choose from.
num_cells : int or tuple of int
A fixed number of cells (depths) to stack, or a tuple of depths to choose from.
dataset : "cifar" | "imagenet"
The essential differences are in "stem" cells, i.e., how they process the raw image input.
Choosing "imagenet" means more downsampling at the beginning of the network.
auxiliary_loss : bool
If true, another auxiliary classification head will produce the another prediction.
This makes the output of network two logits in the training phase.
"""
class NDS(nn.Module):
__doc__ = """
The unified version of NASNet search space.
We follow the implementation in
`unnas <https://github.com/facebookresearch/unnas/blob/main/pycls/models/nas/nas.py>`__.
See `On Network Design Spaces for Visual Recognition <https://arxiv.org/abs/1905.13214>`__ for details.
Different NAS papers usually differ in the way that they specify ``op_candidates`` and ``merge_op``.
``dataset`` here is to give a hint about input resolution, so as to create reasonable stem and auxiliary heads.
NDS has a speciality that it has mutable depths/widths.
This is implemented by accepting a list of int as ``num_cells`` / ``width``.
""" + _INIT_PARAMETER_DOCS + """
op_candidates : list of str
List of operator candidates. Must be from ``OPS``.
merge_op : ``all`` or ``loose_end``
See :class:`~nni.retiarii.nn.pytorch.Cell`.
num_nodes_per_cell : int
See :class:`~nni.retiarii.nn.pytorch.Cell`.
"""
def __init__(self,
op_candidates: List[str],
merge_op: Literal['all', 'loose_end'] = 'all',
num_nodes_per_cell: int = 4,
width: Union[Tuple[int, ...], int] = 16,
num_cells: Union[Tuple[int, ...], int] = 20,
dataset: Literal['cifar', 'imagenet'] = 'imagenet',
auxiliary_loss: bool = False):
super().__init__()
self.dataset = dataset
self.num_labels = 10 if dataset == 'cifar' else 1000
self.auxiliary_loss = auxiliary_loss
# preprocess the specified width and depth
if isinstance(width, Iterable):
C = nn.ValueChoice(list(width), label='width')
else:
C = width
self.num_cells: nn.MaybeChoice[int] = cast(int, num_cells)
if isinstance(num_cells, Iterable):
self.num_cells = nn.ValueChoice(list(num_cells), label='depth')
num_cells_per_stage = [(i + 1) * self.num_cells // 3 - i * self.num_cells // 3 for i in range(3)]
# auxiliary head is different for network targetted at different datasets
if dataset == 'imagenet':
self.stem0 = nn.Sequential(
nn.Conv2d(3, cast(int, C // 2), kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(cast(int, C // 2)),
nn.ReLU(inplace=True),
nn.Conv2d(cast(int, C // 2), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
self.stem1 = nn.Sequential(
nn.ReLU(inplace=True),
nn.Conv2d(cast(int, C), cast(int, C), 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(C),
)
C_pprev = C_prev = C_curr = C
last_cell_reduce = True
elif dataset == 'cifar':
self.stem = nn.Sequential(
nn.Conv2d(3, cast(int, 3 * C), 3, padding=1, bias=False),
nn.BatchNorm2d(cast(int, 3 * C))
)
C_pprev = C_prev = 3 * C
C_curr = C
last_cell_reduce = False
else:
raise ValueError(f'Unsupported dataset: {dataset}')
self.stages = nn.ModuleList()
for stage_idx in range(3):
if stage_idx > 0:
C_curr *= 2
# For a stage, we get C_in, C_curr, and C_out.
# C_in is only used in the first cell.
# C_curr is number of channels for each operator in current stage.
# C_out is usually `C * num_nodes_per_cell` because of concat operator.
cell_builder = CellBuilder(op_candidates, C_pprev, C_prev, C_curr, num_nodes_per_cell,
merge_op, stage_idx > 0, last_cell_reduce)
stage: Union[NDSStage, nn.Sequential] = NDSStage(cell_builder, num_cells_per_stage[stage_idx])
if isinstance(stage, NDSStage):
stage.estimated_out_channels_prev = cast(int, C_prev)
stage.estimated_out_channels = cast(int, C_curr * num_nodes_per_cell)
stage.downsampling = stage_idx > 0
self.stages.append(stage)
# NOTE: output_node_indices will be computed on-the-fly in trial code.
# When constructing model space, it's just all the nodes in the cell,
# which happens to be the case of one-shot supernet.
# C_pprev is output channel number of last second cell among all the cells already built.
if len(stage) > 1:
# Contains more than one cell
C_pprev = len(cast(nn.Cell, stage[-2]).output_node_indices) * C_curr
else:
# Look up in the out channels of last stage.
C_pprev = C_prev
# This was originally,
# C_prev = num_nodes_per_cell * C_curr.
# but due to loose end, it becomes,
C_prev = len(cast(nn.Cell, stage[-1]).output_node_indices) * C_curr
# Useful in aligning the pprev and prev cell.
last_cell_reduce = cell_builder.last_cell_reduce
if stage_idx == 2:
C_to_auxiliary = C_prev
if auxiliary_loss:
assert isinstance(self.stages[2], nn.Sequential), 'Auxiliary loss can only be enabled in retrain mode.'
self.stages[2] = SequentialBreakdown(cast(nn.Sequential, self.stages[2]))
self.auxiliary_head = AuxiliaryHead(C_to_auxiliary, self.num_labels, dataset=self.dataset) # type: ignore
self.global_pooling = nn.AdaptiveAvgPool2d((1, 1))
self.classifier = nn.Linear(cast(int, C_prev), self.num_labels)
def forward(self, inputs):
if self.dataset == 'imagenet':
s0 = self.stem0(inputs)
s1 = self.stem1(s0)
else:
s0 = s1 = self.stem(inputs)
for stage_idx, stage in enumerate(self.stages):
if stage_idx == 2 and self.auxiliary_loss:
s = list(stage([s0, s1]).values())
s0, s1 = s[-1]
if self.training:
# auxiliary loss is attached to the first cell of the last stage.
logits_aux = self.auxiliary_head(s[0][1])
else:
s0, s1 = stage([s0, s1])
out = self.global_pooling(s1)
logits = self.classifier(out.view(out.size(0), -1))
if self.training and self.auxiliary_loss:
return logits, logits_aux # type: ignore
else:
return logits
def set_drop_path_prob(self, drop_prob):
"""
Set the drop probability of Drop-path in the network.
Reference: `FractalNet: Ultra-Deep Neural Networks without Residuals <https://arxiv.org/pdf/1605.07648v4.pdf>`__.
"""
for module in self.modules():
if isinstance(module, DropPath_):
module.drop_prob = drop_prob
@classmethod
def fixed_arch(cls, arch: dict) -> FixedFactory:
return FixedFactory(cls, arch)
@model_wrapper
class NASNet(NDS):
__doc__ = """
Search space proposed in `Learning Transferable Architectures for Scalable Image Recognition <https://arxiv.org/abs/1707.07012>`__.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~NASNet.NASNET_OPS`.
It has 5 nodes per cell, and the output is concatenation of nodes not used as input to other nodes.
""" + _INIT_PARAMETER_DOCS
NASNET_OPS = [
'skip_connect',
'conv_3x1_1x3',
'conv_7x1_1x7',
'dil_conv_3x3',
'avg_pool_3x3',
'max_pool_3x3',
'max_pool_5x5',
'max_pool_7x7',
'conv_1x1',
'conv_3x3',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
]
def __init__(self,
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.NASNET_OPS,
merge_op='loose_end',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
@model_wrapper
class ENAS(NDS):
__doc__ = """Search space proposed in `Efficient neural architecture search via parameter sharing <https://arxiv.org/abs/1802.03268>`__.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~ENAS.ENAS_OPS`.
It has 5 nodes per cell, and the output is concatenation of nodes not used as input to other nodes.
""" + _INIT_PARAMETER_DOCS
ENAS_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'avg_pool_3x3',
'max_pool_3x3',
]
def __init__(self,
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.ENAS_OPS,
merge_op='loose_end',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
@model_wrapper
class AmoebaNet(NDS):
__doc__ = """Search space proposed in
`Regularized evolution for image classifier architecture search <https://arxiv.org/abs/1802.01548>`__.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~AmoebaNet.AMOEBA_OPS`.
It has 5 nodes per cell, and the output is concatenation of nodes not used as input to other nodes.
""" + _INIT_PARAMETER_DOCS
AMOEBA_OPS = [
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'avg_pool_3x3',
'max_pool_3x3',
'dil_sep_conv_3x3',
'conv_7x1_1x7',
]
def __init__(self,
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.AMOEBA_OPS,
merge_op='loose_end',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
@model_wrapper
class PNAS(NDS):
__doc__ = """Search space proposed in
`Progressive neural architecture search <https://arxiv.org/abs/1712.00559>`__.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~PNAS.PNAS_OPS`.
It has 5 nodes per cell, and the output is concatenation of all nodes in the cell.
""" + _INIT_PARAMETER_DOCS
PNAS_OPS = [
'sep_conv_3x3',
'sep_conv_5x5',
'sep_conv_7x7',
'conv_7x1_1x7',
'skip_connect',
'avg_pool_3x3',
'max_pool_3x3',
'dil_conv_3x3',
]
def __init__(self,
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.PNAS_OPS,
merge_op='all',
num_nodes_per_cell=5,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
@model_wrapper
class DARTS(NDS):
__doc__ = """Search space proposed in `Darts: Differentiable architecture search <https://arxiv.org/abs/1806.09055>`__.
It is built upon :class:`~nni.retiarii.nn.pytorch.Cell`, and implemented based on :class:`~NDS`.
Its operator candidates are :attribute:`~DARTS.DARTS_OPS`.
It has 4 nodes per cell, and the output is concatenation of all nodes in the cell.
""" + _INIT_PARAMETER_DOCS
DARTS_OPS = [
'none',
'max_pool_3x3',
'avg_pool_3x3',
'skip_connect',
'sep_conv_3x3',
'sep_conv_5x5',
'dil_conv_3x3',
'dil_conv_5x5',
]
def __init__(self,
width: Union[Tuple[int, ...], int] = (16, 24, 32),
num_cells: Union[Tuple[int, ...], int] = (4, 8, 12, 16, 20),
dataset: Literal['cifar', 'imagenet'] = 'cifar',
auxiliary_loss: bool = False):
super().__init__(self.DARTS_OPS,
merge_op='all',
num_nodes_per_cell=4,
width=width,
num_cells=num_cells,
dataset=dataset,
auxiliary_loss=auxiliary_loss)
@classmethod
def load_searched_model(
cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True
) -> nn.Module:
init_kwargs = {} # all default
if name == 'darts-v2':
init_kwargs.update(
num_cells=20,
width=36,
)
arch = {
'normal/op_2_0': 'sep_conv_3x3',
'normal/op_2_1': 'sep_conv_3x3',
'normal/input_2_0': 0,
'normal/input_2_1': 1,
'normal/op_3_0': 'sep_conv_3x3',
'normal/op_3_1': 'sep_conv_3x3',
'normal/input_3_0': 0,
'normal/input_3_1': 1,
'normal/op_4_0': 'sep_conv_3x3',
'normal/op_4_1': 'skip_connect',
'normal/input_4_0': 1,
'normal/input_4_1': 0,
'normal/op_5_0': 'skip_connect',
'normal/op_5_1': 'dil_conv_3x3',
'normal/input_5_0': 0,
'normal/input_5_1': 2,
'reduce/op_2_0': 'max_pool_3x3',
'reduce/op_2_1': 'max_pool_3x3',
'reduce/input_2_0': 0,
'reduce/input_2_1': 1,
'reduce/op_3_0': 'skip_connect',
'reduce/op_3_1': 'max_pool_3x3',
'reduce/input_3_0': 2,
'reduce/input_3_1': 1,
'reduce/op_4_0': 'max_pool_3x3',
'reduce/op_4_1': 'skip_connect',
'reduce/input_4_0': 0,
'reduce/input_4_1': 2,
'reduce/op_5_0': 'skip_connect',
'reduce/op_5_1': 'max_pool_3x3',
'reduce/input_5_0': 2,
'reduce/input_5_1': 1
}
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = cls.fixed_arch(arch)
model = model_factory(**init_kwargs)
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import math
from typing import Optional, Callable, List, Tuple, Iterator, Union, cast, overload
import torch
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
@overload
def make_divisible(v: Union[int, float], divisor, min_val=None) -> int:
...
@overload
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float]], divisor, min_val=None) -> nn.ChoiceOf[int]:
...
def make_divisible(v: Union[nn.ChoiceOf[int], nn.ChoiceOf[float], int, float], divisor, min_val=None) -> nn.MaybeChoice[int]:
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
"""
if min_val is None:
min_val = divisor
# This should work for both value choices and constants.
new_v = nn.ValueChoice.max(min_val, round(v + divisor // 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
return nn.ValueChoice.condition(new_v < 0.9 * v, new_v + divisor, new_v)
def simplify_sequential(sequentials: List[nn.Module]) -> Iterator[nn.Module]:
"""
Flatten the sequential blocks so that the hierarchy looks better.
Eliminate identity modules automatically.
"""
for module in sequentials:
if isinstance(module, nn.Sequential):
for submodule in module.children():
# no recursive expansion
if not isinstance(submodule, nn.Identity):
yield submodule
else:
if not isinstance(module, nn.Identity):
yield module
class ConvBNReLU(nn.Sequential):
"""
The template for a conv-bn-relu block.
"""
def __init__(
self,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
groups: nn.MaybeChoice[int] = 1,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
dilation: int = 1,
) -> None:
padding = (kernel_size - 1) // 2 * dilation
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if activation_layer is None:
activation_layer = nn.ReLU6
# If no normalization is used, set bias to True
# https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L194
norm = norm_layer(cast(int, out_channels))
no_normalization = isinstance(norm, nn.Identity)
blocks: List[nn.Module] = [
nn.Conv2d(
cast(int, in_channels),
cast(int, out_channels),
cast(int, kernel_size),
stride,
cast(int, padding),
dilation=dilation,
groups=cast(int, groups),
bias=no_normalization
),
# Normalization, regardless of batchnorm or identity
norm,
# One pytorch implementation as an SE here, to faithfully reproduce paper
# We follow a more accepted approach to put SE outside
# Reference: https://github.com/d-li14/mobilenetv3.pytorch/issues/18
activation_layer(inplace=True)
]
super().__init__(*simplify_sequential(blocks))
class DepthwiseSeparableConv(nn.Sequential):
"""
In the original MobileNetV2 implementation, this is InvertedResidual when expand ratio = 1.
Residual connection is added if input and output shape are the same.
References:
- https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L90
- https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L433
- https://github.com/ultmaster/AceNAS/blob/46c8895f/searchspace/proxylessnas/utils.py#L100
"""
def __init__(
self,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
squeeze_excite: Optional[Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module]] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
blocks = [
# dw
ConvBNReLU(in_channels, in_channels, stride=stride, kernel_size=kernel_size, groups=in_channels,
norm_layer=norm_layer, activation_layer=activation_layer),
# optional se
squeeze_excite(in_channels, in_channels) if squeeze_excite else nn.Identity(),
# pw-linear
ConvBNReLU(in_channels, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity)
]
super().__init__(*simplify_sequential(blocks))
# NOTE: "is" is used here instead of "==" to avoid creating a new value choice.
self.has_skip = stride == 1 and in_channels is out_channels
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.has_skip:
return x + super().forward(x)
else:
return super().forward(x)
class InvertedResidual(nn.Sequential):
"""
An Inverted Residual Block, sometimes called an MBConv Block, is a type of residual block used for image models
that uses an inverted structure for efficiency reasons.
It was originally proposed for the `MobileNetV2 <https://arxiv.org/abs/1801.04381>`__ CNN architecture.
It has since been reused for several mobile-optimized CNNs.
It follows a narrow -> wide -> narrow approach, hence the inversion.
It first widens with a 1x1 convolution, then uses a 3x3 depthwise convolution (which greatly reduces the number of parameters),
then a 1x1 convolution is used to reduce the number of channels so input and output can be added.
This implementation is sort of a mixture between:
- https://github.com/google-research/google-research/blob/20736344/tunas/rematlib/mobile_model_v3.py#L453
- https://github.com/rwightman/pytorch-image-models/blob/b7cb8d03/timm/models/efficientnet_blocks.py#L134
"""
def __init__(
self,
in_channels: nn.MaybeChoice[int],
out_channels: nn.MaybeChoice[int],
expand_ratio: nn.MaybeChoice[float],
kernel_size: nn.MaybeChoice[int] = 3,
stride: int = 1,
squeeze_excite: Optional[Callable[[nn.MaybeChoice[int], nn.MaybeChoice[int]], nn.Module]] = None,
norm_layer: Optional[Callable[[int], nn.Module]] = None,
activation_layer: Optional[Callable[..., nn.Module]] = None,
) -> None:
super().__init__()
self.stride = stride
self.out_channels = out_channels
assert stride in [1, 2]
hidden_ch = cast(int, make_divisible(in_channels * expand_ratio, 8))
# NOTE: this equivalence check (==) does NOT work for ValueChoice, need to use "is"
self.has_skip = stride == 1 and in_channels is out_channels
layers: List[nn.Module] = [
# point-wise convolution
# NOTE: some paper omit this point-wise convolution when stride = 1.
# In our implementation, if this pw convolution is intended to be omitted,
# please use SepConv instead.
ConvBNReLU(in_channels, hidden_ch, kernel_size=1,
norm_layer=norm_layer, activation_layer=activation_layer),
# depth-wise
ConvBNReLU(hidden_ch, hidden_ch, stride=stride, kernel_size=kernel_size, groups=hidden_ch,
norm_layer=norm_layer, activation_layer=activation_layer),
# SE
squeeze_excite(
cast(int, hidden_ch),
cast(int, in_channels)
) if squeeze_excite is not None else nn.Identity(),
# pw-linear
ConvBNReLU(hidden_ch, out_channels, kernel_size=1, norm_layer=norm_layer, activation_layer=nn.Identity),
]
super().__init__(*simplify_sequential(layers))
def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.has_skip:
return x + super().forward(x)
else:
return super().forward(x)
def inverted_residual_choice_builder(
expand_ratios: List[int],
kernel_sizes: List[int],
downsample: bool,
stage_input_width: int,
stage_output_width: int,
label: str
):
def builder(index):
stride = 1
inp = stage_output_width
if index == 0:
# first layer in stage
# do downsample and width reshape
inp = stage_input_width
if downsample:
stride = 2
oup = stage_output_width
op_choices = {}
for exp_ratio in expand_ratios:
for kernel_size in kernel_sizes:
op_choices[f'k{kernel_size}e{exp_ratio}'] = InvertedResidual(inp, oup, exp_ratio, kernel_size, stride)
# It can be implemented with ValueChoice, but we use LayerChoice here
# to be aligned with the intention of the original ProxylessNAS.
return nn.LayerChoice(op_choices, label=f'{label}_i{index}')
return builder
@model_wrapper
class ProxylessNAS(nn.Module):
"""
The search space proposed by `ProxylessNAS <https://arxiv.org/abs/1812.00332>`__.
Following the official implementation, the inverted residual with kernel size / expand ratio variations in each layer
is implemented with a :class:`nn.LayerChoice` with all-combination candidates. That means,
when used in weight sharing, these candidates will be treated as separate layers, and won't be fine-grained shared.
We note that :class:`MobileNetV3Space` is different in this perspective.
This space can be implemented as part of :class:`MobileNetV3Space`, but we separate those following conventions.
"""
def __init__(self, num_labels: int = 1000,
base_widths: Tuple[int, ...] = (32, 16, 32, 40, 80, 96, 192, 320, 1280),
dropout_rate: float = 0.,
width_mult: float = 1.0,
bn_eps: float = 1e-3,
bn_momentum: float = 0.1):
super().__init__()
assert len(base_widths) == 9
# include the last stage info widths here
widths = [make_divisible(width * width_mult, 8) for width in base_widths]
downsamples = [True, False, True, True, True, False, True, False]
self.num_labels = num_labels
self.dropout_rate = dropout_rate
self.bn_eps = bn_eps
self.bn_momentum = bn_momentum
self.stem = ConvBNReLU(3, widths[0], stride=2, norm_layer=nn.BatchNorm2d)
blocks: List[nn.Module] = [
# first stage is fixed
DepthwiseSeparableConv(widths[0], widths[1], kernel_size=3, stride=1)
]
# https://github.com/ultmaster/AceNAS/blob/46c8895fd8a05ffbc61a6b44f1e813f64b4f66b7/searchspace/proxylessnas/__init__.py#L21
for stage in range(2, 8):
# Rather than returning a fixed module here,
# we return a builder that dynamically creates module for different `repeat_idx`.
builder = inverted_residual_choice_builder(
[3, 6], [3, 5, 7], downsamples[stage], widths[stage - 1], widths[stage], f's{stage}')
if stage < 7:
blocks.append(nn.Repeat(builder, (1, 4), label=f's{stage}_depth'))
else:
# No mutation for depth in the last stage.
# Directly call builder to initiate one block
blocks.append(builder(0))
self.blocks = nn.Sequential(*blocks)
# final layers
self.feature_mix_layer = ConvBNReLU(widths[7], widths[8], kernel_size=1, norm_layer=nn.BatchNorm2d)
self.global_avg_pooling = nn.AdaptiveAvgPool2d(1)
self.dropout_layer = nn.Dropout(dropout_rate)
self.classifier = nn.Linear(widths[-1], num_labels)
reset_parameters(self, bn_momentum=bn_momentum, bn_eps=bn_eps)
def forward(self, x):
x = self.stem(x)
x = self.blocks(x)
x = self.feature_mix_layer(x)
x = self.global_avg_pooling(x)
x = x.view(x.size(0), -1) # flatten
x = self.dropout_layer(x)
x = self.classifier(x)
return x
def no_weight_decay(self):
# this is useful for timm optimizer
# no regularizer to linear layer
if hasattr(self, 'classifier'):
return {'classifier.weight', 'classifier.bias'}
return set()
@classmethod
def fixed_arch(cls, arch: dict) -> FixedFactory:
return FixedFactory(cls, arch)
@classmethod
def load_searched_model(
cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True
) -> nn.Module:
init_kwargs = {} # all default
if name == 'acenas-m1':
arch = {
's2_depth': 2,
's2_i0': 'k3e6',
's2_i1': 'k3e3',
's3_depth': 3,
's3_i0': 'k5e3',
's3_i1': 'k3e3',
's3_i2': 'k5e3',
's4_depth': 2,
's4_i0': 'k3e6',
's4_i1': 'k5e3',
's5_depth': 4,
's5_i0': 'k7e6',
's5_i1': 'k3e6',
's5_i2': 'k3e6',
's5_i3': 'k7e3',
's6_depth': 4,
's6_i0': 'k7e6',
's6_i1': 'k7e6',
's6_i2': 'k7e3',
's6_i3': 'k7e3',
's7_depth': 1,
's7_i0': 'k7e6'
}
elif name == 'acenas-m2':
arch = {
's2_depth': 1,
's2_i0': 'k5e3',
's3_depth': 3,
's3_i0': 'k3e6',
's3_i1': 'k3e3',
's3_i2': 'k5e3',
's4_depth': 2,
's4_i0': 'k7e6',
's4_i1': 'k5e6',
's5_depth': 4,
's5_i0': 'k5e6',
's5_i1': 'k5e3',
's5_i2': 'k5e6',
's5_i3': 'k3e6',
's6_depth': 4,
's6_i0': 'k7e6',
's6_i1': 'k5e6',
's6_i2': 'k5e3',
's6_i3': 'k5e6',
's7_depth': 1,
's7_i0': 'k7e6'
}
elif name == 'acenas-m3':
arch = {
's2_depth': 2,
's2_i0': 'k3e3',
's2_i1': 'k3e6',
's3_depth': 2,
's3_i0': 'k5e3',
's3_i1': 'k3e3',
's4_depth': 3,
's4_i0': 'k5e6',
's4_i1': 'k7e6',
's4_i2': 'k3e6',
's5_depth': 4,
's5_i0': 'k7e6',
's5_i1': 'k7e3',
's5_i2': 'k7e3',
's5_i3': 'k5e3',
's6_depth': 4,
's6_i0': 'k7e6',
's6_i1': 'k7e3',
's6_i2': 'k7e6',
's6_i3': 'k3e3',
's7_depth': 1,
's7_i0': 'k5e6'
}
elif name == 'proxyless-cpu':
arch = {
's2_depth': 4,
's2_i0': 'k3e6',
's2_i1': 'k3e3',
's2_i2': 'k3e3',
's2_i3': 'k3e3',
's3_depth': 4,
's3_i0': 'k3e6',
's3_i1': 'k3e3',
's3_i2': 'k3e3',
's3_i3': 'k5e3',
's4_depth': 2,
's4_i0': 'k3e6',
's4_i1': 'k3e3',
's5_depth': 4,
's5_i0': 'k5e6',
's5_i1': 'k3e3',
's5_i2': 'k3e3',
's5_i3': 'k3e3',
's6_depth': 4,
's6_i0': 'k5e6',
's6_i1': 'k5e3',
's6_i2': 'k5e3',
's6_i3': 'k3e3',
's7_depth': 1,
's7_i0': 'k5e6'
}
init_kwargs['base_widths'] = [40, 24, 32, 48, 88, 104, 216, 360, 1432]
elif name == 'proxyless-gpu':
arch = {
's2_depth': 1,
's2_i0': 'k5e3',
's3_depth': 2,
's3_i0': 'k7e3',
's3_i1': 'k3e3',
's4_depth': 2,
's4_i0': 'k7e6',
's4_i1': 'k5e3',
's5_depth': 3,
's5_i0': 'k5e6',
's5_i1': 'k3e3',
's5_i2': 'k5e3',
's6_depth': 4,
's6_i0': 'k7e6',
's6_i1': 'k7e6',
's6_i2': 'k7e6',
's6_i3': 'k5e6',
's7_depth': 1,
's7_i0': 'k7e6'
}
init_kwargs['base_widths'] = [40, 24, 32, 56, 112, 128, 256, 432, 1728]
elif name == 'proxyless-mobile':
arch = {
's2_depth': 2,
's2_i0': 'k5e3',
's2_i1': 'k3e3',
's3_depth': 4,
's3_i0': 'k7e3',
's3_i1': 'k3e3',
's3_i2': 'k5e3',
's3_i3': 'k5e3',
's4_depth': 4,
's4_i0': 'k7e6',
's4_i1': 'k5e3',
's4_i2': 'k5e3',
's4_i3': 'k5e3',
's5_depth': 4,
's5_i0': 'k5e6',
's5_i1': 'k5e3',
's5_i2': 'k5e3',
's5_i3': 'k5e3',
's6_depth': 4,
's6_i0': 'k7e6',
's6_i1': 'k7e6',
's6_i2': 'k7e3',
's6_i3': 'k7e3',
's7_depth': 1,
's7_i0': 'k7e6'
}
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = cls.fixed_arch(arch)
model = model_factory(**init_kwargs)
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model
def reset_parameters(model, model_init='he_fout', init_div_groups=False,
bn_momentum=0.1, bn_eps=1e-5):
for m in model.modules():
if isinstance(m, nn.Conv2d):
if model_init == 'he_fout':
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
if init_div_groups:
n /= m.groups
m.weight.data.normal_(0, math.sqrt(2. / n))
elif model_init == 'he_fin':
n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
if init_div_groups:
n /= m.groups
m.weight.data.normal_(0, math.sqrt(2. / n))
else:
raise NotImplementedError
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
m.momentum = bn_momentum
m.eps = bn_eps
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm1d):
m.weight.data.fill_(1)
m.bias.data.zero_()
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from typing import cast
import torch
import nni.nas.nn.pytorch as nn
from nni.nas import model_wrapper
from .utils.fixed import FixedFactory
from .utils.pretrained import load_pretrained_weight
class ShuffleNetBlock(nn.Module):
"""
Describe the basic building block of shuffle net, as described in
`ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices <https://arxiv.org/pdf/1707.01083.pdf>`__.
When stride = 1, the block expects an input with ``2 * input channels``. Otherwise input channels.
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *,
kernel_size: int, stride: int, sequence: str = "pdp", affine: bool = True):
super().__init__()
assert stride in [1, 2]
assert kernel_size in [3, 5, 7]
self.channels = in_channels // 2 if stride == 1 else in_channels
self.in_channels = in_channels
self.out_channels = out_channels
self.mid_channels = mid_channels
self.kernel_size = kernel_size
self.stride = stride
self.pad = kernel_size // 2
self.oup_main = out_channels - self.channels
self.affine = affine
assert self.oup_main > 0
self.branch_main = nn.Sequential(*self._decode_point_depth_conv(sequence))
if stride == 2:
self.branch_proj = nn.Sequential(
# dw
nn.Conv2d(self.channels, self.channels, kernel_size, stride, self.pad,
groups=self.channels, bias=False),
nn.BatchNorm2d(self.channels, affine=affine),
# pw-linear
nn.Conv2d(self.channels, self.channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(self.channels, affine=affine),
nn.ReLU(inplace=True)
)
else:
# empty block to be compatible with torchscript
self.branch_proj = nn.Sequential()
def forward(self, x):
if self.stride == 2:
x_proj, x = self.branch_proj(x), x
else:
x_proj, x = self._channel_shuffle(x)
return torch.cat((x_proj, self.branch_main(x)), 1)
def _decode_point_depth_conv(self, sequence):
result = []
first_depth = first_point = True
pc: int = self.channels
c: int = self.channels
for i, token in enumerate(sequence):
# compute output channels of this conv
if i + 1 == len(sequence):
assert token == "p", "Last conv must be point-wise conv."
c = self.oup_main
elif token == "p" and first_point:
c = cast(int, self.mid_channels)
if token == "d":
# depth-wise conv
if isinstance(pc, int) and isinstance(c, int):
# check can only be done for static channels
assert pc == c, "Depth-wise conv must not change channels."
result.append(nn.Conv2d(pc, c, self.kernel_size, self.stride if first_depth else 1, self.pad,
groups=c, bias=False))
result.append(nn.BatchNorm2d(c, affine=self.affine))
first_depth = False
elif token == "p":
# point-wise conv
result.append(nn.Conv2d(pc, c, 1, 1, 0, bias=False))
result.append(nn.BatchNorm2d(c, affine=self.affine))
result.append(nn.ReLU(inplace=True))
first_point = False
else:
raise ValueError("Conv sequence must be d and p.")
pc = c
return result
def _channel_shuffle(self, x):
bs, num_channels, height, width = x.size()
# NOTE: this line is commented for torchscript
# assert (num_channels % 4 == 0)
x = x.reshape(bs * num_channels // 2, 2, height * width)
x = x.permute(1, 0, 2)
x = x.reshape(2, -1, num_channels // 2, height, width)
return x[0], x[1]
class ShuffleXceptionBlock(ShuffleNetBlock):
"""
The ``choice_x`` version of shuffle net block, described in
`Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
"""
def __init__(self, in_channels: int, out_channels: int, mid_channels: nn.MaybeChoice[int], *, stride: int, affine: bool = True):
super().__init__(in_channels, out_channels, mid_channels,
kernel_size=3, stride=stride, sequence="dpdpdp", affine=affine)
@model_wrapper
class ShuffleNetSpace(nn.Module):
"""
The search space proposed in `Single Path One-shot <https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123610528.pdf>`__.
The basic building block design is inspired by a state-of-the-art manually-designed network --
`ShuffleNetV2 <https://openaccess.thecvf.com/content_ECCV_2018/html/Ningning_Light-weight_CNN_Architecture_ECCV_2018_paper.html>`__.
There are 20 choice blocks in total. Each choice block has 4 candidates, namely ``choice 3``, ``choice 5``,
``choice_7`` and ``choice_x`` respectively. They differ in kernel sizes and the number of depthwise convolutions.
The size of the search space is :math:`4^{20}`.
Parameters
----------
num_labels : int
Number of classes for the classification head. Default: 1000.
channel_search : bool
If true, for each building block, the number of ``mid_channels``
(output channels of the first 1x1 conv in each building block) varies from 0.2x to 1.6x (quantized to multiple of 0.2).
Here, "k-x" means k times the number of default channels.
Otherwise, 1.0x is used by default. Default: false.
affine : bool
Apply affine to all batch norm. Default: true.
"""
def __init__(self,
num_labels: int = 1000,
channel_search: bool = False,
affine: bool = True):
super().__init__()
self.num_labels = num_labels
self.channel_search = channel_search
self.affine = affine
# the block number in each stage. 4 stages in total. 20 blocks in total.
self.stage_repeats = [4, 4, 8, 4]
# output channels for all stages, including the very first layer and the very last layer
self.stage_out_channels = [-1, 16, 64, 160, 320, 640, 1024]
# building first layer
out_channels = self.stage_out_channels[1]
self.first_conv = nn.Sequential(
nn.Conv2d(3, out_channels, 3, 2, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
feature_blocks = []
global_block_idx = 0
for stage_idx, num_repeat in enumerate(self.stage_repeats):
for block_idx in range(num_repeat):
# count global index to give names to choices
global_block_idx += 1
# get ready for input and output
in_channels = out_channels
out_channels = self.stage_out_channels[stage_idx + 2]
stride = 2 if block_idx == 0 else 1
# mid channels can be searched
base_mid_channels = out_channels // 2
if self.channel_search:
k_choice_list = [int(base_mid_channels * (.2 * k)) for k in range(1, 9)]
mid_channels = nn.ValueChoice(k_choice_list, label=f'channel_{global_block_idx}')
else:
mid_channels = int(base_mid_channels)
mid_channels = cast(nn.MaybeChoice[int], mid_channels)
choice_block = nn.LayerChoice(dict(
k3=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=3, stride=stride, affine=affine),
k5=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=5, stride=stride, affine=affine),
k7=ShuffleNetBlock(in_channels, out_channels, mid_channels=mid_channels, kernel_size=7, stride=stride, affine=affine),
xcep=ShuffleXceptionBlock(in_channels, out_channels, mid_channels=mid_channels, stride=stride, affine=affine)
), label=f'layer_{global_block_idx}')
feature_blocks.append(choice_block)
self.features = nn.Sequential(*feature_blocks)
# final layers
last_conv_channels = self.stage_out_channels[-1]
self.conv_last = nn.Sequential(
nn.Conv2d(out_channels, last_conv_channels, 1, 1, 0, bias=False),
nn.BatchNorm2d(last_conv_channels, affine=affine),
nn.ReLU(inplace=True),
)
self.globalpool = nn.AdaptiveAvgPool2d((1, 1))
self.dropout = nn.Dropout(0.1)
self.classifier = nn.Sequential(
nn.Linear(last_conv_channels, num_labels, bias=False),
)
self._initialize_weights()
def forward(self, x):
x = self.first_conv(x)
x = self.features(x)
x = self.conv_last(x)
x = self.globalpool(x)
x = self.dropout(x)
x = x.contiguous().view(-1, self.stage_out_channels[-1])
x = self.classifier(x)
return x
def _initialize_weights(self):
for name, m in self.named_modules():
if isinstance(m, nn.Conv2d):
if 'first' in name:
torch.nn.init.normal_(m.weight, 0, 0.01)
else:
torch.nn.init.normal_(m.weight, 0, 1.0 / m.weight.shape[1])
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
if m.weight is not None:
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.BatchNorm1d):
if m.weight is not None:
torch.nn.init.constant_(m.weight, 1)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0.0001)
if m.running_mean is not None:
torch.nn.init.constant_(m.running_mean, 0)
elif isinstance(m, nn.Linear):
torch.nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
@classmethod
def fixed_arch(cls, arch: dict) -> FixedFactory:
return FixedFactory(cls, arch)
@classmethod
def load_searched_model(
cls, name: str,
pretrained: bool = False, download: bool = False, progress: bool = True
) -> nn.Module:
if name == 'spos':
# NOTE: Need BGR tensor, with no normalization
# https://github.com/ultmaster/spacehub-conversion/blob/371a4fd6646b4e11eda3f61187f7c9a1d484b1ca/cutils.py#L63
arch = {
'layer_1': 'k7',
'layer_2': 'k5',
'layer_3': 'k3',
'layer_4': 'k5',
'layer_5': 'k7',
'layer_6': 'k3',
'layer_7': 'k7',
'layer_8': 'k3',
'layer_9': 'k7',
'layer_10': 'k3',
'layer_11': 'k7',
'layer_12': 'xcep',
'layer_13': 'k3',
'layer_14': 'k3',
'layer_15': 'k3',
'layer_16': 'k3',
'layer_17': 'xcep',
'layer_18': 'k7',
'layer_19': 'xcep',
'layer_20': 'xcep'
}
else:
raise ValueError(f'Unsupported architecture with name: {name}')
model_factory = cls.fixed_arch(arch)
model = model_factory()
if pretrained:
weight_file = load_pretrained_weight(name, download=download, progress=progress)
pretrained_weights = torch.load(weight_file)
model.load_state_dict(pretrained_weights)
return model
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""This file should be merged to nni/nas/fixed.py"""
from typing import Type
from nni.nas.utils import ContextStack
class FixedFactory:
"""Make a model space ready to create a fixed model.
Examples
--------
>>> factory = FixedFactory(ModelSpaceClass, {"choice1": 3})
>>> model = factory(channels=16, classes=10)
"""
# TODO: mutations on ``init_args`` and ``init_kwargs`` themselves are not supported.
def __init__(self, cls: Type, arch: dict):
self.cls = cls
self.arch = arch
def __call__(self, *init_args, **init_kwargs):
with ContextStack('fixed', self.arch):
return self.cls(*init_args, **init_kwargs)
def __repr__(self):
return f'FixedFactory(class={self.cls}, arch={self.arch})'
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
"""
Weights available in this file are processed with scripts in https://github.com/ultmaster/spacehub-conversion,
and uploaded with :func:`nni.common.blob_utils.upload_file`.
"""
import os
from nni.common.blob_utils import NNI_BLOB, nni_cache_home, load_or_download_file
PRETRAINED_WEIGHT_URLS = {
# proxylessnas
'acenas-m1': f'{NNI_BLOB}/nashub/acenas-m1-e215f1b8.pth',
'acenas-m2': f'{NNI_BLOB}/nashub/acenas-m2-a8ee9e8f.pth',
'acenas-m3': f'{NNI_BLOB}/nashub/acenas-m3-66a5ed7b.pth',
'proxyless-cpu': f'{NNI_BLOB}/nashub/proxyless-cpu-2df03430.pth',
'proxyless-gpu': f'{NNI_BLOB}/nashub/proxyless-gpu-dbe6dd15.pth',
'proxyless-mobile': f'{NNI_BLOB}/nashub/proxyless-mobile-8668a978.pth',
# mobilenetv3
'mobilenetv3-large-100': f'{NNI_BLOB}/nashub/mobilenetv3-large-100-420e040a.pth',
'mobilenetv3-small-050': f'{NNI_BLOB}/nashub/mobilenetv3-small-050-05cb7a80.pth',
'mobilenetv3-small-075': f'{NNI_BLOB}/nashub/mobilenetv3-small-075-c87d8acb.pth',
'mobilenetv3-small-100': f'{NNI_BLOB}/nashub/mobilenetv3-small-100-8332faac.pth',
'cream-014': f'{NNI_BLOB}/nashub/cream-014-060aea24.pth',
'cream-043': f'{NNI_BLOB}/nashub/cream-043-bec949e1.pth',
'cream-114': f'{NNI_BLOB}/nashub/cream-114-fc272590.pth',
'cream-287': f'{NNI_BLOB}/nashub/cream-287-a0fcba33.pth',
'cream-481': f'{NNI_BLOB}/nashub/cream-481-d85779b6.pth',
'cream-604': f'{NNI_BLOB}/nashub/cream-604-9ee425f7.pth',
# nasnet
'darts-v2': f'{NNI_BLOB}/nashub/darts-v2-5465b0d2.pth',
# spos
'spos': f'{NNI_BLOB}/nashub/spos-0b17f6fc.pth',
# autoformer
'autoformer-tiny': f'{NNI_BLOB}/nashub/autoformer-searched-tiny-1e90ebc1.pth',
'autoformer-small': f'{NNI_BLOB}/nashub/autoformer-searched-small-4bc5d4e5.pth',
'autoformer-base': f'{NNI_BLOB}/nashub/autoformer-searched-base-c417590a.pth'
}
def load_pretrained_weight(name: str, **kwargs) -> str:
if name not in PRETRAINED_WEIGHT_URLS:
raise ValueError(f'"{name}" do not have a valid pretrained weight file.')
url = PRETRAINED_WEIGHT_URLS[name]
local_path = os.path.join(nni_cache_home(), 'nashub', url.split('/')[-1])
load_or_download_file(local_path, url, **kwargs)
return local_path
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from .mutator import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import warnings
from typing import (Any, Iterable, List, Optional, Tuple, cast)
from nni.nas.execution import Model, Mutation, ModelStatus
__all__ = ['Sampler', 'Mutator', 'InvalidMutation']
Choice = Any
class Sampler:
"""
Handles `Mutator.choice()` calls.
"""
def choice(self, candidates: List[Choice], mutator: 'Mutator', model: Model, index: int) -> Choice:
raise NotImplementedError()
def mutation_start(self, mutator: 'Mutator', model: Model) -> None:
pass
def mutation_end(self, mutator: 'Mutator', model: Model) -> None:
pass
class Mutator:
"""
Mutates graphs in model to generate new model.
`Mutator` class will be used in two places:
1. Inherit `Mutator` to implement graph mutation logic.
2. Use `Mutator` subclass to implement NAS strategy.
In scenario 1, the subclass should implement `Mutator.mutate()` interface with `Mutator.choice()`.
In scenario 2, strategy should use constructor or `Mutator.bind_sampler()` to initialize subclass,
and then use `Mutator.apply()` to mutate model.
For certain mutator subclasses, strategy or sampler can use `Mutator.dry_run()` to predict choice candidates.
# Method names are open for discussion.
If mutator has a label, in most cases, it means that this mutator is applied to nodes with this label.
"""
def __init__(self, sampler: Optional[Sampler] = None, label: str = cast(str, None)):
self.sampler: Optional[Sampler] = sampler
if label is None:
warnings.warn('Each mutator should have an explicit label. Mutator without label is deprecated.', DeprecationWarning)
self.label: str = label
self._cur_model: Optional[Model] = None
self._cur_choice_idx: Optional[int] = None
def bind_sampler(self, sampler: Sampler) -> 'Mutator':
"""
Set the sampler which will handle `Mutator.choice` calls.
"""
self.sampler = sampler
return self
def apply(self, model: Model) -> Model:
"""
Apply this mutator on a model.
Returns mutated model.
The model will be copied before mutation and the original model will not be modified.
"""
assert self.sampler is not None
copy = model.fork()
self._cur_model = copy
self._cur_choice_idx = 0
self._cur_samples = []
self.sampler.mutation_start(self, copy)
self.mutate(copy)
self.sampler.mutation_end(self, copy)
copy.history.append(Mutation(self, self._cur_samples, model, copy))
copy.status = ModelStatus.Frozen
self._cur_model = None
self._cur_choice_idx = None
return copy
def dry_run(self, model: Model) -> Tuple[List[List[Choice]], Model]:
"""
Dry run mutator on a model to collect choice candidates.
If you invoke this method multiple times on same or different models,
it may or may not return identical results, depending on how the subclass implements `Mutator.mutate()`.
"""
sampler_backup = self.sampler
recorder = _RecorderSampler()
self.sampler = recorder
new_model = self.apply(model)
self.sampler = sampler_backup
return recorder.recorded_candidates, new_model
def mutate(self, model: Model) -> None:
"""
Abstract method to be implemented by subclass.
Mutate a model in place.
"""
raise NotImplementedError()
def choice(self, candidates: Iterable[Choice]) -> Choice:
"""
Ask sampler to make a choice.
"""
assert self.sampler is not None and self._cur_model is not None and self._cur_choice_idx is not None
ret = self.sampler.choice(list(candidates), self, self._cur_model, self._cur_choice_idx)
self._cur_samples.append(ret)
self._cur_choice_idx += 1
return ret
class _RecorderSampler(Sampler):
def __init__(self):
self.recorded_candidates: List[List[Choice]] = []
def choice(self, candidates: List[Choice], *args) -> Choice:
self.recorded_candidates.append(candidates)
return candidates[0]
class InvalidMutation(Exception):
pass
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from nni.common.framework import shortcut_framework
shortcut_framework(__name__)
del shortcut_framework
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from .mutator import EnasMutator from .choice import *
from .trainer import EnasTrainer from .repeat import *
from .cell import *
from .layers import *
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import warnings
from typing import Callable, Dict, List, Union, Optional, Tuple, Sequence, cast
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
import torch.nn as nn
from .choice import ChosenInputs, LayerChoice, InputChoice
from .layers import ModuleList # pylint: disable=no-name-in-module
from .mutation_utils import generate_new_label
class _ListIdentity(nn.Identity):
# workaround for torchscript
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return x
class _DefaultPostprocessor(nn.Module):
# this is also a workaround for torchscript
def forward(self, this_cell: torch.Tensor, prev_cell: List[torch.Tensor]) -> torch.Tensor:
return this_cell
CellOpFactory = Callable[[int, int, Optional[int]], nn.Module]
def create_cell_op_candidates(
op_candidates, node_index, op_index, chosen
) -> Tuple[Dict[str, nn.Module], bool]:
has_factory = False
# convert the complex type into the type that is acceptable to LayerChoice
def convert_single_op(op):
nonlocal has_factory
if isinstance(op, nn.Module):
return copy.deepcopy(op)
elif callable(op):
# Yes! It's using factory to create operations now.
has_factory = True
# FIXME: I don't know how to check whether we are in graph engine.
return op(node_index, op_index, chosen)
else:
raise TypeError(f'Unrecognized type {type(op)} for op {op}')
if isinstance(op_candidates, list):
res = {str(i): convert_single_op(op) for i, op in enumerate(op_candidates)}
elif isinstance(op_candidates, dict):
res = {key: convert_single_op(op) for key, op in op_candidates.items()}
elif callable(op_candidates):
warnings.warn(f'Directly passing a callable into Cell is deprecated. Please consider migrating to list or dict.',
DeprecationWarning)
res = op_candidates()
has_factory = True
else:
raise TypeError(f'Unrecognized type {type(op_candidates)} for {op_candidates}')
return res, has_factory
def preprocess_cell_inputs(num_predecessors: int, *inputs: Union[List[torch.Tensor], torch.Tensor]) -> List[torch.Tensor]:
if len(inputs) == 1 and isinstance(inputs[0], list):
processed_inputs = list(inputs[0]) # shallow copy
else:
processed_inputs = cast(List[torch.Tensor], list(inputs))
assert len(processed_inputs) == num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
return processed_inputs
class Cell(nn.Module):
"""
Cell structure that is popularly used in NAS literature.
Find the details in:
* `Neural Architecture Search with Reinforcement Learning <https://arxiv.org/abs/1611.01578>`__.
* `Learning Transferable Architectures for Scalable Image Recognition <https://arxiv.org/abs/1707.07012>`__.
* `DARTS: Differentiable Architecture Search <https://arxiv.org/abs/1806.09055>`__
`On Network Design Spaces for Visual Recognition <https://arxiv.org/abs/1905.13214>`__
is a good summary of how this structure works in practice.
A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from
``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell.
The output of cell is the concatenation of some of the nodes in the cell (by default all the nodes).
Two examples of searched cells are illustrated in the figure below.
In these two cells, ``op_candidates`` are series of convolutions and pooling operations.
``num_nodes_per_node`` is set to 2. ``num_nodes`` is set to 5. ``merge_op`` is ``loose_end``.
Assuming nodes are enumerated from bottom to top, left to right,
``output_node_indices`` for the normal cell is ``[2, 3, 4, 5, 6]``.
For the reduction cell, it's ``[4, 5, 6]``.
Please take a look at this
`review article <https://sh-tsang.medium.com/review-nasnet-neural-architecture-search-network-image-classification-23139ea0425d>`__
if you are interested in details.
.. image:: ../../../img/nasnet_cell.png
:width: 900
:align: center
Here is a glossary table, which could help better understand the terms used above:
.. list-table::
:widths: 25 75
:header-rows: 1
* - Name
- Brief Description
* - Cell
- A cell consists of ``num_nodes`` nodes.
* - Node
- A node is the **sum** of ``num_ops_per_node`` operators.
* - Operator
- Each operator is independently chosen from a list of user-specified candidate operators.
* - Operator's input
- Each operator has one input, chosen from previous nodes as well as predecessors.
* - Predecessors
- Input of cell. A cell can have multiple predecessors. Predecessors are sent to *preprocessor* for preprocessing.
* - Cell's output
- Output of cell. Usually concatenation of some nodes (possibly all nodes) in the cell. Cell's output,
along with predecessors, are sent to *postprocessor* for postprocessing.
* - Preprocessor
- Extra preprocessing to predecessors. Usually used in shape alignment (e.g., predecessors have different shapes).
By default, do nothing.
* - Postprocessor
- Extra postprocessing for cell's output. Usually used to chain cells with multiple Predecessors
(e.g., the next cell wants to have the outputs of both this cell and previous cell as its input).
By default, directly use this cell's output.
.. tip::
It's highly recommended to make the candidate operators have an output of the same shape as input.
This is because, there can be dynamic connections within cell. If there's shape change within operations,
the input shape of the subsequent operation becomes unknown.
In addition, the final concatenation could have shape mismatch issues.
Parameters
----------
op_candidates : list of module or function, or dict
A list of modules to choose from, or a function that accepts current index and optionally its input index, and returns a module.
For example, (2, 3, 0) means the 3rd op in the 2nd node, accepts the 0th node as input.
The index are enumerated for all nodes including predecessors from 0.
When first created, the input index is ``None``, meaning unknown.
Note that in graph execution engine, support of function in ``op_candidates`` is limited.
Please also note that, to make :class:`Cell` work with one-shot strategy,
``op_candidates``, in case it's a callable, should not depend on the second input argument,
i.e., ``op_index`` in current node.
num_nodes : int
Number of nodes in the cell.
num_ops_per_node: int
Number of operators in each node. The output of each node is the sum of all operators in the node. Default: 1.
num_predecessors : int
Number of inputs of the cell. The input to forward should be a list of tensors. Default: 1.
merge_op : "all", or "loose_end"
If "all", all the nodes (except predecessors) will be concatenated as the cell's output, in which case, ``output_node_indices``
will be ``list(range(num_predecessors, num_predecessors + num_nodes))``.
If "loose_end", only the nodes that have never been used as other nodes' inputs will be concatenated to the output.
Predecessors are not considered when calculating unused nodes.
Details can be found in `NDS paper <https://arxiv.org/abs/1905.13214>`__. Default: all.
preprocessor : callable
Override this if some extra transformation on cell's input is intended.
It should be a callable (``nn.Module`` is also acceptable) that takes a list of tensors which are predecessors,
and outputs a list of tensors, with the same length as input.
By default, it does nothing to the input.
postprocessor : callable
Override this if customization on the output of the cell is intended.
It should be a callable that takes the output of this cell, and a list which are predecessors.
Its return type should be either one tensor, or a tuple of tensors.
The return value of postprocessor is the return value of the cell's forward.
By default, it returns only the output of the current cell.
concat_dim : int
The result will be a concatenation of several nodes on this dim. Default: 1.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
Examples
--------
Choose between conv2d and maxpool2d.
The cell have 4 nodes, 1 op per node, and 2 predecessors.
>>> cell = nn.Cell([nn.Conv2d(32, 32, 3, padding=1), nn.MaxPool2d(3, padding=1)], 4, 1, 2)
In forward:
>>> cell([input1, input2])
The "list bracket" can be omitted:
>>> cell(only_input) # only one input
>>> cell(tensor1, tensor2, tensor3) # multiple inputs
Use ``merge_op`` to specify how to construct the output.
The output will then have dynamic shape, depending on which input has been used in the cell.
>>> cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2, merge_op='loose_end')
>>> cell_out_channels = len(cell.output_node_indices) * 32
The op candidates can be callable that accepts node index in cell, op index in node, and input index.
>>> cell = nn.Cell([
... lambda node_index, op_index, input_index: nn.Conv2d(32, 32, 3, stride=2 if input_index < 1 else 1),
... ], 4, 1, 2)
Predecessor example: ::
class Preprocessor:
def __init__(self):
self.conv1 = nn.Conv2d(16, 32, 1)
self.conv2 = nn.Conv2d(64, 32, 1)
def forward(self, x):
return [self.conv1(x[0]), self.conv2(x[1])]
cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2, preprocessor=Preprocessor())
cell([torch.randn(1, 16, 48, 48), torch.randn(1, 64, 48, 48)]) # the two inputs will be sent to conv1 and conv2 respectively
Warnings
--------
:class:`Cell` is not supported in :ref:`graph-based execution engine <graph-based-execution-engine>`.
Attributes
----------
output_node_indices : list of int
An attribute that contains indices of the nodes concatenated to the output (a list of integers).
When the cell is first instantiated in the base model, or when ``merge_op`` is ``all``,
``output_node_indices`` must be ``range(num_predecessors, num_predecessors + num_nodes)``.
When ``merge_op`` is ``loose_end``, ``output_node_indices`` is useful to compute the shape of this cell's output,
because the output shape depends on the connection in the cell, and which nodes are "loose ends" depends on mutation.
op_candidates_factory : CellOpFactory or None
If the operations are created with a factory (callable), this is to be set with the factory.
One-shot algorithms will use this to make each node a cartesian product of operations and inputs.
"""
def __init__(self,
op_candidates: Union[
Callable[[], List[nn.Module]],
List[nn.Module],
List[CellOpFactory],
Dict[str, nn.Module],
Dict[str, CellOpFactory]
],
num_nodes: int,
num_ops_per_node: int = 1,
num_predecessors: int = 1,
merge_op: Literal['all', 'loose_end'] = 'all',
preprocessor: Optional[Callable[[List[torch.Tensor]], List[torch.Tensor]]] = None,
postprocessor: Optional[Callable[[torch.Tensor, List[torch.Tensor]],
Union[Tuple[torch.Tensor, ...], torch.Tensor]]] = None,
concat_dim: int = 1,
*,
label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
# modules are created in "natural" order
# first create preprocessor
self.preprocessor = preprocessor or _ListIdentity()
# then create intermediate ops
self.ops = ModuleList()
self.inputs = ModuleList()
# finally postprocessor
self.postprocessor = postprocessor or _DefaultPostprocessor()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
self.num_predecessors = num_predecessors
assert merge_op in ['all', 'loose_end']
self.merge_op = merge_op
self.output_node_indices = list(range(num_predecessors, num_predecessors + num_nodes))
self.concat_dim = concat_dim
self.op_candidates_factory: Union[List[CellOpFactory], Dict[str, CellOpFactory], None] = None # set later
# fill-in the missing modules
self._create_modules(op_candidates)
def _create_modules(self, op_candidates):
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
self.ops.append(ModuleList())
self.inputs.append(ModuleList())
for k in range(self.num_ops_per_node):
inp = InputChoice(i, 1, label=f'{self.label}/input_{i}_{k}')
chosen = None
if isinstance(inp, ChosenInputs):
# now we are in the fixed mode
# the length of chosen should be 1
chosen = inp.chosen[0]
if self.merge_op == 'loose_end' and chosen in self.output_node_indices:
# remove it from concat indices
self.output_node_indices.remove(chosen)
# this is needed because op_candidates can be very complex
# the type annoation and docs for details
ops, has_factory = create_cell_op_candidates(op_candidates, i, k, chosen)
if has_factory:
self.op_candidates_factory = op_candidates
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
cast(ModuleList, self.ops[-1]).append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
cast(ModuleList, self.inputs[-1]).append(inp)
@property
def label(self):
return self._label
def forward(self, *inputs: Union[List[torch.Tensor], torch.Tensor]) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
"""Forward propagation of cell.
Parameters
----------
inputs
Can be a list of tensors, or several tensors.
The length should be equal to ``num_predecessors``.
Returns
-------
Tuple[torch.Tensor] | torch.Tensor
The return type depends on the output of ``postprocessor``.
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell.
"""
processed_inputs: List[torch.Tensor] = preprocess_cell_inputs(self.num_predecessors, *inputs)
states: List[torch.Tensor] = self.preprocessor(processed_inputs)
for ops, inps in zip(
cast(Sequence[Sequence[LayerChoice]], self.ops),
cast(Sequence[Sequence[InputChoice]], self.inputs)
):
current_state = []
for op, inp in zip(ops, inps):
current_state.append(op(inp(states)))
current_state = torch.sum(torch.stack(current_state), 0)
states.append(current_state)
if self.merge_op == 'all':
# a special case for graph engine
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
else:
this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim)
return self.postprocessor(this_cell, processed_inputs)
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import itertools
import math
import operator
import warnings
from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
NoReturn, Optional, Sequence, SupportsRound, TypeVar,
Union, cast)
import torch
import torch.nn as nn
from nni.common.hpo_utils import ParameterSpec
from nni.common.serializer import Translatable
from nni.nas.utils import STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace, NoContextError, basic_unit
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
__all__ = [
# APIs
'LayerChoice',
'InputChoice',
'ValueChoice',
'ModelParameterChoice',
'Placeholder',
# Fixed module
'ChosenInputs',
# Type utils
'ReductionType',
'MaybeChoice',
'ChoiceOf',
]
class LayerChoice(Mutable):
"""
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
It allows users to put several candidate operations (e.g., PyTorch modules), one of them is chosen in each explored model.
*New in v2.2:* Layer choice can be nested.
Parameters
----------
candidates : list of nn.Module or OrderedDict
A module list to be selected from.
prior : list of float
Prior distribution used in random sampling.
label : str
Identifier of the layer choice.
Attributes
----------
length : int
Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended.
names : list of str
Names of candidates.
choices : list of Module
Deprecated. A list of all candidate modules in the layer choice module.
``list(layer_choice)`` is recommended, which will serve the same purpose.
Examples
--------
::
# import nni.retiarii.nn.pytorch as nn
# declared in `__init__` method
self.layer = nn.LayerChoice([
ops.PoolBN('max', channels, 3, stride, 1),
ops.SepConv(channels, channels, 3, stride, 1),
nn.Identity()
])
# invoked in `forward` method
out = self.layer(x)
Notes
-----
``candidates`` can be a list of modules or a ordered dict of named modules, for example,
.. code-block:: python
self.op_choice = LayerChoice(OrderedDict([
("conv3x3", nn.Conv2d(3, 16, 128)),
("conv5x5", nn.Conv2d(5, 16, 128)),
("conv7x7", nn.Conv2d(7, 16, 128))
]))
Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
# FIXME: prior is designed but not supported yet
@classmethod
def create_fixed_module(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
label: Optional[str] = None, **kwargs):
chosen = get_fixed_value(label)
if isinstance(candidates, list):
result = candidates[int(chosen)]
else:
result = candidates[chosen]
# map the named hierarchies to support weight inheritance for python engine
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# handle cases where layer choices are nested
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'{chosen}.{v}' for k, v in prev_mapping.items()})
else:
# "result" needs to know where to map itself.
# Ideally, we should put a _mapping_ in the module where "result" is located,
# but it's impossible to put mapping into parent module here.
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': str(chosen)})
return result
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
super(LayerChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
label = kwargs['key']
if 'return_mask' in kwargs:
warnings.warn(f'"return_mask" is deprecated. Ignoring...')
if 'reduction' in kwargs:
warnings.warn(f'"reduction" is deprecated. Ignoring...')
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
self._label = generate_new_label(label)
self.names = []
if isinstance(candidates, dict):
for name, module in candidates.items():
assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
"Please don't use a reserved name '{}' for your module.".format(name)
self.add_module(name, module)
self.names.append(name)
elif isinstance(candidates, list):
for i, module in enumerate(candidates):
self.add_module(str(i), module)
self.names.append(str(i))
else:
raise TypeError("Unsupported candidates type: {}".format(type(candidates)))
self._first_module = cast(nn.Module, self._modules[self.names[0]]) # to make the dummy forward meaningful
@property
def label(self):
return self._label
def __getitem__(self, idx: Union[int, str]) -> nn.Module:
if isinstance(idx, str):
return cast(nn.Module, self._modules[idx])
return cast(nn.Module, list(self)[idx])
def __setitem__(self, idx, module):
key = idx if isinstance(idx, str) else self.names[idx]
return setattr(self, key, module)
def __delitem__(self, idx):
if isinstance(idx, slice):
for key in self.names[idx]:
delattr(self, key)
else:
if isinstance(idx, str):
key, idx = idx, self.names.index(idx)
else:
key = self.names[idx]
delattr(self, key)
del self.names[idx]
def __len__(self):
return len(self.names)
def __iter__(self):
return map(lambda name: self._modules[name], self.names)
def forward(self, x):
"""
The forward of layer choice is simply running the first candidate module.
It shouldn't be called directly by users in most cases.
"""
warnings.warn('You should not run forward of this module directly.')
return self._first_module(x)
def __repr__(self):
return f'LayerChoice({self.candidates}, label={repr(self.label)})'
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
ReductionType = Literal['mean', 'concat', 'sum', 'none']
class InputChoice(Mutable):
"""
Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys).
It is mainly for choosing (or trying) different connections. It takes several tensors and chooses ``n_chosen`` tensors from them.
When specific inputs are chosen, ``InputChoice`` will become :class:`ChosenInputs`.
Use ``reduction`` to specify how chosen inputs are reduced into one output. A few options are:
* ``none``: do nothing and return the list directly.
* ``sum``: summing all the chosen inputs.
* ``mean``: taking the average of all chosen inputs.
* ``concat``: concatenate all chosen inputs at dimension 1.
We don't support customizing reduction yet.
Parameters
----------
n_candidates : int
Number of inputs to choose from. It is required.
n_chosen : int
Recommended inputs to choose. If None, mutator is instructed to select any.
reduction : str
``mean``, ``concat``, ``sum`` or ``none``.
prior : list of float
Prior distribution used in random sampling.
label : str
Identifier of the input choice.
Examples
--------
::
# import nni.retiarii.nn.pytorch as nn
# declared in `__init__` method
self.input_switch = nn.InputChoice(n_chosen=1)
# invoked in `forward` method, choose one from the three
out = self.input_switch([tensor1, tensor2, tensor3])
"""
@classmethod
def create_fixed_module(cls, n_candidates: int, n_chosen: Optional[int] = 1,
reduction: ReductionType = 'sum', *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
return ChosenInputs(get_fixed_value(label), reduction=reduction)
def __init__(self, n_candidates: int, n_chosen: Optional[int] = 1,
reduction: str = 'sum', *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
super(InputChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
label = kwargs['key']
if 'return_mask' in kwargs:
warnings.warn(f'"return_mask" is deprecated. Ignoring...')
if 'choose_from' in kwargs:
warnings.warn(f'"reduction" is deprecated. Ignoring...')
self.n_candidates = n_candidates
self.n_chosen = n_chosen
self.reduction = reduction
self.prior = prior or [1 / n_candidates for _ in range(n_candidates)]
assert self.reduction in ['mean', 'concat', 'sum', 'none']
self._label = generate_new_label(label)
@property
def label(self):
return self._label
def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor:
"""
The forward of input choice is simply the first item of ``candidate_inputs``.
It shouldn't be called directly by users in most cases.
"""
warnings.warn('You should not run forward of this module directly.')
return candidate_inputs[0]
def __repr__(self):
return f'InputChoice(n_candidates={self.n_candidates}, n_chosen={self.n_chosen}, ' \
f'reduction={repr(self.reduction)}, label={repr(self.label)})'
class ChosenInputs(nn.Module):
"""
A module that chooses from a tensor list and outputs a reduced tensor.
The already-chosen version of InputChoice.
When forward, ``chosen`` will be used to select inputs from ``candidate_inputs``,
and ``reduction`` will be used to choose from those inputs to form a tensor.
Attributes
----------
chosen : list of int
Indices of chosen inputs.
reduction : ``mean`` | ``concat`` | ``sum`` | ``none``
How to reduce the inputs when multiple are selected.
"""
def __init__(self, chosen: Union[List[int], int], reduction: ReductionType):
super().__init__()
self.chosen = chosen if isinstance(chosen, list) else [chosen]
self.reduction = reduction
def forward(self, candidate_inputs):
"""
Compute the reduced input based on ``chosen`` and ``reduction``.
"""
return self._tensor_reduction(self.reduction, [candidate_inputs[i] for i in self.chosen])
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == 'none':
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == 'sum':
return sum(tensor_list)
if reduction_type == 'mean':
return sum(tensor_list) / len(tensor_list)
if reduction_type == 'concat':
return torch.cat(tensor_list, dim=1)
raise ValueError(f'Unrecognized reduction policy: "{reduction_type}"')
# the code in ValueChoice can be generated with this codegen
# this is not done online because I want to have type-hint supports
# $ python -c "from nni.retiarii.nn.pytorch.api import _valuechoice_codegen; _valuechoice_codegen(_internal=True)"
def _valuechoice_codegen(*, _internal: bool = False):
if not _internal:
raise RuntimeError("This method is set to be internal. Please don't use it directly.")
MAPPING = {
# unary
'neg': '-', 'pos': '+', 'invert': '~',
# binary
'add': '+', 'sub': '-', 'mul': '*', 'matmul': '@',
'truediv': '//', 'floordiv': '/', 'mod': '%',
'lshift': '<<', 'rshift': '>>',
'and': '&', 'xor': '^', 'or': '|',
# no reverse
'lt': '<', 'le': '<=', 'eq': '==',
'ne': '!=', 'ge': '>=', 'gt': '>',
# NOTE
# Currently we don't support operators like __contains__ (b in a),
# Might support them in future when we actually need them.
}
binary_template = """ def __{op}__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [self, other])"""
binary_r_template = """ def __r{op}__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [other, self])"""
unary_template = """ def __{op}__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.{op}, '{sym}{{}}', [self]))"""
for op, sym in MAPPING.items():
if op in ['neg', 'pos', 'invert']:
print(unary_template.format(op=op, sym=sym) + '\n')
else:
opt = op + '_' if op in ['and', 'or'] else op
print(binary_template.format(op=op, opt=opt, sym=sym) + '\n')
if op not in ['lt', 'le', 'eq', 'ne', 'ge', 'gt']:
print(binary_r_template.format(op=op, opt=opt, sym=sym) + '\n')
_func = TypeVar('_func')
_cand = TypeVar('_cand')
_value = TypeVar('_value')
def _valuechoice_staticmethod_helper(orig_func: _func) -> _func:
if orig_func.__doc__ is not None:
orig_func.__doc__ += """
Notes
-----
This function performs lazy evaluation.
Only the expression will be recorded when the function is called.
The real evaluation happens when the inner value choice has determined its final decision.
If no value choice is contained in the parameter list, the evaluation will be intermediate."""
return orig_func
class ValueChoiceX(Generic[_cand], Translatable, nn.Module):
"""Internal API. Implementation note:
The transformed (X) version of value choice.
It can be the result of composition (transformation) of one or several value choices. For example,
.. code-block:: python
nn.ValueChoice([1, 2]) + nn.ValueChoice([3, 4]) + 5
The instance of base class cannot be created directly. Instead, they should be only the result of transformation of value choice.
Therefore, there is no need to implement ``create_fixed_module`` in this class, because,
1. For python-engine, value choice itself has create fixed module. Consequently, the transformation is born to be fixed.
2. For graph-engine, it uses evaluate to calculate the result.
Potentially, we have to implement the evaluation logic in oneshot algorithms. I believe we can postpone the discussion till then.
This class is implemented as a ``nn.Module`` so that it can be scanned by python engine / torchscript.
"""
def __init__(self, function: Callable[..., _cand] = cast(Callable[..., _cand], None),
repr_template: str = cast(str, None),
arguments: List[Any] = cast('List[MaybeChoice[_cand]]', None),
dry_run: bool = True):
super().__init__()
if function is None:
# this case is a hack for ValueChoice subclass
# it will reach here only because ``__init__`` in ``nn.Module`` is useful.
return
self.function = function
self.repr_template = repr_template
self.arguments = arguments
assert any(isinstance(arg, ValueChoiceX) for arg in self.arguments)
if dry_run:
# for sanity check
self.dry_run()
def forward(self) -> None:
raise RuntimeError('You should never call forward of the composition of a value-choice.')
def inner_choices(self) -> Iterable['ValueChoice']:
"""
Return a generator of all leaf value choices.
Useful for composition of value choices.
No deduplication on labels. Mutators should take care.
"""
for arg in self.arguments:
if isinstance(arg, ValueChoiceX):
yield from arg.inner_choices()
def dry_run(self) -> _cand:
"""
Dry run the value choice to get one of its possible evaluation results.
"""
# values are not used
return self._evaluate(iter([]), True)
def all_options(self) -> Iterable[_cand]:
"""Explore all possibilities of a value choice.
"""
# Record all inner choices: label -> candidates, no duplicates.
dedup_inner_choices: Dict[str, List[_cand]] = {}
# All labels of leaf nodes on tree, possibly duplicates.
all_labels: List[str] = []
for choice in self.inner_choices():
all_labels.append(choice.label)
if choice.label in dedup_inner_choices:
if choice.candidates != dedup_inner_choices[choice.label]:
# check for choice with the same label
raise ValueError(f'"{choice.candidates}" is not equal to "{dedup_inner_choices[choice.label]}", '
f'but they share the same label: {choice.label}')
else:
dedup_inner_choices[choice.label] = choice.candidates
dedup_labels, dedup_candidates = list(dedup_inner_choices.keys()), list(dedup_inner_choices.values())
for chosen in itertools.product(*dedup_candidates):
chosen = dict(zip(dedup_labels, chosen))
yield self.evaluate([chosen[label] for label in all_labels])
def evaluate(self, values: Iterable[_cand]) -> _cand:
"""
Evaluate the result of this group.
``values`` should in the same order of ``inner_choices()``.
"""
return self._evaluate(iter(values), False)
def _evaluate(self, values: Iterator[_cand], dry_run: bool = False) -> _cand:
# "values" iterates in the recursion
eval_args = []
for arg in self.arguments:
if isinstance(arg, ValueChoiceX):
# recursive evaluation
eval_args.append(arg._evaluate(values, dry_run))
# the recursion will stop when it hits a leaf node (value choice)
# the implementation is in `ValueChoice`
else:
# constant value
eval_args.append(arg)
return self.function(*eval_args)
def _translate(self):
"""
Try to behave like one of its candidates when used in ``basic_unit``.
"""
return self.dry_run()
def __repr__(self) -> str:
reprs = []
for arg in self.arguments:
if isinstance(arg, ValueChoiceX) and not isinstance(arg, ValueChoice):
reprs.append('(' + repr(arg) + ')') # add parenthesis for operator priority
else:
reprs.append(repr(arg))
return self.repr_template.format(*reprs)
# the following are a series of methods to create "ValueChoiceX"
# which is a transformed version of value choice
# https://docs.python.org/3/reference/datamodel.html#special-method-names
# Special operators that can be useful in place of built-in conditional operators.
@staticmethod
@_valuechoice_staticmethod_helper
def to_int(obj: 'MaybeChoice[Any]') -> 'MaybeChoice[int]':
"""
Convert a ``ValueChoice`` to an integer.
"""
if isinstance(obj, ValueChoiceX):
return ValueChoiceX(int, 'int({})', [obj])
return int(obj)
@staticmethod
@_valuechoice_staticmethod_helper
def to_float(obj: 'MaybeChoice[Any]') -> 'MaybeChoice[float]':
"""
Convert a ``ValueChoice`` to a float.
"""
if isinstance(obj, ValueChoiceX):
return ValueChoiceX(float, 'float({})', [obj])
return float(obj)
@staticmethod
@_valuechoice_staticmethod_helper
def condition(pred: 'MaybeChoice[bool]',
true: 'MaybeChoice[_value]',
false: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
"""
Return ``true`` if the predicate ``pred`` is true else ``false``.
Examples
--------
>>> ValueChoice.condition(ValueChoice([1, 2]) > ValueChoice([0, 3]), 2, 1)
"""
if any(isinstance(obj, ValueChoiceX) for obj in [pred, true, false]):
return ValueChoiceX(lambda t, c, f: t if c else f, '{} if {} else {}', [true, pred, false])
return true if pred else false
@staticmethod
@_valuechoice_staticmethod_helper
def max(arg0: Union[Iterable['MaybeChoice[_value]'], 'MaybeChoice[_value]'],
*args: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
"""
Returns the maximum value from a list of value choices.
The usage should be similar to Python's built-in value choices,
where the parameters could be an iterable, or at least two arguments.
"""
if not args:
if not isinstance(arg0, Iterable):
raise TypeError('Expect more than one items to compare max')
return cast(MaybeChoice[_value], ValueChoiceX.max(*list(arg0)))
lst = list(arg0) if isinstance(arg0, Iterable) else [arg0] + list(args)
if any(isinstance(obj, ValueChoiceX) for obj in lst):
return ValueChoiceX(max, 'max({})', lst)
return max(cast(Any, lst))
@staticmethod
@_valuechoice_staticmethod_helper
def min(arg0: Union[Iterable['MaybeChoice[_value]'], 'MaybeChoice[_value]'],
*args: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
"""
Returns the minunum value from a list of value choices.
The usage should be similar to Python's built-in value choices,
where the parameters could be an iterable, or at least two arguments.
"""
if not args:
if not isinstance(arg0, Iterable):
raise TypeError('Expect more than one items to compare min')
return cast(MaybeChoice[_value], ValueChoiceX.min(*list(arg0)))
lst = list(arg0) if isinstance(arg0, Iterable) else [arg0] + list(args)
if any(isinstance(obj, ValueChoiceX) for obj in lst):
return ValueChoiceX(min, 'min({})', lst)
return min(cast(Any, lst))
def __hash__(self):
# this is required because we have implemented ``__eq__``
return id(self)
# NOTE:
# Write operations are not supported. Reasons follow:
# - Semantics are not clear. It can be applied to "all" the inner candidates, or only the chosen one.
# - Implementation effort is too huge.
# As a result, inplace operators like +=, *=, magic methods like `__getattr__` are not included in this list.
def __getitem__(self: 'ChoiceOf[Any]', key: Any) -> 'ChoiceOf[Any]':
return ValueChoiceX(lambda x, y: x[y], '{}[{}]', [self, key])
# region implement int, float, round, trunc, floor, ceil
# because I believe sometimes we need them to calculate #channels
# `__int__` and `__float__` are not supported because `__int__` is required to return int.
def __round__(self: 'ChoiceOf[SupportsRound[_value]]',
ndigits: Optional['MaybeChoice[int]'] = None) -> 'ChoiceOf[Union[int, SupportsRound[_value]]]':
if ndigits is not None:
return cast(ChoiceOf[Union[int, SupportsRound[_value]]], ValueChoiceX(round, 'round({}, {})', [self, ndigits]))
return cast(ChoiceOf[Union[int, SupportsRound[_value]]], ValueChoiceX(round, 'round({})', [self]))
def __trunc__(self) -> NoReturn:
raise RuntimeError("Try to use `ValueChoice.to_int()` instead of `math.trunc()` on value choices.")
def __floor__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[int]':
return ValueChoiceX(math.floor, 'math.floor({})', [self])
def __ceil__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[int]':
return ValueChoiceX(math.ceil, 'math.ceil({})', [self])
def __index__(self) -> NoReturn:
# https://docs.python.org/3/reference/datamodel.html#object.__index__
raise RuntimeError("`__index__` is not allowed on ValueChoice, which means you can't "
"use int(), float(), complex(), range() on a ValueChoice. "
"To cast the type of ValueChoice, please try `ValueChoice.to_int()` or `ValueChoice.to_float()`.")
def __bool__(self) -> NoReturn:
raise RuntimeError('Cannot use bool() on ValueChoice. That means, using ValueChoice in a if-clause is illegal. '
'Please try methods like `ValueChoice.max(a, b)` to see whether that meets your needs.')
# endregion
# region the following code is generated with codegen (see above)
# Annotated with "region" because I want to collapse them in vscode
def __neg__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.neg, '-{}', [self]))
def __pos__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.pos, '+{}', [self]))
def __invert__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
return cast(ChoiceOf[_value], ValueChoiceX(operator.invert, '~{}', [self]))
def __add__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.add, '{} + {}', [self, other])
def __radd__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.add, '{} + {}', [other, self])
def __sub__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.sub, '{} - {}', [self, other])
def __rsub__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.sub, '{} - {}', [other, self])
def __mul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mul, '{} * {}', [self, other])
def __rmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mul, '{} * {}', [other, self])
def __matmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.matmul, '{} @ {}', [self, other])
def __rmatmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.matmul, '{} @ {}', [other, self])
def __truediv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.truediv, '{} // {}', [self, other])
def __rtruediv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.truediv, '{} // {}', [other, self])
def __floordiv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.floordiv, '{} / {}', [self, other])
def __rfloordiv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.floordiv, '{} / {}', [other, self])
def __mod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mod, '{} % {}', [self, other])
def __rmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.mod, '{} % {}', [other, self])
def __lshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.lshift, '{} << {}', [self, other])
def __rlshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.lshift, '{} << {}', [other, self])
def __rshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.rshift, '{} >> {}', [self, other])
def __rrshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.rshift, '{} >> {}', [other, self])
def __and__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.and_, '{} & {}', [self, other])
def __rand__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.and_, '{} & {}', [other, self])
def __xor__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.xor, '{} ^ {}', [self, other])
def __rxor__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.xor, '{} ^ {}', [other, self])
def __or__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.or_, '{} | {}', [self, other])
def __ror__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.or_, '{} | {}', [other, self])
def __lt__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.lt, '{} < {}', [self, other])
def __le__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.le, '{} <= {}', [self, other])
def __eq__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.eq, '{} == {}', [self, other])
def __ne__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.ne, '{} != {}', [self, other])
def __ge__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.ge, '{} >= {}', [self, other])
def __gt__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(operator.gt, '{} > {}', [self, other])
# endregion
# __pow__, __divmod__, __abs__ are special ones.
# Not easy to cover those cases with codegen.
def __pow__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]', modulo: Optional['MaybeChoice[Any]'] = None) -> 'ChoiceOf[Any]':
if modulo is not None:
return ValueChoiceX(pow, 'pow({}, {}, {})', [self, other, modulo])
return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [self, other])
def __rpow__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]', modulo: Optional['MaybeChoice[Any]'] = None) -> 'ChoiceOf[Any]':
if modulo is not None:
return ValueChoiceX(pow, 'pow({}, {}, {})', [other, self, modulo])
return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [other, self])
def __divmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(divmod, 'divmod({}, {})', [self, other])
def __rdivmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(divmod, 'divmod({}, {})', [other, self])
def __abs__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[Any]':
return ValueChoiceX(abs, 'abs({})', [self])
ChoiceOf = ValueChoiceX
MaybeChoice = Union[ValueChoiceX[_cand], _cand]
class ValueChoice(ValueChoiceX[_cand], Mutable):
"""
ValueChoice is to choose one from ``candidates``. The most common use cases are:
* Used as input arguments of :class:`~nni.retiarii.basic_unit`
(i.e., modules in ``nni.retiarii.nn.pytorch`` and user-defined modules decorated with ``@basic_unit``).
* Used as input arguments of evaluator (*new in v2.7*).
It can be used in parameters of operators (i.e., a sub-module of the model): ::
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, nn.ValueChoice([32, 64]), kernel_size=nn.ValueChoice([3, 5, 7]))
def forward(self, x):
return self.conv(x)
Or evaluator (only if the evaluator is :doc:`traceable </nas/serialization>`, e.g.,
:class:`FunctionalEvaluator <nni.retiarii.evaluator.FunctionalEvaluator>`): ::
def train_and_evaluate(model_cls, learning_rate):
...
self.evaluator = FunctionalEvaluator(train_and_evaluate, learning_rate=nn.ValueChoice([1e-3, 1e-2, 1e-1]))
Value choices supports arithmetic operators, which is particularly useful when searching for a network width multiplier: ::
# init
scale = nn.ValueChoice([1.0, 1.5, 2.0])
self.conv1 = nn.Conv2d(3, round(scale * 16))
self.conv2 = nn.Conv2d(round(scale * 16), round(scale * 64))
self.conv3 = nn.Conv2d(round(scale * 64), round(scale * 256))
# forward
return self.conv3(self.conv2(self.conv1(x)))
Or when kernel size and padding are coupled so as to keep the output size constant: ::
# init
ks = nn.ValueChoice([3, 5, 7])
self.conv = nn.Conv2d(3, 16, kernel_size=ks, padding=(ks - 1) // 2)
# forward
return self.conv(x)
Or when several layers are concatenated for a final layer. ::
# init
self.linear1 = nn.Linear(3, nn.ValueChoice([1, 2, 3], label='a'))
self.linear2 = nn.Linear(3, nn.ValueChoice([4, 5, 6], label='b'))
self.final = nn.Linear(nn.ValueChoice([1, 2, 3], label='a') + nn.ValueChoice([4, 5, 6], label='b'), 2)
# forward
return self.final(torch.cat([self.linear1(x), self.linear2(x)], 1))
Some advanced operators are also provided, such as :meth:`ValueChoice.max` and :meth:`ValueChoice.cond`.
.. tip::
All the APIs have an optional argument called ``label``,
mutations with the same label will share the same choice. A typical example is, ::
self.net = nn.Sequential(
nn.Linear(10, nn.ValueChoice([32, 64, 128], label='hidden_dim')),
nn.Linear(nn.ValueChoice([32, 64, 128], label='hidden_dim'), 3)
)
Sharing the same value choice instance has the similar effect. ::
class Net(nn.Module):
def __init__(self):
super().__init__()
hidden_dim = nn.ValueChoice([128, 512])
self.fc = nn.Sequential(
nn.Linear(64, hidden_dim),
nn.Linear(hidden_dim, 10)
)
.. warning::
It looks as if a specific candidate has been chosen (e.g., how it looks like when you can put ``ValueChoice``
as a parameter of ``nn.Conv2d``), but in fact it's a syntax sugar as because the basic units and evaluators
do all the underlying works. That means, you cannot assume that ``ValueChoice`` can be used in the same way
as its candidates. For example, the following usage will NOT work: ::
self.blocks = []
for i in range(nn.ValueChoice([1, 2, 3])):
self.blocks.append(Block())
# NOTE: instead you should probably write
# self.blocks = nn.Repeat(Block(), (1, 3))
Another use case is to initialize the values to choose from in init and call the module in forward to get the chosen value.
Usually, this is used to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```.
For example, ::
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([0., 1.])
def forward(self, x):
return F.dropout(x, self.dropout_rate())
Parameters
----------
candidates : list
List of values to choose from.
prior : list of float
Prior distribution to sample from.
label : str
Identifier of the value choice.
"""
# FIXME: prior is designed but not supported yet
@classmethod
def create_fixed_module(cls, candidates: List[_cand], *, label: Optional[str] = None, **kwargs):
value = get_fixed_value(label)
if value not in candidates:
raise ValueError(f'Value {value} does not belong to the candidates: {candidates}.')
return value
def __init__(self, candidates: List[_cand], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__() # type: ignore
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
self._label = generate_new_label(label)
@property
def label(self):
return self._label
def forward(self):
"""
The forward of input choice is simply the first value of ``candidates``.
It shouldn't be called directly by users in most cases.
"""
warnings.warn('You should not run forward of this module directly.')
return self.candidates[0]
def inner_choices(self) -> Iterable['ValueChoice']:
# yield self because self is the only value choice here
yield self
def dry_run(self) -> _cand:
return self.candidates[0]
def _evaluate(self, values: Iterator[_cand], dry_run: bool = False) -> _cand:
if dry_run:
return self.candidates[0]
try:
value = next(values)
except StopIteration:
raise ValueError(f'Value list {values} is exhausted when trying to get a chosen value of {self}.')
if value not in self.candidates:
raise ValueError(f'Value {value} does not belong to the candidates of {self}.')
return value
def __repr__(self):
return f'ValueChoice({self.candidates}, label={repr(self.label)})'
ValueType = TypeVar('ValueType')
class ModelParameterChoice:
"""ModelParameterChoice chooses one hyper-parameter from ``candidates``.
.. attention::
This API is internal, and does not guarantee forward-compatibility.
It's quite similar to :class:`ValueChoice`, but unlike :class:`ValueChoice`,
it always returns a fixed value, even at the construction of base model.
This makes it highly flexible (e.g., can be used in for-loop, if-condition, as argument of any function). For example: ::
self.has_auxiliary_head = ModelParameterChoice([False, True])
# this will raise error if you use `ValueChoice`
if self.has_auxiliary_head is True: # or self.has_auxiliary_head
self.auxiliary_head = Head()
else:
self.auxiliary_head = None
print(type(self.has_auxiliary_head)) # <class 'bool'>
The working mechanism of :class:`ModelParameterChoice` is that, it registers itself
in the ``model_wrapper``, as a hyper-parameter of the model, and then returns the value specified with ``default``.
At base model construction, the default value will be used (as a mocked hyper-parameter).
In trial, the hyper-parameter selected by strategy will be used.
Although flexible, we still recommend using :class:`ValueChoice` in favor of :class:`ModelParameterChoice`,
because information are lost when using :class:`ModelParameterChoice` in exchange of its flexibility,
making it incompatible with one-shot strategies and non-python execution engines.
.. warning::
:class:`ModelParameterChoice` can NOT be nested.
.. tip::
Although called :class:`ModelParameterChoice`, it's meant to tune hyper-parameter of architecture.
It's NOT used to tune model-training hyper-parameters like ``learning_rate``.
If you need to tune ``learning_rate``, please use :class:`ValueChoice` on arguments of :class:`nni.retiarii.Evaluator`.
Parameters
----------
candidates : list of any
List of values to choose from.
prior : list of float
Prior distribution to sample from. Currently has no effect.
default : Callable[[List[Any]], Any] or Any
Function that selects one from ``candidates``, or a candidate.
Use :meth:`ModelParameterChoice.FIRST` or :meth:`ModelParameterChoice.LAST` to take the first or last item.
Default: :meth:`ModelParameterChoice.FIRST`
label : str
Identifier of the value choice.
Warnings
--------
:class:`ModelParameterChoice` is incompatible with one-shot strategies and non-python execution engines.
Sometimes, the same search space implemented **without** :class:`ModelParameterChoice` can be simpler, and explored
with more types of search strategies. For example, the following usages are equivalent: ::
# with ModelParameterChoice
depth = nn.ModelParameterChoice(list(range(3, 10)))
blocks = []
for i in range(depth):
blocks.append(Block())
# w/o HyperParmaeterChoice
blocks = Repeat(Block(), (3, 9))
Examples
--------
Get a dynamic-shaped parameter. Because ``torch.zeros`` is not a basic unit, we can't use :class:`ValueChoice` on it.
>>> parameter_dim = nn.ModelParameterChoice([64, 128, 256])
>>> self.token = nn.Parameter(torch.zeros(1, parameter_dim, 32, 32))
"""
# FIXME: fix signature in docs
# FIXME: prior is designed but not supported yet
def __new__(cls, candidates: List[ValueType], *,
prior: Optional[List[float]] = None,
default: Union[Callable[[List[ValueType]], ValueType], ValueType] = None,
label: Optional[str] = None) -> ValueType:
# Actually, creating a `ModelParameterChoice` never creates one.
# It always return a fixed value, and register a ParameterSpec
if default is None:
default = cls.FIRST
try:
return cls.create_fixed_module(candidates, label=label)
except NoContextError:
return cls.create_default(candidates, default, label)
@staticmethod
def create_default(candidates: List[ValueType],
default: Union[Callable[[List[ValueType]], ValueType], ValueType],
label: Optional[str]) -> ValueType:
if default not in candidates:
# could be callable
try:
default = cast(Callable[[List[ValueType]], ValueType], default)(candidates)
except TypeError as e:
if 'not callable' in str(e):
raise TypeError("`default` is not in `candidates`, and it's also not callable.")
raise
default = cast(ValueType, default)
label = generate_new_label(label)
parameter_spec = ParameterSpec(
label, # name
'choice', # TODO: support more types
candidates, # value
(label,), # we don't have nested now
True, # yes, categorical
)
# there could be duplicates. Dedup is done in mutator
ModelNamespace.current_context().parameter_specs.append(parameter_spec)
return default
@classmethod
def create_fixed_module(cls, candidates: List[ValueType], *, label: Optional[str] = None, **kwargs) -> ValueType:
# same as ValueChoice
value = get_fixed_value(label)
if value not in candidates:
raise ValueError(f'Value {value} does not belong to the candidates: {candidates}.')
return value
@staticmethod
def FIRST(sequence: Sequence[ValueType]) -> ValueType:
"""Get the first item of sequence. Useful in ``default`` argument."""
return sequence[0]
@staticmethod
def LAST(sequence: Sequence[ValueType]) -> ValueType:
"""Get the last item of sequence. Useful in ``default`` argument."""
return sequence[-1]
@basic_unit
class Placeholder(nn.Module):
"""
The API that creates an empty module for later mutations.
For advanced usages only.
"""
def __init__(self, label, **related_info):
self.label = label
self.related_info = related_info
super().__init__()
def forward(self, x):
"""
Forward of placeholder is not meaningful.
It returns input directly.
"""
return x
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
from pathlib import Path
# To make auto-completion happy, we generate a _layers.py that lists out all the classes.
nn_cache_file_path = Path(__file__).parent / '_layers.py'
# Update this when cache format changes, to enforce an update.
cache_version = 3
def validate_cache() -> bool:
import torch
cache_valid = []
if nn_cache_file_path.exists():
lines = nn_cache_file_path.read_text().splitlines()
for line in lines:
if line.startswith('# _torch_version'):
_cached_torch_version = line[line.find('=') + 1:].strip()
if _cached_torch_version == torch.__version__:
cache_valid.append(True)
if line.startswith('# _torch_nn_cache_version'):
_cached_cache_version = int(line[line.find('=') + 1:].strip())
if _cached_cache_version == cache_version:
cache_valid.append(True)
return len(cache_valid) >= 2 and all(cache_valid)
def generate_stub_file() -> str:
import inspect
import warnings
import torch
import torch.nn as nn
_NO_WRAP_CLASSES = [
# not an nn.Module
'Parameter',
'ParameterList',
'UninitializedBuffer',
'UninitializedParameter',
# arguments are special
'Module',
'Sequential',
# utilities
'Container',
'DataParallel',
]
_WRAP_WITHOUT_TAG_CLASSES = [
# special support on graph engine
'ModuleList',
'ModuleDict',
]
code = [
'# Copyright (c) Microsoft Corporation.',
'# Licensed under the MIT license.',
'# This file is auto-generated to make auto-completion work.',
'# When pytorch version does not match, it will get automatically updated.',
'# pylint: skip-file',
'# pyright: reportGeneralTypeIssues=false',
f'# _torch_version = {torch.__version__}',
f'# _torch_nn_cache_version = {cache_version}',
'import typing',
'import torch.nn as nn',
'from nni.nas.utils import basic_unit',
]
all_names = []
# Add modules, classes, functions in torch.nn into this module.
for name, obj in inspect.getmembers(torch.nn):
if inspect.isclass(obj):
if name in _NO_WRAP_CLASSES:
code.append(f'{name} = nn.{name}')
elif not issubclass(obj, nn.Module):
# It should never go here
# We did it to play safe
warnings.warn(f'{obj} is found to be not a nn.Module, which is unexpected. '
'It means your PyTorch version might not be supported.', RuntimeWarning)
code.append(f'{name} = nn.{name}')
elif name in _WRAP_WITHOUT_TAG_CLASSES:
code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}, basic_unit_tag=False))')
else:
code.append(f'{name} = typing.cast(typing.Type[nn.{name}], basic_unit(nn.{name}))')
all_names.append(name)
elif inspect.isfunction(obj) or inspect.ismodule(obj):
code.append(f'{name} = nn.{name}') # no modification
all_names.append(name)
code.append(f'__all__ = {all_names}')
return '\n'.join(code)
def write_cache(code: str) -> None:
with nn_cache_file_path.open('w') as fp:
fp.write(code)
code = generate_stub_file()
if not validate_cache():
write_cache(code)
del Path, validate_cache, write_cache, cache_version, nn_cache_file_path, code
from ._layers import * # pylint: disable=import-error, wildcard-import, unused-wildcard-import
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment