Unverified Commit 5df75c33 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] New API: Repeat and Cell (#3481)


Co-authored-by: default avatarquzha <Quanlu.Zhang@microsoft.com>
parent 2f3c3951
......@@ -18,6 +18,12 @@ Inline Mutation APIs
.. autoclass:: nni.retiarii.nn.pytorch.ChosenInputs
:members:
.. autoclass:: nni.retiarii.nn.pytorch.Repeat
:members:
.. autoclass:: nni.retiarii.nn.pytorch.Cell
:members:
Graph Mutation APIs
-------------------
......
......@@ -642,6 +642,16 @@ class GraphConverter:
ir_graph._register()
# add mutation signal for special modules
if original_type_name == OpTypeName.Repeat:
attrs = {
'mutation': 'repeat',
'label': module.label,
'min_depth': module.min_depth,
'max_depth': module.max_depth
}
return ir_graph, attrs
return ir_graph, {}
......
......@@ -17,3 +17,5 @@ class OpTypeName(str, Enum):
ValueChoice = 'ValueChoice'
Placeholder = 'Placeholder'
MergedSlice = 'MergedSlice'
Repeat = 'Repeat'
Cell = 'Cell'
from .api import *
from .component import *
from .nn import *
......@@ -10,26 +10,12 @@ import torch
import torch.nn as nn
from ...serializer import Translatable, basic_unit
from ...utils import uid, get_current_context
from .utils import generate_new_label, get_fixed_value
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
def _generate_new_label(label: Optional[str]):
if label is None:
return '_mutation_' + str(uid('mutation'))
return label
def _get_fixed_value(label: str):
ret = get_current_context('fixed')
try:
return ret[_generate_new_label(label)]
except KeyError:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
class LayerChoice(nn.Module):
"""
Layer choice selects one of the ``candidates``, then apply it on inputs and return results.
......@@ -69,9 +55,9 @@ class LayerChoice(nn.Module):
``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
"""
def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs):
def __new__(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs):
try:
chosen = _get_fixed_value(label)
chosen = get_fixed_value(label)
if isinstance(candidates, list):
return candidates[int(chosen)]
else:
......@@ -79,7 +65,7 @@ class LayerChoice(nn.Module):
except AssertionError:
return super().__new__(cls)
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: str = None, **kwargs):
def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], label: Optional[str] = None, **kwargs):
super(LayerChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
......@@ -89,7 +75,7 @@ class LayerChoice(nn.Module):
if 'reduction' in kwargs:
warnings.warn(f'"reduction" is deprecated. Ignoring...')
self.candidates = candidates
self._label = _generate_new_label(label)
self._label = generate_new_label(label)
self.names = []
if isinstance(candidates, OrderedDict):
......@@ -187,13 +173,13 @@ class InputChoice(nn.Module):
Identifier of the input choice.
"""
def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs):
def __new__(cls, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs):
try:
return ChosenInputs(_get_fixed_value(label), reduction=reduction)
return ChosenInputs(get_fixed_value(label), reduction=reduction)
except AssertionError:
return super().__new__(cls)
def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: str = None, **kwargs):
def __init__(self, n_candidates: int, n_chosen: int = 1, reduction: str = 'sum', label: Optional[str] = None, **kwargs):
super(InputChoice, self).__init__()
if 'key' in kwargs:
warnings.warn(f'"key" is deprecated. Assuming label.')
......@@ -206,7 +192,7 @@ class InputChoice(nn.Module):
self.n_chosen = n_chosen
self.reduction = reduction
assert self.reduction in ['mean', 'concat', 'sum', 'none']
self._label = _generate_new_label(label)
self._label = generate_new_label(label)
@property
def key(self):
......@@ -295,16 +281,16 @@ class ValueChoice(Translatable, nn.Module):
Identifier of the value choice.
"""
def __new__(cls, candidates: List[Any], label: str = None):
def __new__(cls, candidates: List[Any], label: Optional[str] = None):
try:
return _get_fixed_value(label)
return get_fixed_value(label)
except AssertionError:
return super().__new__(cls)
def __init__(self, candidates: List[Any], label: str = None):
def __init__(self, candidates: List[Any], label: Optional[str] = None):
super().__init__()
self.candidates = candidates
self._label = _generate_new_label(label)
self._label = generate_new_label(label)
self._accessor = []
@property
......
import copy
from typing import Callable, List, Union, Tuple, Optional
import torch
import torch.nn as nn
from .api import LayerChoice, InputChoice
from .nn import ModuleList
from .utils import generate_new_label, get_fixed_value
__all__ = ['Repeat', 'Cell']
class Repeat(nn.Module):
"""
Repeat a block by a variable number of times.
Parameters
----------
blocks : function, list of function, module or list of module
The block to be repeated. If not a list, it will be replicated into a list.
If a list, it should be of length ``max_depth``, the modules will be instantiated in order and a prefix will be taken.
If a function, it will be called to instantiate a module. Otherwise the module will be deep-copied.
depth : int or tuple of int
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
meaning that the block will be repeated at least `min` times and at most `max` times.
"""
def __new__(cls, blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]],
depth: Union[int, Tuple[int, int]], label: Optional[str] = None):
try:
repeat = get_fixed_value(label)
return nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat))
except AssertionError:
return super().__new__(cls)
def __init__(self,
blocks: Union[Callable[[], nn.Module], List[Callable[[], nn.Module]], nn.Module, List[nn.Module]],
depth: Union[int, Tuple[int, int]], label: Optional[str] = None):
super().__init__()
self._label = generate_new_label(label)
self.min_depth = depth if isinstance(depth, int) else depth[0]
self.max_depth = depth if isinstance(depth, int) else depth[1]
assert self.max_depth >= self.min_depth > 0
self.blocks = nn.ModuleList(self._replicate_and_instantiate(blocks, self.max_depth))
@property
def label(self):
return self._label
def forward(self, x):
for block in self.blocks:
x = block(x)
return x
@staticmethod
def _replicate_and_instantiate(blocks, repeat):
if not isinstance(blocks, list):
if isinstance(blocks, nn.Module):
blocks = [blocks] + [copy.deepcopy(blocks) for _ in range(repeat - 1)]
else:
blocks = [blocks for _ in range(repeat)]
assert len(blocks) > 0
assert repeat <= len(blocks), f'Not enough blocks to be used. {repeat} expected, only found {len(blocks)}.'
blocks = blocks[:repeat]
if not isinstance(blocks[0], nn.Module):
blocks = [b() for b in blocks]
return blocks
class Cell(nn.Module):
"""
Cell structure [1]_ [2]_ that is popularly used in NAS literature.
A cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from
``op_candidates``, and takes one input from previous nodes and predecessors. Predecessor means the input of cell.
The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes).
Parameters
----------
op_candidates : function or list of module
A list of modules to choose from, or a function that returns a list of modules.
num_nodes : int
Number of nodes in the cell.
num_ops_per_node: int
Number of operators in each node. The output of each node is the sum of all operators in the node. Default: 1.
num_predecessors : int
Number of inputs of the cell. The input to forward should be a list of tensors. Default: 1.
merge_op : str
Currently only ``all`` is supported, which has slight difference with that described in reference. Default: all.
label : str
Identifier of the cell. Cell sharing the same label will semantically share the same choice.
References
----------
.. [1] Barret Zoph, Quoc V. Le, "Neural Architecture Search with Reinforcement Learning". https://arxiv.org/abs/1611.01578
.. [2] Barret Zoph, Vijay Vasudevan, Jonathon Shlens, Quoc V. Le,
"Learning Transferable Architectures for Scalable Image Recognition". https://arxiv.org/abs/1707.07012
"""
# TODO:
# Support loose end concat (shape inference on the following cells)
# How to dynamically create convolution with stride as the first node
def __init__(self,
op_candidates: Union[Callable, List[nn.Module]],
num_nodes: int,
num_ops_per_node: int = 1,
num_predecessors: int = 1,
merge_op: str = 'all',
label: str = None):
super().__init__()
self._label = generate_new_label(label)
self.ops = ModuleList()
self.inputs = ModuleList()
self.num_nodes = num_nodes
self.num_ops_per_node = num_ops_per_node
self.num_predecessors = num_predecessors
for i in range(num_nodes):
self.ops.append(ModuleList())
self.inputs.append(ModuleList())
for k in range(num_ops_per_node):
if isinstance(op_candidates, list):
assert len(op_candidates) > 0 and isinstance(op_candidates[0], nn.Module)
ops = copy.deepcopy(op_candidates)
else:
ops = op_candidates()
self.ops[-1].append(LayerChoice(ops, label=f'{self.label}__op_{i}_{k}'))
self.inputs[-1].append(InputChoice(i + num_predecessors, 1, label=f'{self.label}/input_{i}_{k}'))
assert merge_op in ['all'] # TODO: loose_end
self.merge_op = merge_op
@property
def label(self):
return self._label
def forward(self, x: List[torch.Tensor]):
states = x
for ops, inps in zip(self.ops, self.inputs):
current_state = []
for op, inp in zip(ops, inps):
current_state.append(op(inp(states)))
current_state = torch.sum(torch.stack(current_state), 0)
states.append(current_state)
return torch.cat(states[self.num_predecessors:], 1)
......@@ -8,8 +8,9 @@ import torch.nn as nn
from ...mutator import Mutator
from ...graph import Cell, Graph, Model, ModelStatus, Node
from ...utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
from .component import Repeat
from ...utils import uid
class LayerChoiceMutator(Mutator):
......@@ -80,6 +81,42 @@ class ParameterChoiceMutator(Mutator):
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value})
class RepeatMutator(Mutator):
def __init__(self, nodes: List[Node]):
# nodes is a subgraph consisting of repeated blocks.
super().__init__()
self.nodes = nodes
def _retrieve_chain_from_graph(self, graph: Graph) -> List[Node]:
u = graph.input_node
chain = []
while u != graph.output_node:
if u != graph.input_node:
chain.append(u)
assert len(u.successors) == 1, f'This graph is an illegal chain. {u} has output {u.successor}.'
u = u.successors[0]
return chain
def mutate(self, model):
min_depth = self.nodes[0].operation.parameters['min_depth']
max_depth = self.nodes[0].operation.parameters['max_depth']
if min_depth < max_depth:
chosen_depth = self.choice(list(range(min_depth, max_depth + 1)))
for node in self.nodes:
# the logic here is similar to layer choice. We find cell attached to each node.
target: Graph = model.graphs[node.operation.cell_name]
chain = self._retrieve_chain_from_graph(target)
for edge in chain[chosen_depth - 1].outgoing_edges:
edge.remove()
target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None))
for rm_node in chain[chosen_depth:]:
for edge in rm_node.outgoing_edges:
edge.remove()
rm_node.remove()
# to delete the unused parameters.
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
applied_mutators = []
......@@ -120,6 +157,15 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = LayerChoiceMutator(node_list)
applied_mutators.append(mutator)
repeat_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'repeat',
model.get_nodes_by_type('_cell')))
for node_list in repeat_nodes:
assert _is_all_equal(map(lambda node: node.operation.parameters['max_depth'], node_list)) and \
_is_all_equal(map(lambda node: node.operation.parameters['min_depth'], node_list)), \
'Repeat with the same label must have the same number of candidates.'
mutator = RepeatMutator(node_list)
applied_mutators.append(mutator)
if applied_mutators:
return applied_mutators
return None
......@@ -190,6 +236,11 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if isinstance(module, ValueChoice):
node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates})
node.label = module.label
if isinstance(module, Repeat) and module.min_depth <= module.max_depth:
node = graph.add_node(name, 'Repeat', {
'candidates': list(range(module.min_depth, module.max_depth + 1))
})
node.label = module.label
if isinstance(module, Placeholder):
raise NotImplementedError('Placeholder is not supported in python execution mode.')
......
from typing import Optional
from ...utils import uid, get_current_context
def generate_new_label(label: Optional[str]):
if label is None:
return '_mutation_' + str(uid('mutation'))
return label
def get_fixed_value(label: str):
ret = get_current_context('fixed')
try:
return ret[generate_new_label(label)]
except KeyError:
raise KeyError(f'Fixed context with {label} not found. Existing values are: {ret}')
......@@ -379,7 +379,7 @@ class GraphIR(unittest.TestCase):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([[0.,], [1.,]])
self.dropout_rate = nn.ValueChoice([[0., ], [1., ]])
def forward(self, x):
return F.dropout(x, self.dropout_rate()[0])
......@@ -398,7 +398,7 @@ class GraphIR(unittest.TestCase):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([[1.05,], [1.1,]])
self.dropout_rate = nn.ValueChoice([[1.05, ], [1.1, ]])
def forward(self, x):
# if expression failed, the exception would be:
......@@ -414,6 +414,67 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_repeat(self):
class AddOne(nn.Module):
def forward(self, x):
return x + 1
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.block = nn.Repeat(AddOne(), (3, 5))
def forward(self, x):
return self.block(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
model3 = mutator.apply(model)
self.assertTrue((self._get_converted_pytorch_model(model1)(torch.zeros(1, 16)) == 3).all())
self.assertTrue((self._get_converted_pytorch_model(model2)(torch.zeros(1, 16)) == 4).all())
self.assertTrue((self._get_converted_pytorch_model(model3)(torch.zeros(1, 16)) == 5).all())
def test_cell(self):
@self.get_serializer()
class Net(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)],
num_nodes=4, num_ops_per_node=2, num_predecessors=2, merge_op='all')
def forward(self, x, y):
return self.cell([x, y])
raw_model, mutators = self._get_model_with_mutators(Net())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(
torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64]))
@self.get_serializer()
class Net2(nn.Module):
def __init__(self):
super().__init__()
self.cell = nn.Cell([nn.Linear(16, 16), nn.Linear(16, 16, bias=False)], num_nodes=4)
def forward(self, x):
return self.cell([x])
raw_model, mutators = self._get_model_with_mutators(Net2())
for _ in range(10):
sampler = EnumerateSampler()
model = raw_model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
class Python(GraphIR):
def _get_converted_pytorch_model(self, model_ir):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment