Unverified Commit 055b42c0 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

More improvements on NAS documentation and cell interface (#4752)

parent 2d8f925b
...@@ -20,7 +20,7 @@ class Trainer(pl.Trainer): ...@@ -20,7 +20,7 @@ class Trainer(pl.Trainer):
default: False default: False
trainer_kwargs : dict trainer_kwargs : dict
Optional keyword arguments passed to trainer. See Optional keyword arguments passed to trainer. See
`Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/trainer.html>`__ for details. `Lightning documentation <https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html>`__ for details.
""" """
def __init__(self, use_cgo=False, **trainer_kwargs): def __init__(self, use_cgo=False, **trainer_kwargs):
......
...@@ -50,6 +50,9 @@ class Cell(nn.Module): ...@@ -50,6 +50,9 @@ class Cell(nn.Module):
Two examples of searched cells are illustrated in the figure below. Two examples of searched cells are illustrated in the figure below.
In these two cells, ``op_candidates`` are series of convolutions and pooling operations. In these two cells, ``op_candidates`` are series of convolutions and pooling operations.
``num_nodes_per_node`` is set to 2. ``num_nodes`` is set to 5. ``merge_op`` is ``loose_end``. ``num_nodes_per_node`` is set to 2. ``num_nodes`` is set to 5. ``merge_op`` is ``loose_end``.
Assuming nodes are enumerated from bottom to top, left to right,
``output_node_indices`` for the normal cell is ``[2, 3, 4, 5, 6]``.
For the reduction cell, it's ``[4, 5, 6]``.
Please take a look at this Please take a look at this
`review article <https://sh-tsang.medium.com/review-nasnet-neural-architecture-search-network-image-classification-23139ea0425d>`__ `review article <https://sh-tsang.medium.com/review-nasnet-neural-architecture-search-network-image-classification-23139ea0425d>`__
if you are interested in details. if you are interested in details.
...@@ -106,7 +109,7 @@ class Cell(nn.Module): ...@@ -106,7 +109,7 @@ class Cell(nn.Module):
will be ``list(range(num_predecessors, num_predecessors + num_nodes))``. will be ``list(range(num_predecessors, num_predecessors + num_nodes))``.
If "loose_end", only the nodes that have never been used as other nodes' inputs will be concatenated to the output. If "loose_end", only the nodes that have never been used as other nodes' inputs will be concatenated to the output.
Predecessors are not considered when calculating unused nodes. Predecessors are not considered when calculating unused nodes.
Details can be found in reference [nds]. Default: all. Details can be found in `NDS paper <https://arxiv.org/abs/1905.13214>`__. Default: all.
preprocessor : callable preprocessor : callable
Override this if some extra transformation on cell's input is intended. Override this if some extra transformation on cell's input is intended.
It should be a callable (``nn.Module`` is also acceptable) that takes a list of tensors which are predecessors, It should be a callable (``nn.Module`` is also acceptable) that takes a list of tensors which are predecessors,
...@@ -118,15 +121,11 @@ class Cell(nn.Module): ...@@ -118,15 +121,11 @@ class Cell(nn.Module):
Its return type should be either one tensor, or a tuple of tensors. Its return type should be either one tensor, or a tuple of tensors.
The return value of postprocessor is the return value of the cell's forward. The return value of postprocessor is the return value of the cell's forward.
By default, it returns only the output of the current cell. By default, it returns only the output of the current cell.
concat_dim : int
The result will be a concatenation of several nodes on this dim. Default: 1.
label : str label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice. Identifier of the cell. Cell sharing the same label will semantically share the same choice.
Attributes
----------
output_node_indices : list of int
Indices of the nodes concatenated to the output. For example, if the following operation is a 2d-convolution,
its input channels is ``len(output_node_indices) * channels``.
Examples Examples
-------- --------
Choose between conv2d and maxpool2d. Choose between conv2d and maxpool2d.
...@@ -138,10 +137,16 @@ class Cell(nn.Module): ...@@ -138,10 +137,16 @@ class Cell(nn.Module):
>>> cell([input1, input2]) >>> cell([input1, input2])
The "list bracket" can be omitted:
>>> cell(only_input) # only one input
>>> cell(tensor1, tensor2, tensor3) # multiple inputs
Use ``merge_op`` to specify how to construct the output. Use ``merge_op`` to specify how to construct the output.
The output will then have dynamic shape, depending on which input has been used in the cell. The output will then have dynamic shape, depending on which input has been used in the cell.
>>> cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2, merge_op='loose_end') >>> cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2, merge_op='loose_end')
>>> cell_out_channels = len(cell.output_node_indices) * 32
The op candidates can be callable that accepts node index in cell, op index in node, and input index. The op candidates can be callable that accepts node index in cell, op index in node, and input index.
...@@ -161,6 +166,21 @@ class Cell(nn.Module): ...@@ -161,6 +166,21 @@ class Cell(nn.Module):
cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2, preprocessor=Preprocessor()) cell = nn.Cell([nn.Conv2d(32, 32, 3), nn.MaxPool2d(3)], 4, 1, 2, preprocessor=Preprocessor())
cell([torch.randn(1, 16, 48, 48), torch.randn(1, 64, 48, 48)]) # the two inputs will be sent to conv1 and conv2 respectively cell([torch.randn(1, 16, 48, 48), torch.randn(1, 64, 48, 48)]) # the two inputs will be sent to conv1 and conv2 respectively
Warnings
--------
:class:`Cell` is not supported in :ref:`graph-based execution engine <graph-based-exeuction-engine>`.
Attributes
----------
output_node_indices : list of int
An attribute that contains indices of the nodes concatenated to the output (a list of integers).
When the cell is first instantiated in the base model, or when ``merge_op`` is ``all``,
``output_node_indices`` must be ``range(num_predecessors, num_predecessors + num_nodes)``.
When ``merge_op`` is ``loose_end``, ``output_node_indices`` is useful to compute the shape of this cell's output,
because the output shape depends on the connection in the cell, and which nodes are "loose ends" depends on mutation.
""" """
def __init__(self, def __init__(self,
...@@ -176,6 +196,7 @@ class Cell(nn.Module): ...@@ -176,6 +196,7 @@ class Cell(nn.Module):
preprocessor: Optional[Callable[[List[torch.Tensor]], List[torch.Tensor]]] = None, preprocessor: Optional[Callable[[List[torch.Tensor]], List[torch.Tensor]]] = None,
postprocessor: Optional[Callable[[torch.Tensor, List[torch.Tensor]], postprocessor: Optional[Callable[[torch.Tensor, List[torch.Tensor]],
Union[Tuple[torch.Tensor, ...], torch.Tensor]]] = None, Union[Tuple[torch.Tensor, ...], torch.Tensor]]] = None,
concat_dim: int = 1,
*, *,
label: Optional[str] = None): label: Optional[str] = None):
super().__init__() super().__init__()
...@@ -197,6 +218,8 @@ class Cell(nn.Module): ...@@ -197,6 +218,8 @@ class Cell(nn.Module):
self.merge_op = merge_op self.merge_op = merge_op
self.output_node_indices = list(range(num_predecessors, num_predecessors + num_nodes)) self.output_node_indices = list(range(num_predecessors, num_predecessors + num_nodes))
self.concat_dim = concat_dim
# fill-in the missing modules # fill-in the missing modules
self._create_modules(op_candidates) self._create_modules(op_candidates)
...@@ -228,11 +251,28 @@ class Cell(nn.Module): ...@@ -228,11 +251,28 @@ class Cell(nn.Module):
def label(self): def label(self):
return self._label return self._label
def forward(self, x: List[torch.Tensor]): def forward(self, *inputs: Union[List[torch.Tensor], torch.Tensor]) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# The return type should be 'Union[Tuple[torch.Tensor, ...], torch.Tensor]'. """Forward propagation of cell.
# Cannot decorate it as annotation. Otherwise torchscript will complain.
assert isinstance(x, list), 'We currently only support input of cell as a list, even if you have only one predecessor.' Parameters
states = self.preprocessor(x) ----------
inputs
Can be a list of tensors, or several tensors.
The length should be equal to ``num_predecessors``.
Returns
-------
Tuple[torch.Tensor] | torch.Tensor
The return type depends on the output of ``postprocessor``.
By default, it's the output of ``merge_op``, which is a contenation (on ``concat_dim``)
of some of (possibly all) the nodes' outputs in the cell.
"""
if len(inputs) == 1 and isinstance(inputs[0], list):
inputs = inputs[0]
else:
inputs = list(inputs)
assert len(inputs) == self.num_predecessors, 'The number of inputs must be equal to `num_predecessors`.'
states = self.preprocessor(inputs)
for ops, inps in zip(self.ops, self.inputs): for ops, inps in zip(self.ops, self.inputs):
current_state = [] current_state = []
for op, inp in zip(ops, inps): for op, inp in zip(ops, inps):
...@@ -241,10 +281,10 @@ class Cell(nn.Module): ...@@ -241,10 +281,10 @@ class Cell(nn.Module):
states.append(current_state) states.append(current_state)
if self.merge_op == 'all': if self.merge_op == 'all':
# a special case for graph engine # a special case for graph engine
this_cell = torch.cat(states[self.num_predecessors:], 1) this_cell = torch.cat(states[self.num_predecessors:], self.concat_dim)
else: else:
this_cell = torch.cat([states[k] for k in self.output_node_indices], 1) this_cell = torch.cat([states[k] for k in self.output_node_indices], self.concat_dim)
return self.postprocessor(this_cell, x) return self.postprocessor(this_cell, inputs)
@staticmethod @staticmethod
def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]: def _convert_op_candidates(op_candidates, node_index, op_index, chosen) -> Union[Dict[str, nn.Module], List[nn.Module]]:
......
...@@ -277,6 +277,10 @@ class NasBench101Cell(Mutable): ...@@ -277,6 +277,10 @@ class NasBench101Cell(Mutable):
Maximum number of edges in the cell. Default: 9. Maximum number of edges in the cell. Default: 9.
label : str label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice. Identifier of the cell. Cell sharing the same label will semantically share the same choice.
Warnings
--------
:class:`NasBench101Cell` is not supported in :ref:`graph-based execution engine <graph-based-exeuction-engine>`.
""" """
@staticmethod @staticmethod
......
...@@ -776,82 +776,6 @@ class GraphIR(unittest.TestCase): ...@@ -776,82 +776,6 @@ class GraphIR(unittest.TestCase):
b = model_new(inp) b = model_new(inp)
self.assertLess((a - b).abs().max().item(), 1E-4) self.assertLess((a - b).abs().max().item(), 1E-4)
def test_cell(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(
torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64]))
@model_wrapper
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
def forward(self, x):
return self.cell([x])
raw_model, mutators = self._get_model_with_mutators(Net2())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_cell_predecessors(self):
from typing import List, Tuple
class Preprocessor(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 16)
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return [self.linear(x[0]), x[1]]
class Postprocessor(nn.Module):
def forward(self, this: torch.Tensor, prev: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return prev[-1], this
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': nn.Linear(16, 16),
'second': nn.Linear(16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2,
preprocessor=Preprocessor(), postprocessor=Postprocessor(), merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
result = self._get_converted_pytorch_model(model)(
torch.randn(1, 3), torch.randn(1, 16))
self.assertTrue(result[0].size() == torch.Size([1, 16]))
self.assertTrue(result[1].size() == torch.Size([1, 64]))
def test_nasbench201_cell(self): def test_nasbench201_cell(self):
@model_wrapper @model_wrapper
class Net(nn.Module): class Net(nn.Module):
...@@ -978,6 +902,82 @@ class Python(GraphIR): ...@@ -978,6 +902,82 @@ class Python(GraphIR):
with self.assertRaises(NoContextError): with self.assertRaises(NoContextError):
model = Net() model = Net()
def test_cell(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell(x, y)
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(
torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64]))
@model_wrapper
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
def forward(self, x):
return self.cell(x)
raw_model, mutators = self._get_model_with_mutators(Net2())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_cell_predecessors(self):
from typing import List, Tuple
class Preprocessor(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 16)
def forward(self, x: List[torch.Tensor]) -> List[torch.Tensor]:
return [self.linear(x[0]), x[1]]
class Postprocessor(nn.Module):
def forward(self, this: torch.Tensor, prev: List[torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
return prev[-1], this
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell({
'first': nn.Linear(16, 16),
'second': nn.Linear(16, 16, bias=False)
}, num_nodes=4, num_ops_per_node=2, num_predecessors=2,
preprocessor=Preprocessor(), postprocessor=Postprocessor(), merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
result = self._get_converted_pytorch_model(model)(
torch.randn(1, 3), torch.randn(1, 16))
self.assertTrue(result[0].size() == torch.Size([1, 16]))
self.assertTrue(result[1].size() == torch.Size([1, 64]))
def test_cell_loose_end(self): def test_cell_loose_end(self):
@model_wrapper @model_wrapper
class Net(nn.Module): class Net(nn.Module):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment