Unverified Commit 6808708d authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Nest `ValueChoice` in `LayerChoice` and dict/list in `ValueChoice` (#3508)

parent b7062b5d
...@@ -6,7 +6,7 @@ import re ...@@ -6,7 +6,7 @@ import re
import torch import torch
from ..graph import Graph, Model, Node from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder from ..nn.pytorch import InputChoice, Placeholder
from ..operation import Cell, Operation from ..operation import Cell, Operation
from ..serializer import get_init_parameters_or_fail from ..serializer import get_init_parameters_or_fail
from ..utils import get_importable_name from ..utils import get_importable_name
...@@ -343,7 +343,7 @@ class GraphConverter: ...@@ -343,7 +343,7 @@ class GraphConverter:
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs) subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
if isinstance(submodule_obj, Placeholder): if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label) subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, (LayerChoice, InputChoice)): elif isinstance(submodule_obj, InputChoice):
subcell.update_label(sub_m_attrs['label']) subcell.update_label(sub_m_attrs['label'])
else: else:
# Graph already created, create Cell for it # Graph already created, create Cell for it
...@@ -536,16 +536,6 @@ class GraphConverter: ...@@ -536,16 +536,6 @@ class GraphConverter:
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr') self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
self.merge_aten_slices(ir_graph) self.merge_aten_slices(ir_graph)
def _handle_layerchoice(self, module):
choices = []
for cand in list(module):
cand_type = '__torch__.' + get_importable_name(cand.__class__)
choices.append({'type': cand_type, 'parameters': get_init_parameters_or_fail(cand)})
return {
'candidates': choices,
'label': module.label
}
def _handle_inputchoice(self, module): def _handle_inputchoice(self, module):
return { return {
'n_candidates': module.n_candidates, 'n_candidates': module.n_candidates,
...@@ -557,7 +547,8 @@ class GraphConverter: ...@@ -557,7 +547,8 @@ class GraphConverter:
def _handle_valuechoice(self, module): def _handle_valuechoice(self, module):
return { return {
'candidates': module.candidates, 'candidates': module.candidates,
'label': module.label 'label': module.label,
'accessor': module._accessor
} }
def convert_module(self, script_module, module, module_name, ir_model): def convert_module(self, script_module, module, module_name, ir_model):
...@@ -590,7 +581,13 @@ class GraphConverter: ...@@ -590,7 +581,13 @@ class GraphConverter:
if original_type_name in MODULE_EXCEPT_LIST: if original_type_name in MODULE_EXCEPT_LIST:
pass # do nothing pass # do nothing
elif original_type_name == OpTypeName.LayerChoice: elif original_type_name == OpTypeName.LayerChoice:
m_attrs = self._handle_layerchoice(module) graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
candidate_name_list = [f'layerchoice_{module.label}_{cand_name}' for cand_name in module.names]
for cand_name, cand in zip(candidate_name_list, module):
cand_type = '__torch__.' + get_importable_name(cand.__class__)
graph.add_node(cand_name, cand_type, get_init_parameters_or_fail(cand))
graph._register()
return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
elif original_type_name == OpTypeName.InputChoice: elif original_type_name == OpTypeName.InputChoice:
m_attrs = self._handle_inputchoice(module) m_attrs = self._handle_inputchoice(module)
elif original_type_name == OpTypeName.ValueChoice: elif original_type_name == OpTypeName.ValueChoice:
......
...@@ -144,11 +144,13 @@ class Model: ...@@ -144,11 +144,13 @@ class Model:
for graph_name, graph_data in ir.items(): for graph_name, graph_data in ir.items():
if graph_name != '_evaluator': if graph_name != '_evaluator':
Graph._load(model, graph_name, graph_data)._register() Graph._load(model, graph_name, graph_data)._register()
if '_evaluator' in ir:
model.evaluator = Evaluator._load_with_type(ir['_evaluator']['__type__'], ir['_evaluator']) model.evaluator = Evaluator._load_with_type(ir['_evaluator']['__type__'], ir['_evaluator'])
return model return model
def _dump(self) -> Any: def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()} ret = {name: graph._dump() for name, graph in self.graphs.items()}
if self.evaluator is not None:
ret['_evaluator'] = { ret['_evaluator'] = {
'__type__': get_importable_name(self.evaluator.__class__), '__type__': get_importable_name(self.evaluator.__class__),
**self.evaluator._dump() **self.evaluator._dump()
......
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
import copy
import warnings
from collections import OrderedDict from collections import OrderedDict
from typing import Any, List, Union, Dict from typing import Any, List, Union, Dict
import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -268,6 +269,7 @@ class ValueChoice(Translatable, nn.Module): ...@@ -268,6 +269,7 @@ class ValueChoice(Translatable, nn.Module):
super().__init__() super().__init__()
self.candidates = candidates self.candidates = candidates
self._label = label if label is not None else f'valuechoice_{uid()}' self._label = label if label is not None else f'valuechoice_{uid()}'
self._accessor = []
@property @property
def label(self): def label(self):
...@@ -279,11 +281,36 @@ class ValueChoice(Translatable, nn.Module): ...@@ -279,11 +281,36 @@ class ValueChoice(Translatable, nn.Module):
def _translate(self): def _translate(self):
# Will function as a value when used in serializer. # Will function as a value when used in serializer.
return self.candidates[0] return self.access(self.candidates[0])
def __repr__(self): def __repr__(self):
return f'ValueChoice({self.candidates}, label={repr(self.label)})' return f'ValueChoice({self.candidates}, label={repr(self.label)})'
def access(self, value):
if not self._accessor:
return value
try:
v = value
for a in self._accessor:
v = v[a]
except KeyError:
raise KeyError(''.join([f'[{a}]' for a in self._accessor]) + f' does not work on {value}')
return v
def __getitem__(self, item):
"""
Get a sub-element of value choice.
The underlying implementation is to clone the current instance, and append item to "accessor", which records all
the history getitem calls. For example, when accessor is ``[a, b, c]``, the value choice will return ``vc[a][b][c]``
where ``vc`` is the original value choice.
"""
access = copy.deepcopy(self)
access._accessor.append(item)
for candidate in self.candidates:
access.access(candidate)
return access
@basic_unit @basic_unit
class Placeholder(nn.Module): class Placeholder(nn.Module):
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
from ...mutator import Mutator from ...mutator import Mutator
from ...graph import Model, Node from ...graph import Cell, Model, Node
from .api import ValueChoice from .api import ValueChoice
...@@ -14,13 +14,23 @@ class LayerChoiceMutator(Mutator): ...@@ -14,13 +14,23 @@ class LayerChoiceMutator(Mutator):
self.nodes = nodes self.nodes = nodes
def mutate(self, model): def mutate(self, model):
n_candidates = len(self.nodes[0].operation.parameters['candidates']) candidates = self.nodes[0].operation.parameters['candidates']
indices = list(range(n_candidates)) chosen = self.choice(candidates)
chosen_index = self.choice(indices)
for node in self.nodes: for node in self.nodes:
target = model.get_node_by_name(node.name) # Each layer choice corresponds to a cell, which is unconnected in the base graph.
chosen_cand = node.operation.parameters['candidates'][chosen_index] # We add the connections here in the mutation logic.
target.update_operation(chosen_cand['type'], chosen_cand['parameters']) # Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target = model.graphs[node.operation.cell_name]
chosen_node = target.get_node_by_name(chosen)
assert chosen_node is not None
target.add_edge((target.input_node, 0), (chosen_node, None))
target.add_edge((chosen_node, None), (target.output_node, None))
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
# remove redundant nodes
for rm_node in target.hidden_nodes:
if rm_node.name != chosen_node.name:
rm_node.remove()
class InputChoiceMutator(Mutator): class InputChoiceMutator(Mutator):
...@@ -61,20 +71,14 @@ class ParameterChoiceMutator(Mutator): ...@@ -61,20 +71,14 @@ class ParameterChoiceMutator(Mutator):
def mutate(self, model): def mutate(self, model):
chosen = self.choice(self.candidates) chosen = self.choice(self.candidates)
for node, argname in self.nodes: for node, argname in self.nodes:
chosen_value = node.operation.parameters[argname].access(chosen)
target = model.get_node_by_name(node.name) target = model.get_node_by_name(node.name)
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen}) target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value})
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
applied_mutators = [] applied_mutators = []
lc_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.LayerChoice'))
for node_list in lc_nodes:
assert _is_all_equal(map(lambda node: len(node.operation.parameters['candidates']), node_list)), \
'Layer choice with the same label must have the same number of candidates.'
mutator = LayerChoiceMutator(node_list)
applied_mutators.append(mutator)
ic_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.InputChoice')) ic_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.InputChoice'))
for node_list in ic_nodes: for node_list in ic_nodes:
assert _is_all_equal(map(lambda node: node.operation.parameters['n_candidates'], node_list)) and \ assert _is_all_equal(map(lambda node: node.operation.parameters['n_candidates'], node_list)) and \
...@@ -99,9 +103,20 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: ...@@ -99,9 +103,20 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
for node_list in pc_nodes: for node_list in pc_nodes:
assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \ assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \
'Value choice with the same label must have the same candidates.' 'Value choice with the same label must have the same candidates.'
mutator = ParameterChoiceMutator(node_list, node_list[0][0].operation.parameters[node_list[0][1]].candidates) first_node, first_argname = node_list[0]
mutator = ParameterChoiceMutator(node_list, first_node.operation.parameters[first_argname].candidates)
applied_mutators.append(mutator) applied_mutators.append(mutator)
# apply layer choice at last as it will delete some nodes
lc_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'layerchoice',
model.get_nodes_by_type('_cell')))
for node_list in lc_nodes:
assert _is_all_equal(map(lambda node: len(node.operation.parameters['candidates']), node_list)), \
'Layer choice with the same label must have the same number of candidates.'
mutator = LayerChoiceMutator(node_list)
applied_mutators.append(mutator)
if applied_mutators: if applied_mutators:
return applied_mutators return applied_mutators
return None return None
......
...@@ -69,6 +69,9 @@ class PrimConstant(PyTorchOperation): ...@@ -69,6 +69,9 @@ class PrimConstant(PyTorchOperation):
elif self.parameters['type'] == 'Device': elif self.parameters['type'] == 'Device':
value = self.parameters['value'] value = self.parameters['value']
return f'{output} = torch.device("{value}")' return f'{output} = torch.device("{value}")'
elif self.parameters['type'] in ('dict', 'list', 'tuple'):
# TODO: prim::TupleIndex is not supported yet
return f'{output} = {repr(self.parameters["value"])}'
else: else:
raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}') raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}')
......
import random import random
import unittest import unittest
from collections import Counter
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
import torch import torch
...@@ -252,6 +253,30 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -252,6 +253,30 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3])) self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_value_choice_in_layer_choice(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.LayerChoice([
nn.Linear(3, nn.ValueChoice([10, 20])),
nn.Linear(3, nn.ValueChoice([30, 40]))
])
def forward(self, x):
return self.linear(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 3)
sz_counter = Counter()
sampler = RandomSampler()
for i in range(100):
model_new = model
for mutator in mutators:
model_new = mutator.bind_sampler(sampler).apply(model_new)
sz_counter[self._get_converted_pytorch_model(model_new)(torch.randn(1, 3)).size(1)] += 1
self.assertEqual(len(sz_counter), 4)
def test_shared(self): def test_shared(self):
class Net(nn.Module): class Net(nn.Module):
def __init__(self, shared=True): def __init__(self, shared=True):
...@@ -284,12 +309,94 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -284,12 +309,94 @@ class TestHighLevelAPI(unittest.TestCase):
# repeat test. Expectation: sometimes succeeds, sometimes fails. # repeat test. Expectation: sometimes succeeds, sometimes fails.
failed_count = 0 failed_count = 0
for i in range(30): for i in range(30):
model_new = model
for mutator in mutators: for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model) model_new = mutator.bind_sampler(sampler).apply(model_new)
self.assertEqual(sampler.counter, 2 * (i + 1)) self.assertEqual(sampler.counter, 2 * (i + 1))
try: try:
self._get_converted_pytorch_model(model)(torch.randn(1, 3, 3, 3)) self._get_converted_pytorch_model(model_new)(torch.randn(1, 3, 3, 3))
except RuntimeError: except RuntimeError:
failed_count += 1 failed_count += 1
self.assertGreater(failed_count, 0) self.assertGreater(failed_count, 0)
self.assertLess(failed_count, 30) self.assertLess(failed_count, 30)
def test_valuechoice_access(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
vc = nn.ValueChoice([(6, 3), (8, 5)])
self.conv = nn.Conv2d(3, vc[0], kernel_size=vc[1])
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
torch.Size([1, 8, 1, 1]))
class Net2(nn.Module):
def __init__(self):
super().__init__()
choices = [
{'b': [3], 'bp': [6]},
{'b': [6], 'bp': [12]}
]
self.conv = nn.Conv2d(3, nn.ValueChoice(choices, label='a')['b'][0], 1)
self.conv1 = nn.Conv2d(nn.ValueChoice(choices, label='a')['bp'][0], 3, 1)
def forward(self, x):
x = self.conv(x)
return self.conv1(torch.cat((x, x), 1))
model = self._convert_to_ir(Net2())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
self._get_converted_pytorch_model(mutators[0].apply(model))(input)
def test_valuechoice_access_functional(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([[0.,], [1.,]])
def forward(self, x):
return F.dropout(x, self.dropout_rate()[0])
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_access_functional_expression(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([[1.05,], [1.1,]])
def forward(self, x):
# if expression failed, the exception would be:
# ValueError: dropout probability has to be between 0 and 1, but got 1.05
return F.dropout(x, self.dropout_rate()[0] - .1)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
...@@ -62,11 +62,11 @@ class Net(nn.Module): ...@@ -62,11 +62,11 @@ class Net(nn.Module):
self.fc1 = nn.LayerChoice([ self.fc1 = nn.LayerChoice([
nn.Linear(4*4*50, hidden_size, bias=True), nn.Linear(4*4*50, hidden_size, bias=True),
nn.Linear(4*4*50, hidden_size, bias=False) nn.Linear(4*4*50, hidden_size, bias=False)
]) ], label='fc1')
self.fc2 = nn.LayerChoice([ self.fc2 = nn.LayerChoice([
nn.Linear(hidden_size, 10, bias=False), nn.Linear(hidden_size, 10, bias=False),
nn.Linear(hidden_size, 10, bias=True) nn.Linear(hidden_size, 10, bias=True)
]) ], label='fc2')
def forward(self, x): def forward(self, x):
x = F.relu(self.conv1(x)) x = F.relu(self.conv1(x))
...@@ -97,8 +97,8 @@ def test_grid_search(): ...@@ -97,8 +97,8 @@ def test_grid_search():
selection = set() selection = set()
for model in engine.models: for model in engine.models:
selection.add(( selection.add((
model.get_node_by_name('_model__fc1').operation.parameters['bias'], model.graphs['_model__fc1'].hidden_nodes[0].operation.parameters['bias'],
model.get_node_by_name('_model__fc2').operation.parameters['bias'] model.graphs['_model__fc2'].hidden_nodes[0].operation.parameters['bias']
)) ))
assert len(selection) == 4 assert len(selection) == 4
_reset_execution_engine() _reset_execution_engine()
...@@ -113,8 +113,8 @@ def test_random_search(): ...@@ -113,8 +113,8 @@ def test_random_search():
selection = set() selection = set()
for model in engine.models: for model in engine.models:
selection.add(( selection.add((
model.get_node_by_name('_model__fc1').operation.parameters['bias'], model.graphs['_model__fc1'].hidden_nodes[0].operation.parameters['bias'],
model.get_node_by_name('_model__fc2').operation.parameters['bias'] model.graphs['_model__fc2'].hidden_nodes[0].operation.parameters['bias']
)) ))
assert len(selection) == 4 assert len(selection) == 4
_reset_execution_engine() _reset_execution_engine()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment