Unverified Commit ba771871 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Support ValueChoice as depth in Repeat (#4598)

parent c5e3bad9
...@@ -660,8 +660,9 @@ class GraphConverter: ...@@ -660,8 +660,9 @@ class GraphConverter:
attrs = { attrs = {
'mutation': 'repeat', 'mutation': 'repeat',
'label': module.label, 'label': module.label,
'depth': module.depth_choice,
'max_depth': module.max_depth,
'min_depth': module.min_depth, 'min_depth': module.min_depth,
'max_depth': module.max_depth
} }
return ir_graph, attrs return ir_graph, attrs
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import math import math
import itertools
import operator import operator
import warnings import warnings
from typing import Any, List, Union, Dict, Optional, Callable, Iterable, NoReturn, TypeVar from typing import Any, List, Union, Dict, Optional, Callable, Iterable, NoReturn, TypeVar
...@@ -439,6 +440,30 @@ class ValueChoiceX(Translatable): ...@@ -439,6 +440,30 @@ class ValueChoiceX(Translatable):
# values are not used # values are not used
return self._evaluate(iter([]), True) return self._evaluate(iter([]), True)
def all_options(self) -> Iterable[Any]:
"""Explore all possibilities of a value choice.
"""
# Record all inner choices: label -> candidates, no duplicates.
dedup_inner_choices: Dict[str, List[Any]] = {}
# All labels of leaf nodes on tree, possibly duplicates.
all_labels: List[str] = []
for choice in self.inner_choices():
all_labels.append(choice.label)
if choice.label in dedup_inner_choices:
if choice.candidates != dedup_inner_choices[choice.label]:
# check for choice with the same label
raise ValueError(f'"{choice.candidates}" is not equal to "{dedup_inner_choices[choice.label]}", '
f'but they share the same label: {choice.label}')
else:
dedup_inner_choices[choice.label] = choice.candidates
dedup_labels, dedup_candidates = list(dedup_inner_choices.keys()), list(dedup_inner_choices.values())
for chosen in itertools.product(*dedup_candidates):
chosen = dict(zip(dedup_labels, chosen))
yield self.evaluate([chosen[label] for label in all_labels])
def evaluate(self, values: Iterable[Any]) -> Any: def evaluate(self, values: Iterable[Any]) -> Any:
""" """
Evaluate the result of this group. Evaluate the result of this group.
......
import copy import copy
import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Callable, List, Union, Tuple, Optional from typing import Callable, List, Union, Tuple, Optional
import torch import torch
import torch.nn as nn import torch.nn as nn
from nni.retiarii.utils import STATE_DICT_PY_MAPPING_PARTIAL from nni.retiarii.utils import NoContextError, STATE_DICT_PY_MAPPING_PARTIAL
from .api import LayerChoice from .api import LayerChoice, ValueChoice, ValueChoiceX
from .cell import Cell from .cell import Cell
from .nasbench101 import NasBench101Cell, NasBench101Mutator from .nasbench101 import NasBench101Cell, NasBench101Mutator
from .mutation_utils import Mutable, generate_new_label, get_fixed_value from .mutation_utils import Mutable, generate_new_label, get_fixed_value
...@@ -30,7 +31,7 @@ class Repeat(Mutable): ...@@ -30,7 +31,7 @@ class Repeat(Mutable):
depth : int or tuple of int depth : int or tuple of int
If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max), If one number, the block will be repeated by a fixed number of times. If a tuple, it should be (min, max),
meaning that the block will be repeated at least ``min`` times and at most ``max`` times. meaning that the block will be repeated at least ``min`` times and at most ``max`` times.
If a ValueChoice, it should choose from a series of positive integers.
Examples Examples
-------- --------
...@@ -51,6 +52,10 @@ class Repeat(Mutable): ...@@ -51,6 +52,10 @@ class Repeat(Mutable):
we need a factory function that accepts index (0, 1, 2, ...) and returns the module of the ``index``-th layer. :: we need a factory function that accepts index (0, 1, 2, ...) and returns the module of the ``index``-th layer. ::
self.blocks = nn.Repeat(lambda index: nn.LayerChoice([...], label=f'layer{index}'), (1, 3)) self.blocks = nn.Repeat(lambda index: nn.LayerChoice([...], label=f'layer{index}'), (1, 3))
Depth can be a ValueChoice to support arbitrary depth candidate list. ::
self.blocks = nn.Repeat(Block(), nn.ValueChoice([1, 3, 5]))
""" """
@classmethod @classmethod
...@@ -59,17 +64,26 @@ class Repeat(Mutable): ...@@ -59,17 +64,26 @@ class Repeat(Mutable):
List[Callable[[int], nn.Module]], List[Callable[[int], nn.Module]],
nn.Module, nn.Module,
List[nn.Module]], List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None): depth: Union[int, Tuple[int, int], ValueChoice], *, label: Optional[str] = None):
repeat = get_fixed_value(label) if isinstance(depth, tuple):
result = nn.Sequential(*cls._replicate_and_instantiate(blocks, repeat)) # we can't create a value choice here,
# otherwise we will have two value choices, one created here, another in init.
if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL): depth = get_fixed_value(label)
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL) if isinstance(depth, int):
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'blocks.{v}' for k, v in prev_mapping.items()}) # if depth is a valuechoice, it should be already an int
else: result = nn.Sequential(*cls._replicate_and_instantiate(blocks, depth))
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': 'blocks'})
return result if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
# already has a mapping, will merge with it
prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'blocks.{v}' for k, v in prev_mapping.items()})
else:
setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': 'blocks'})
return result
raise NoContextError(f'Not in fixed mode, or {depth} not an integer.')
def __init__(self, def __init__(self,
blocks: Union[Callable[[int], nn.Module], blocks: Union[Callable[[int], nn.Module],
...@@ -78,15 +92,32 @@ class Repeat(Mutable): ...@@ -78,15 +92,32 @@ class Repeat(Mutable):
List[nn.Module]], List[nn.Module]],
depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None): depth: Union[int, Tuple[int, int]], *, label: Optional[str] = None):
super().__init__() super().__init__()
self._label = generate_new_label(label)
self.min_depth = depth if isinstance(depth, int) else depth[0] if isinstance(depth, ValueChoiceX):
self.max_depth = depth if isinstance(depth, int) else depth[1] if label is not None:
warnings.warn(
'In repeat, `depth` is already a ValueChoice, but `label` is still set. It will be ignored.',
RuntimeWarning
)
self.depth_choice = depth
all_values = list(self.depth_choice.all_options())
self.min_depth = min(all_values)
self.max_depth = max(all_values)
elif isinstance(depth, tuple):
self.min_depth = depth if isinstance(depth, int) else depth[0]
self.max_depth = depth if isinstance(depth, int) else depth[1]
self.depth_choice = ValueChoice(list(range(self.min_depth, self.max_depth + 1)), label=label)
elif isinstance(depth, int):
self.min_depth = self.max_depth = depth
self.depth_choice = depth
else:
raise TypeError(f'Unsupported "depth" type: {type(depth)}')
assert self.max_depth >= self.min_depth > 0 assert self.max_depth >= self.min_depth > 0
self.blocks = nn.ModuleList(self._replicate_and_instantiate(blocks, self.max_depth)) self.blocks = nn.ModuleList(self._replicate_and_instantiate(blocks, self.max_depth))
@property @property
def label(self): def label(self):
return self._label return self.depth_choice.label
def forward(self, x): def forward(self, x):
for block in self.blocks: for block in self.blocks:
...@@ -107,6 +138,10 @@ class Repeat(Mutable): ...@@ -107,6 +138,10 @@ class Repeat(Mutable):
blocks = [b(i) for i, b in enumerate(blocks)] blocks = [b(i) for i, b in enumerate(blocks)]
return blocks return blocks
def __getitem__(self, index):
# shortcut for blocks[index]
return self.blocks[index]
class NasBench201Cell(nn.Module): class NasBench201Cell(nn.Module):
""" """
......
...@@ -14,7 +14,7 @@ from nni.retiarii.serializer import is_basic_unit, is_model_wrapped ...@@ -14,7 +14,7 @@ from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from nni.retiarii.utils import uid from nni.retiarii.utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder from .api import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder
from .component import Repeat, NasBench101Cell, NasBench101Mutator from .component import NasBench101Cell, NasBench101Mutator
class LayerChoiceMutator(Mutator): class LayerChoiceMutator(Mutator):
...@@ -144,14 +144,15 @@ class RepeatMutator(Mutator): ...@@ -144,14 +144,15 @@ class RepeatMutator(Mutator):
return chain return chain
def mutate(self, model): def mutate(self, model):
min_depth = self.nodes[0].operation.parameters['min_depth']
max_depth = self.nodes[0].operation.parameters['max_depth']
if min_depth < max_depth:
chosen_depth = self.choice(list(range(min_depth, max_depth + 1)))
for node in self.nodes: for node in self.nodes:
# the logic here is similar to layer choice. We find cell attached to each node. # the logic here is similar to layer choice. We find cell attached to each node.
target: Graph = model.graphs[node.operation.cell_name] target: Graph = model.graphs[node.operation.cell_name]
chain = self._retrieve_chain_from_graph(target) chain = self._retrieve_chain_from_graph(target)
# and we get the chosen depth (by value choice)
node_in_model = model.get_node_by_name(node.name)
# depth is a value choice in base model
# but it's already mutated by a ParameterChoiceMutator here
chosen_depth = node_in_model.operation.parameters['depth']
for edge in chain[chosen_depth - 1].outgoing_edges: for edge in chain[chosen_depth - 1].outgoing_edges:
edge.remove() edge.remove()
target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None)) target.add_edge((chain[chosen_depth - 1], None), (target.output_node, None))
...@@ -184,6 +185,8 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: ...@@ -184,6 +185,8 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
# `pc_nodes` are arguments of basic units. They can be compositions. # `pc_nodes` are arguments of basic units. They can be compositions.
pc_nodes: List[Tuple[Node, str, ValueChoiceX]] = [] pc_nodes: List[Tuple[Node, str, ValueChoiceX]] = []
for node in model.get_nodes(): for node in model.get_nodes():
# arguments used in operators like Conv2d
# argument `valuechoice` used in generated repeat cell
for name, choice in node.operation.parameters.items(): for name, choice in node.operation.parameters.items():
if isinstance(choice, ValueChoiceX): if isinstance(choice, ValueChoiceX):
# e.g., (conv_node, "out_channels", ValueChoice([1, 3])) # e.g., (conv_node, "out_channels", ValueChoice([1, 3]))
...@@ -219,9 +222,10 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: ...@@ -219,9 +222,10 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
repeat_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'repeat', repeat_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'repeat',
model.get_nodes_by_type('_cell'))) model.get_nodes_by_type('_cell')))
for node_list in repeat_nodes: for node_list in repeat_nodes:
# this check is not completely reliable, because it only checks max and min
assert _is_all_equal(map(lambda node: node.operation.parameters['max_depth'], node_list)) and \ assert _is_all_equal(map(lambda node: node.operation.parameters['max_depth'], node_list)) and \
_is_all_equal(map(lambda node: node.operation.parameters['min_depth'], node_list)), \ _is_all_equal(map(lambda node: node.operation.parameters['min_depth'], node_list)), \
'Repeat with the same label must have the same number of candidates.' 'Repeat with the same label must have the same candidates.'
mutator = RepeatMutator(node_list) mutator = RepeatMutator(node_list)
applied_mutators.append(mutator) applied_mutators.append(mutator)
...@@ -303,11 +307,6 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op ...@@ -303,11 +307,6 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
if isinstance(module, ValueChoice): if isinstance(module, ValueChoice):
node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates}) node = graph.add_node(name, 'ValueChoice', {'candidates': module.candidates})
node.label = module.label node.label = module.label
if isinstance(module, Repeat) and module.min_depth <= module.max_depth:
node = graph.add_node(name, 'Repeat', {
'candidates': list(range(module.min_depth, module.max_depth + 1))
})
node.label = module.label
if isinstance(module, NasBench101Cell): if isinstance(module, NasBench101Cell):
node = graph.add_node(name, 'NasBench101Cell', { node = graph.add_node(name, 'NasBench101Cell', {
'max_num_edges': module.max_num_edges 'max_num_edges': module.max_num_edges
......
...@@ -66,6 +66,8 @@ def _apply_all_mutators(model, mutators, samplers): ...@@ -66,6 +66,8 @@ def _apply_all_mutators(model, mutators, samplers):
class GraphIR(unittest.TestCase): class GraphIR(unittest.TestCase):
# graph engine will have an extra mutator for parameter choices # graph engine will have an extra mutator for parameter choices
value_choice_incr = 1 value_choice_incr = 1
# graph engine has an extra mutator to apply the depth choice to nodes
repeat_incr = 1
def _convert_to_ir(self, model): def _convert_to_ir(self, model):
script_module = torch.jit.script(model) script_module = torch.jit.script(model)
...@@ -578,14 +580,39 @@ class GraphIR(unittest.TestCase): ...@@ -578,14 +580,39 @@ class GraphIR(unittest.TestCase):
return self.block(x) return self.block(x)
model, mutators = self._get_model_with_mutators(Net()) model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1 + self.repeat_incr + self.value_choice_incr)
mutator = mutators[0].bind_sampler(EnumerateSampler()) samplers = [EnumerateSampler() for _ in range(len(mutators))]
model1 = mutator.apply(model) for target in [3, 4, 5]:
model2 = mutator.apply(model) new_model = _apply_all_mutators(model, mutators, samplers)
model3 = mutator.apply(model) self.assertTrue((self._get_converted_pytorch_model(new_model)(torch.zeros(1, 16)) == target).all())
self.assertTrue((self._get_converted_pytorch_model(model1)(torch.zeros(1, 16)) == 3).all())
self.assertTrue((self._get_converted_pytorch_model(model2)(torch.zeros(1, 16)) == 4).all()) def test_repeat_static(self):
self.assertTrue((self._get_converted_pytorch_model(model3)(torch.zeros(1, 16)) == 5).all()) class AddOne(nn.Module):
def forward(self, x):
return x + 1
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.block = nn.Repeat(lambda index: nn.LayerChoice([AddOne(), nn.Identity()]), 4)
def forward(self, x):
return self.block(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 4)
sampler = RandomSampler()
result = []
for _ in range(50):
new_model = model
for mutator in mutators:
new_model = mutator.bind_sampler(sampler).apply(new_model)
result.append(self._get_converted_pytorch_model(new_model)(torch.zeros(1, 1)).item())
for x in [1, 2, 3]:
self.assertIn(float(x), result)
def test_repeat_complex(self): def test_repeat_complex(self):
class AddOne(nn.Module): class AddOne(nn.Module):
...@@ -602,8 +629,8 @@ class GraphIR(unittest.TestCase): ...@@ -602,8 +629,8 @@ class GraphIR(unittest.TestCase):
return self.block(x) return self.block(x)
model, mutators = self._get_model_with_mutators(Net()) model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2) self.assertEqual(len(mutators), 2 + self.repeat_incr + self.value_choice_incr)
self.assertEqual(set([mutator.label for mutator in mutators]), {'lc', 'rep'}) self.assertEqual(set([mutator.label for mutator in mutators if mutator.label is not None]), {'lc', 'rep'})
sampler = RandomSampler() sampler = RandomSampler()
for _ in range(10): for _ in range(10):
...@@ -624,7 +651,7 @@ class GraphIR(unittest.TestCase): ...@@ -624,7 +651,7 @@ class GraphIR(unittest.TestCase):
return self.block(x) return self.block(x)
model, mutators = self._get_model_with_mutators(Net()) model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 4) self.assertEqual(len(mutators), 4 + self.repeat_incr + self.value_choice_incr)
result = [] result = []
for _ in range(20): for _ in range(20):
...@@ -635,6 +662,27 @@ class GraphIR(unittest.TestCase): ...@@ -635,6 +662,27 @@ class GraphIR(unittest.TestCase):
self.assertIn(1., result) self.assertIn(1., result)
def test_repeat_valuechoice(self):
class AddOne(nn.Module):
def forward(self, x):
return x + 1
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.block = nn.Repeat(AddOne(), nn.ValueChoice([1, 3, 5]))
def forward(self, x):
return self.block(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1 + self.repeat_incr + self.value_choice_incr)
samplers = [EnumerateSampler() for _ in range(len(mutators))]
for target in [1, 3, 5]:
new_model = _apply_all_mutators(model, mutators, samplers)
self.assertTrue((self._get_converted_pytorch_model(new_model)(torch.zeros(1, 16)) == target).all())
def test_repeat_weight_inheritance(self): def test_repeat_weight_inheritance(self):
@model_wrapper @model_wrapper
class Net(nn.Module): class Net(nn.Module):
...@@ -647,11 +695,11 @@ class GraphIR(unittest.TestCase): ...@@ -647,11 +695,11 @@ class GraphIR(unittest.TestCase):
orig_model = Net() orig_model = Net()
model, mutators = self._get_model_with_mutators(orig_model) model, mutators = self._get_model_with_mutators(orig_model)
mutator = mutators[0].bind_sampler(EnumerateSampler()) samplers = [EnumerateSampler() for _ in range(len(mutators))]
inp = torch.randn(1, 3, 5, 5) inp = torch.randn(1, 3, 5, 5)
for i in range(4): for i in range(4):
model_new = self._get_converted_pytorch_model(mutator.apply(model)) model_new = self._get_converted_pytorch_model(_apply_all_mutators(model, mutators, samplers))
with original_state_dict_hooks(model_new): with original_state_dict_hooks(model_new):
model_new.load_state_dict(orig_model.state_dict(), strict=False) model_new.load_state_dict(orig_model.state_dict(), strict=False)
...@@ -778,6 +826,7 @@ class GraphIR(unittest.TestCase): ...@@ -778,6 +826,7 @@ class GraphIR(unittest.TestCase):
class Python(GraphIR): class Python(GraphIR):
# Python engine doesn't have the extra mutator # Python engine doesn't have the extra mutator
value_choice_incr = 0 value_choice_incr = 0
repeat_incr = 0
def _get_converted_pytorch_model(self, model_ir): def _get_converted_pytorch_model(self, model_ir):
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history} mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
...@@ -891,6 +940,8 @@ class Shared(unittest.TestCase): ...@@ -891,6 +940,8 @@ class Shared(unittest.TestCase):
elif i == 2: elif i == 2:
assert choice.candidates == [5, 6] assert choice.candidates == [5, 6]
assert d.evaluate([2, 3, 5]) == 20 assert d.evaluate([2, 3, 5]) == 20
expect = [x + y + 3 * z for x in [1, 2] for y in [3, 4] for z in [5, 6]]
assert list(d.all_options()) == expect
a = nn.ValueChoice(['cat', 'dog']) a = nn.ValueChoice(['cat', 'dog'])
b = nn.ValueChoice(['milk', 'coffee']) b = nn.ValueChoice(['milk', 'coffee'])
...@@ -967,6 +1018,9 @@ class Shared(unittest.TestCase): ...@@ -967,6 +1018,9 @@ class Shared(unittest.TestCase):
lst = [value if choice.label == 'value' else divisor for choice in result.inner_choices()] lst = [value if choice.label == 'value' else divisor for choice in result.inner_choices()]
assert result.evaluate(lst) == original_make_divisible(value, divisor) assert result.evaluate(lst) == original_make_divisible(value, divisor)
assert len(list(result.all_options())) == 30
assert max(result.all_options()) == 135
def test_valuechoice_in_evaluator(self): def test_valuechoice_in_evaluator(self):
def foo(): def foo():
pass pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment