Unverified Commit 055b42c0 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

More improvements on NAS documentation and cell interface (#4752)

parent 2d8f925b
......@@ -20,7 +20,7 @@ class Trainer(pl.Trainer):
default: False
trainer_kwargs : dict
Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details.
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
"""
def __init__(self, use_cgo=False, **trainer_kwargs):
......
......@@ -50,6 +50,9 @@ class Cell(nn.Module):
Two examples of searched cells are illustrated in the figure below.
In these two cells, ``op_candidates`` are series of convolutions and pooling operations.
``num_nodes_per_node`` is set to 2. ``num_nodes`` is set to 5. ``merge_op`` is ``loose_end``.
Assuming nodes are enumerated from bottom to top, left to right,
``output_node_indices`` for the normal cell is ``[2, 3, 4, 5, 6]``.
For the reduction cell, it's ``[4, 5, 6]``.
Please take a look at this
`review article <https://sh-tsang.medium.com/review-nasnet-neural-architecture-search-network-image-classification-23139ea0425d>`__
if you are interested in details.
......@@ -106,7 +109,7 @@ class Cell(nn.Module):
will be ``list(range(num_predecessors, num_predecessors + num_nodes))``.
If "loose_end", only the nodes that have never been used as other nodes' inputs will be concatenated to the output.
Predecessors are not considered when calculating unused nodes.
Details can be found in reference [nds]. Default: all.
Details can be found in `NDS paper <https://arxiv.org/abs/1905.13214>`__. Default: all.
preprocessor : callable
Override this if some extra transformation on cell's input is intended.
It should be a callable (``nn.Module`` is also acceptable) that takes a list of tensors which are predecessors,
......@@ -118,15 +121,11 @@ class Cell(nn.Module):
Its return type should be either one tensor, or a tuple of tensors.
The return value of postprocessor is the return value of the cell's forward.
By default, it returns only the output of the current cell.
concat_dim : int
The result will be a concatenation of several nodes on this dim. Default: 1.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
Attributes
----------
output_node_indices : list of int
Indices of the nodes concatenated to the output. For example, if the following operation is a 2d-convolution,
its input channels is ``len(output_node_indices) * channels``.
Examples
--------
Choose between conv2d and maxpool2d.
......@@ -138,10 +137,16 @@ class Cell(nn.Module):
>>> cell([input1, input2])
The "list bracket" can be omitted:
>>> cell(only_input) # only one input
>>> cell(tensor1, tensor2, tensor3) # multiple inputs
Use ``merge_op`` to specify how to construct the output.
The output will then have dynamic shape, depending on which input has been used in the cell.
>>> cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2, merge_op='loose_end')
>>> cell_out_channels = len(cell.output_node_indices) * 32
The op candidates can be callable that accepts node index in cell, op index in node, and input index.
......@@ -161,6 +166,21 @@ class Cell(nn.Module):
cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2, preprocessor=Preprocessor())
cell([torch.randn(1, 16, 48, 48), torch.randn(1, 64, 48, 48)]) # the two inputs will be sent to conv1 and conv2 respectively
Warnings
--------
:class:`Cell` is not supported in :ref:`graph-based execution engine <graph-based-exeuction-engine>`.
Attributes
----------
output_node_indices : list of int
An attribute that contains indices of the nodes concatenated to the output (a list of integers).
When the cell is first instantiated in the base model, or when ``merge_op`` is ``all``,
``output_node_indices`` must be ``range(num_predecessors, num_predecessors + num_nodes)``.
When ``merge_op`` is ``loose_end``, ``output_node_indices`` is useful to compute the shape of this cell's output,
because the output shape depends on the connection in the cell, and which nodes are "loose ends" depends on mutation.
"""
def __init__(self,
......@@ -176,6 +196,7 @@ class Cell(nn.Module):
preprocessor: Optional[Callable[[List[torch.Tensor]], List[torch.Tensor]]] = None,
postprocessor: Optional[Callable[[torch.Tensor, List[torch.Tensor]],
Union[Tuple[torch.Tensor, ...], torch.Tensor]]] = None,
concat_dim: int = 1,
*,
label: Optional[str] = None):
super().__init__()
......@@ -197,6 +218,8 @@ class Cell(nn.Module):
self.merge_op = merge_op
self.output_node_indices = list(range(num_predecessors, num_predecessors + num_nodes))
self.concat_dim = concat_dim
# fill-in the missing modules
self._create_modules(op_candidates)
......@@ -228,11 +251,28 @@ class Cell(nn.Module):
def label(self):
return self._label
def forward(self, x: List[torch.Tensor]):
# The return type should be 'Union[Tuple[torch.Tensor, ...], torch.Tensor]'.
# Cannot decorate it as annotation. Otherwise torchscript will complain.
assert isinstance(x, list), 'We currently only support input of cell as a list, even if you have only one predecessor.'
states = self.preprocessor(x)
def forward(self, *inputs: Union[List[torch.Tensor], torch.Tensor]) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
"""Forward propagation of cell.
Parameters
----------
inputs
Can be a list of tensors, or several tensors.
The length should be equal to ``num_predecessors``.
Returns
-------
Tuple[torch.Tensor] | torch.Tensor
The return type depends on the output of ``postprocessor``.
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell.
"""
if len(inputs) == 1 and isinstance(inputs[0], list):
inputs = inputs[0]
else:
inputs = list(inputs)
assert len(inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
states = self.preprocessor(inputs)
for ops, inps in zip(self.ops, self.inputs):
current_state = []
for op, inp in zip(ops, inps):
......@@ -241,10 +281,10 @@ class Cell(nn.Module):
states.append(current_state)
if self.merge_op == 'all':
# a special case for graph engine
this_cell = torch.cat(states[self.num_predecessors:], 1)
this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
else:
this_cell = torch.cat([states[k] for k in self.output_node_indices], 1)
return self.postprocessor(this_cell, x)
this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim)
return self.postprocessor(this_cell, inputs)
@staticmethod
def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]:
......
......@@ -277,6 +277,10 @@ class NasBench101Cell(Mutable):
Maximum number of edges in the cell. Default: 9.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
Warnings
--------
:class:`NasBench101Cell` is not supported in :ref:`graph-based execution engine <graph-based-exeuction-engine>`.
"""
@staticmethod
......
......@@ -776,82 +776,6 @@ class GraphIR(unittest.TestCase):
b = model_new(inp)
self.assertLess((a - b).abs().max().item(), 1E-4)
def test_cell(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(
torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64]))
@model_wrapper
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
def forward(self, x):
return self.cell([x])
raw_model, mutators = self._get_model_with_mutators(Net2())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_cell_predecessors(self):
from typing import List, Tuple
class Preprocessor(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 16)
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return [self.linear(x[0]), x[1]]
class Postprocessor(nn.Module):
def forward(self, this: torch.Tensor, prev: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return prev[-1], this
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': nn.Linear(16, 16),
'second': nn.Linear(16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2,
preprocessor=Preprocessor(), postprocessor=Postprocessor(), merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
result = self._get_converted_pytorch_model(model)(
torch.randn(1, 3), torch.randn(1, 16))
self.assertTrue(result[0].size() == torch.Size([1, 16]))
self.assertTrue(result[1].size() == torch.Size([1, 64]))
def test_nasbench201_cell(self):
@model_wrapper
class Net(nn.Module):
......@@ -978,6 +902,82 @@ class Python(GraphIR):
with self.assertRaises(NoContextError):
model = Net()
def test_cell(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell(x, y)
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(
torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64]))
@model_wrapper
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
def forward(self, x):
return self.cell(x)
raw_model, mutators = self._get_model_with_mutators(Net2())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_cell_predecessors(self):
from typing import List, Tuple
class Preprocessor(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 16)
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return [self.linear(x[0]), x[1]]
class Postprocessor(nn.Module):
def forward(self, this: torch.Tensor, prev: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return prev[-1], this
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': nn.Linear(16, 16),
'second': nn.Linear(16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2,
preprocessor=Preprocessor(), postprocessor=Postprocessor(), merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
result = self._get_converted_pytorch_model(model)(
torch.randn(1, 3), torch.randn(1, 16))
self.assertTrue(result[0].size() == torch.Size([1, 16]))
self.assertTrue(result[1].size() == torch.Size([1, 64]))
def test_cell_loose_end(self):
@model_wrapper
class Net(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment