Unverified Commit b8b7ed0e authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Add loose-end support for Cell API (#4411)

parent 53ae92cc
Mutation Primitives
===================
.. TODO: this file will be merged with API reference in future.
To make users easily express a model space within their PyTorch/TensorFlow model, NNI provides some inline mutation APIs as shown below.
We show the most common use case here. For advanced usages, please see `reference <./ApiReference.rst>`__.
......@@ -48,7 +50,7 @@ API reference: :class:`nni.retiarii.nn.pytorch.ValueChoice`
It is for choosing one value from some candidate values. The most common use cases are:
* Used as input arguments of `basic units <LINK_TBD>` (i.e., modules in ``nni.retiarii.nn.pytorch`` and user-defined modules decorated with ``@basic_unit``).
* Used as input arguments of :class:`nni.retiarii.basic_unit` (i.e., modules in ``nni.retiarii.nn.pytorch`` and user-defined modules decorated with ``@basic_unit``).
* Used as input arguments of evaluator (*new in v2.7*).
Examples are as follows:
......@@ -108,9 +110,6 @@ Repeat a block by a variable number of times.
# Block() will be repeated 1, 2, or 3 times
self.blocks = nn.Repeat(Block(), (1, 3))
# FIXME
# The following use cases have known issues and will be fixed in current release
# Can be used together with layer choice
# With deep copy, the 3 layers will have the same label, thus share the choice
self.blocks = nn.Repeat(nn.LayerChoice([...]), (1, 3))
......@@ -124,4 +123,60 @@ Repeat a block by a variable number of times.
API reference: :class:`nni.retiarii.nn.pytorch.Cell`
This cell structure is popularly used in `NAS literature <https://arxiv.org/abs/1611.01578>`__. Specifically, the cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from user specified candidates, and takes one input from previous nodes and predecessors. Predecessor means the input of cell. The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes).
This cell structure is popularly used in `NAS literature <https://arxiv.org/abs/1611.01578>`__. High-level speaking, literatures often use the following glossaries.
.. list-table::
:widths: 25 75
* - Cell
- A cell consists of several nodes.
* - Node
- A node is the **sum** of several operators.
* - Operator
- Each operator is independently chosen from a list of user-specified candidate operators.
* - Operator's input
- Each operator has one input, chosen from previous nodes as well as predecessors.
* - Predecessors
- Input of cell. A cell can have multiple predecessors. Predecessors are sent to *preprocessor* for preprocessing.
* - Cell's output
- Output of cell. Usually concatenation of several nodes (possibly all nodes) in the cell. Cell's output, along with predecessors, are sent to *postprocessor* for postprocessing.
* - Preprocessor
- Extra preprocessing to predecessors. Usually used in shape alignment (e.g., predecessors have different shapes). By default, do nothing.
* - Postprocessor
- Extra postprocessing for cell's output. Usually used to chain cells with multiple Predecessors
(e.g., the next cell wants to have the outputs of both this cell and previous cell as its input). By default, directly use this cell's output.
Example usages:
.. code-block:: python
# import nni.retiarii.nn.pytorch as nn
# used in `__init__` method
# Choose between conv2d and maxpool2d.
# The cell have 4 nodes, 1 op per node, and 2 predecessors.
cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2)
# forward
cell([input1, input2])
# Use `merge_op` to specify how to construct the output.
# The output will then have dynamic shape, depending on which input has been used in the cell.
cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2, merge_op='loose_end')
# The op candidates can be callable that accepts node index in cell, op index in node, and input index.
cell = nn.Cell([
lambda node_index, op_index, input_index: nn.Conv2d(32, 32, 3, stride=2 if input_index < 1 else 1),
...
], 4, 1, 2)
# predecessor example
class Preprocessor:
def __init__(self):
self.conv1 = nn.Conv2d(16, 32, 1)
self.conv2 = nn.Conv2d(64, 32, 1)
def forward(self, x):
return [self.conv1(x[0]), self.conv2(x[1])]
cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2, preprocessor=Preprocessor())
cell([torch.randn(1, 16, 48, 48), torch.randn(1, 64, 48, 48)]) # the two inputs will be sent to conv1 and conv2 respectively
......@@ -2,6 +2,7 @@
# Licensed under the MIT license.
import logging
import re
from typing import Dict, List, Tuple, Any
from nni.retiarii.operation_def.torch_op_def import ToDevice
......@@ -38,14 +39,17 @@ def _sorted_incoming_edges(node: Node) -> List[Edge]:
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))
def _format_inputs(node: Node) -> Tuple[List[str], List[Any]]:
def _format_inputs(node: Node, graph_name: str) -> Tuple[List[str], List[Any]]:
"""
Format the inputs of a given node
Format the inputs of a given node.
Inputs will be formatted with ``_format_variable_name``
Parameters
----------
node : Node
a graph node, get and format its inputs
graph_name : str
subgraph name, to format variable names
Returns
-------
......@@ -63,7 +67,7 @@ def _format_inputs(node: Node) -> Tuple[List[str], List[Any]]:
assert isinstance(edge.head_slot, int)
if edge.head.operation.io_names is not None:
# when input has names, e.g., forward(self, tensor1, tensor2, another_one)
inputs.append(edge.head.operation.io_names[edge.head_slot])
inputs.append(_format_variable_name(edge.head.operation.io_names[edge.head_slot], graph_name))
else:
# when input has no name, e.g., forward(*_inputs)
inputs.append('_inputs[{}]'.format(edge.head_slot))
......@@ -71,7 +75,7 @@ def _format_inputs(node: Node) -> Tuple[List[str], List[Any]]:
else:
if edge.head_slot is None:
# when the input comes from a single-output operator
inputs.append('{}'.format(edge.head.name))
inputs.append(_format_variable_name(edge.head.name, graph_name))
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \
'value' in edge.head.operation.parameters:
inputs_value.append(edge.head.operation.parameters['value'])
......@@ -79,26 +83,21 @@ def _format_inputs(node: Node) -> Tuple[List[str], List[Any]]:
inputs_value.append(None)
else:
# when the input comes from a multi-output operator: needs to know which one it comes from
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot))
inputs.append('{}[{}]'.format(_format_variable_name(edge.head.name, graph_name), edge.head_slot))
inputs_value.append(None)
return inputs, inputs_value
def _remove_prefix(names, graph_name):
def _format_variable_name(name: str, graph_name: str) -> str:
"""
variables name (full name space) is too long,
shorten the name by removing the prefix ```graph_name```
1. replace invalid characters in node name
2. variables name (full name space) is too long, shorten the name by removing the prefix ```graph_name```
"""
if isinstance(names, list):
converted_names = []
for name in names:
if name.startswith(graph_name):
converted_names.append(name[len(graph_name):])
else:
converted_names.append(name)
return converted_names
else:
return names[len(graph_name):] if names.startswith(graph_name) else names
name = name[len(graph_name):] if name.startswith(graph_name) else name
name = name.replace('/', '__')
# https://stackoverflow.com/questions/3303312/how-do-i-convert-a-string-to-a-valid-variable-name-in-python
return re.sub('\W|^(?=\d)','_', name)
def generate_cuda_mapping(placement: Dict[Node, Device]) -> Dict[Device, int]:
......@@ -139,7 +138,7 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
pkg_name = node.operation.get_import_pkg()
if pkg_name is not None:
import_pkgs.add(pkg_name)
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name))
node_code = node.operation.to_init_code(_format_variable_name(node.name, graph_name))
if node_code is not None:
if placement and node in placement and len(node_code) > 0:
if isinstance(placement[node], GPUDevice):
......@@ -161,16 +160,14 @@ def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str
sorted_nodes = graph.topo_sort()
for node in sorted_nodes:
if node.operation:
inputs, inputs_value = _format_inputs(node)
inputs = _remove_prefix(inputs, graph_name)
node_name = _remove_prefix(node.name, graph_name)
inputs, inputs_value = _format_inputs(node, graph_name)
node_name = _format_variable_name(node.name, graph_name)
submodule_name = node_name
if node.operation.type == 'shared':
submodule_name = _remove_prefix(node.operation.parameters['reference'], graph_name)
submodule_name = _format_variable_name(node.operation.parameters['reference'], graph_name)
edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value))
output_names, _ = _format_inputs(graph.output_node)
output_names = _remove_prefix(output_names, graph_name)
output_names, _ = _format_inputs(graph.output_node, graph_name)
if not output_names:
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node))
output_code = ', '.join(output_names)
......
......@@ -374,6 +374,11 @@ class ChosenInputs(nn.Module):
"""
A module that chooses from a tensor list and outputs a reduced tensor.
The already-chosen version of InputChoice.
Attributes
----------
chosen : list of int
Indices of chosen candidates.
"""
def __init__(self, chosen: Union[List[int], int], reduction: str):
......
import copy
import warnings
from typing import Callable, Dict, List, Union, Optional, Tuple
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
import torch
import torch.nn as nn
from .api import ChosenInputs, LayerChoice, InputChoice
from .nn import ModuleList
from .utils import generate_new_label
class _ListIdentity(nn.Identity):
# workaround for torchscript
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return x
class _DefaultPostprocessor(nn.Module):
# this is also a workaround for torchscript
def forward(self, this_cell: torch.Tensor, prev_cell: List[torch.Tensor]) -> torch.Tensor:
return this_cell
_cell_op_factory_type = Callable[[int, int, Optional[int]], nn.Module]
class Cell(nn.Module):
"""
Cell structure [zophnas]_ [zophnasnet]_ that is popularly used in NAS literature.
[nds]_ is a good summary of how this structure works in practice.
A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from
``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell.
The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes).
Parameters
----------
op_candidates : list of module or function, or dict
A list of modules to choose from, or a function that accepts current index and optionally its input index, and returns a module.
For example, (2, 3, 0) means the 3rd op in the 2nd node, accepts the 0th node as input.
The index are enumerated for all nodes including predecessors from 0.
When first created, the input index is ``None``, meaning unknown.
Note that in graph execution engine, support of function in ``op_candidates`` is limited.
num_nodes : int
Number of nodes in the cell.
num_ops_per_node: int
Number of operators in each node. The output of each node is the sum of all operators in the node. Default: 1.
num_predecessors : int
Number of inputs of the cell. The input to forward should be a list of tensors. Default: 1.
merge_op : "all", or "loose_end"
If "all", all the nodes (except predecessors) will be concatenated as the cell's output, in which case, ``output_node_indices``
will be ``list(range(num_predecessors, num_predecessors + num_nodes))``.
If "loose_end", only the nodes that have never been used as other nodes' inputs will be concatenated to the output.
Predecessors are not considered when calculating unused nodes.
Details can be found in reference [nds]. Default: all.
preprocessor : callable
Override this if some extra transformation on cell's input is intended.
It should be a callable (``nn.Module`` is also acceptable) that takes a list of tensors which are predecessors,
and outputs a list of tensors, with the same length as input.
By default, it does nothing to the input.
postprocessor : callable
Override this if customization on the output of the cell is intended.
It should be a callable that takes the output of this cell, and a list which are predecessors.
Its return type should be either one tensor, or a tuple of tensors.
The return value of postprocessor is the return value of the cell's forward.
By default, it returns only the output of the current cell.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
Attributes
----------
output_node_indices : list of int
Indices of the nodes concatenated to the output. For example, if the following operation is a 2d-convolution,
its input channels is ``len(output_node_indices) * channels``.
Examples
--------
>>> cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2)
>>> output = cell([input1, input2])
References
----------
.. [zophnas] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578
.. [zophnasnet] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le,
"Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012
.. [nds] Radosavovic, Ilija and Johnson, Justin and Xie, Saining and Lo, Wan-Yen and Dollar, Piotr,
"On Network Design Spaces for Visual Recognition". https://arxiv.org/abs/1905.13214
"""
def __init__(self,
op_candidates: Union[
Callable[[], List[nn.Module]],
List[Union[nn.Module, _cell_op_factory_type]],
Dict[str, Union[nn.Module, _cell_op_factory_type]]
],
num_nodes: int,
num_ops_per_node: int = 1,
num_predecessors: int = 1,
merge_op: Literal['all', 'loose_end'] = 'all',
preprocessor: Optional[Callable[[List[torch.Tensor]], List[torch.Tensor]]] = None,
postprocessor: Optional[Callable[[torch.Tensor, List[torch.Tensor]],
Union[Tuple[torch.Tensor, ...], torch.Tensor]]] = None,
*,
label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
# modules are created in "natural" order
# first create preprocessor
self.preprocessor = preprocessor or _ListIdentity()
# then create intermediate ops
self.ops = ModuleList()
self.inputs = ModuleList()
# finally postprocessor
self.postprocessor = postprocessor or _DefaultPostprocessor()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
self.num_predecessors = num_predecessors
assert merge_op in ['all', 'loose_end']
self.merge_op = merge_op
self.output_node_indices = list(range(num_predecessors, num_predecessors + num_nodes))
# fill-in the missing modules
self._create_modules(op_candidates)
def _create_modules(self, op_candidates):
for i in range(self.num_predecessors, self.num_nodes + self.num_predecessors):
self.ops.append(ModuleList())
self.inputs.append(ModuleList())
for k in range(self.num_ops_per_node):
inp = InputChoice(i, 1, label=f'{self.label}/input_{i}_{k}')
chosen = None
if isinstance(inp, ChosenInputs):
# now we are in the fixed mode
# the length of chosen should be 1
chosen = inp.chosen[0]
if self.merge_op == 'loose_end' and chosen in self.output_node_indices:
# remove it from concat indices
self.output_node_indices.remove(chosen)
# this is needed because op_candidates can be very complex
# the type annoation and docs for details
ops = self._convert_op_candidates(op_candidates, i, k, chosen)
# though it's layer choice and input choice here, in fixed mode, the chosen module will be created.
self.ops[-1].append(LayerChoice(ops, label=f'{self.label}/op_{i}_{k}'))
self.inputs[-1].append(inp)
@property
def label(self):
return self._label
def forward(self, x: List[torch.Tensor]):
# The return type should be 'Union[Tuple[torch.Tensor, ...], torch.Tensor]'.
# Cannot decorate it as annotation. Otherwise torchscript will complain.
assert isinstance(x, list), 'We currently only support input of cell as a list, even if you have only one predecessor.'
states = self.preprocessor(x)
for ops, inps in zip(self.ops, self.inputs):
current_state = []
for op, inp in zip(ops, inps):
current_state.append(op(inp(states)))
current_state = torch.sum(torch.stack(current_state), 0)
states.append(current_state)
if self.merge_op == 'all':
# a special case for graph engine
this_cell = torch.cat(states[self.num_predecessors:], 1)
else:
this_cell = torch.cat([states[k] for k in self.output_node_indices], 1)
return self.postprocessor(this_cell, x)
@staticmethod
def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]:
# convert the complex type into the type that is acceptable to LayerChoice
def convert_single_op(op):
if isinstance(op, nn.Module):
return copy.deepcopy(op)
elif callable(op):
# FIXME: I don't know how to check whether we are in graph engine.
return op(node_index, op_index, chosen)
else:
raise TypeError(f'Unrecognized type {type(op)} for op {op}')
if isinstance(op_candidates, list):
return [convert_single_op(op) for op in op_candidates]
elif isinstance(op_candidates, dict):
return {key: convert_single_op(op) for key, op in op_candidates.items()}
elif callable(op_candidates):
warnings.warn(f'Directly passing a callable into Cell is deprecated. Please consider migrating to list or dict.',
DeprecationWarning)
return op_candidates()
else:
raise TypeError(f'Unrecognized type {type(op_candidates)} for {op_candidates}')
......@@ -5,9 +5,8 @@ from typing import Callable, List, Union, Tuple, Optional
import torch
import torch.nn as nn
from .api import LayerChoice, InputChoice
from .nn import ModuleList
from .api import LayerChoice
from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .utils import Mutable, generate_new_label, get_fixed_value
......@@ -78,83 +77,6 @@ class Repeat(Mutable):
return blocks
class Cell(nn.Module):
"""
Cell structure [zophnas]_ [zophnasnet]_ that is popularly used in NAS literature.
A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from
``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell.
The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes).
Parameters
----------
op_candidates : function or list of module
A list of modules to choose from, or a function that returns a list of modules.
num_nodes : int
Number of nodes in the cell.
num_ops_per_node: int
Number of operators in each node. The output of each node is the sum of all operators in the node. Default: 1.
num_predecessors : int
Number of inputs of the cell. The input to forward should be a list of tensors. Default: 1.
merge_op : str
Currently only ``all`` is supported, which has slight difference with that described in reference. Default: all.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
References
----------
.. [zophnas] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578
.. [zophnasnet] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le,
"Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012
"""
# TODO:
# Support loose end concat (shape inference on the following cells)
# How to dynamically create convolution with stride as the first node
def __init__(self,
op_candidates: Union[Callable, List[nn.Module]],
num_nodes: int,
num_ops_per_node: int = 1,
num_predecessors: int = 1,
merge_op: str = 'all',
label: str = None):
super().__init__()
self._label = generate_new_label(label)
self.ops = ModuleList()
self.inputs = ModuleList()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
self.num_predecessors = num_predecessors
for i in range(num_nodes):
self.ops.append(ModuleList())
self.inputs.append(ModuleList())
for k in range(num_ops_per_node):
if isinstance(op_candidates, list):
assert len(op_candidates) > 0 and isinstance(op_candidates[0], nn.Module)
ops = copy.deepcopy(op_candidates)
else:
ops = op_candidates()
self.ops[-1].append(LayerChoice(ops, label=f'{self.label}__op_{i}_{k}'))
self.inputs[-1].append(InputChoice(i + num_predecessors, 1, label=f'{self.label}/input_{i}_{k}'))
assert merge_op in ['all'] # TODO: loose_end
self.merge_op = merge_op
@property
def label(self):
return self._label
def forward(self, x: List[torch.Tensor]):
states = x
for ops, inps in zip(self.ops, self.inputs):
current_state = []
for op, inp in zip(ops, inps):
current_state.append(op(inp(states)))
current_state = torch.sum(torch.stack(current_state), 0)
states.append(current_state)
return torch.cat(states[self.num_predecessors:], 1)
class NasBench201Cell(nn.Module):
"""
Cell structure that is proposed in NAS-Bench-201 [nasbench201]_ .
......
......@@ -124,6 +124,13 @@ class PrimGetAttr(PyTorchOperation):
return f"{output} = {self.parameters['input']}.{self.parameters['name']}"
class PrimUncheckedCast(PyTorchOperation):
_ori_type_name = ['prim::unchecked_cast']
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str:
return f'{output} = {inputs[0]}'
class SimpleMember(PyTorchOperation):
_ori_type_name = ['prim::is_cuda', 'prim::data']
......
......@@ -565,6 +565,45 @@ class GraphIR(unittest.TestCase):
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_cell_predecessors(self):
from typing import List, Tuple
class Preprocessor(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 16)
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return [self.linear(x[0]), x[1]]
class Postprocessor(nn.Module):
def forward(self, this: torch.Tensor, prev: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return prev[-1], this
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': nn.Linear(16, 16),
'second': nn.Linear(16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2,
preprocessor=Preprocessor(), postprocessor=Postprocessor(), merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
result = self._get_converted_pytorch_model(model)(
torch.randn(1, 3), torch.randn(1, 16))
self.assertTrue(result[0].size() == torch.Size([1, 16]))
self.assertTrue(result[1].size() == torch.Size([1, 64]))
def test_nasbench201_cell(self):
@model_wrapper
class Net(nn.Module):
......@@ -627,6 +666,54 @@ class Python(GraphIR):
@unittest.skip
def test_valuechoice_access_functional_expression(self): ...
def test_cell_loose_end(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='loose_end')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
any_not_all = False
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
model = self._get_converted_pytorch_model(model)
indices = model.cell.output_node_indices
assert all(i > 2 for i in indices)
self.assertTrue(model(torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 16 * len(indices)]))
if len(indices) < 4:
any_not_all = True
self.assertTrue(any_not_all)
def test_cell_complex(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16),
'second': lambda _, __, chosen: nn.Linear(3 if chosen == 0 else 16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(
torch.randn(1, 3), torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_nasbench101_cell(self):
# this is only supported in python engine for now.
@model_wrapper
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment