Unverified Commit 6808708d authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Nest `ValueChoice` in `LayerChoice` and dict/list in `ValueChoice` (#3508)

parent b7062b5d
......@@ -6,7 +6,7 @@ import re
import torch
from ..graph import Graph, Model, Node
from ..nn.pytorch import InputChoice, LayerChoice, Placeholder
from ..nn.pytorch import InputChoice, Placeholder
from ..operation import Cell, Operation
from ..serializer import get_init_parameters_or_fail
from ..utils import get_importable_name
......@@ -343,7 +343,7 @@ class GraphConverter:
subcell = ir_graph.add_node(submodule_full_name, submodule_type_str, sub_m_attrs)
if isinstance(submodule_obj, Placeholder):
subcell.update_label(submodule_obj.label)
elif isinstance(submodule_obj, (LayerChoice, InputChoice)):
elif isinstance(submodule_obj, InputChoice):
subcell.update_label(sub_m_attrs['label'])
else:
# Graph already created, create Cell for it
......@@ -536,16 +536,6 @@ class GraphConverter:
self.remove_unconnected_nodes(ir_graph, targeted_type='prim::GetAttr')
self.merge_aten_slices(ir_graph)
def _handle_layerchoice(self, module):
choices = []
for cand in list(module):
cand_type = '__torch__.' + get_importable_name(cand.__class__)
choices.append({'type': cand_type, 'parameters': get_init_parameters_or_fail(cand)})
return {
'candidates': choices,
'label': module.label
}
def _handle_inputchoice(self, module):
return {
'n_candidates': module.n_candidates,
......@@ -557,7 +547,8 @@ class GraphConverter:
def _handle_valuechoice(self, module):
return {
'candidates': module.candidates,
'label': module.label
'label': module.label,
'accessor': module._accessor
}
def convert_module(self, script_module, module, module_name, ir_model):
......@@ -590,7 +581,13 @@ class GraphConverter:
if original_type_name in MODULE_EXCEPT_LIST:
pass # do nothing
elif original_type_name == OpTypeName.LayerChoice:
m_attrs = self._handle_layerchoice(module)
graph = Graph(ir_model, -100, module_name, _internal=True) # graph_id is not used now
candidate_name_list = [f'layerchoice_{module.label}_{cand_name}' for cand_name in module.names]
for cand_name, cand in zip(candidate_name_list, module):
cand_type = '__torch__.' + get_importable_name(cand.__class__)
graph.add_node(cand_name, cand_type, get_init_parameters_or_fail(cand))
graph._register()
return graph, {'mutation': 'layerchoice', 'label': module.label, 'candidates': candidate_name_list}
elif original_type_name == OpTypeName.InputChoice:
m_attrs = self._handle_inputchoice(module)
elif original_type_name == OpTypeName.ValueChoice:
......
......@@ -144,15 +144,17 @@ class Model:
for graph_name, graph_data in ir.items():
if graph_name != '_evaluator':
Graph._load(model, graph_name, graph_data)._register()
model.evaluator = Evaluator._load_with_type(ir['_evaluator']['__type__'], ir['_evaluator'])
if '_evaluator' in ir:
model.evaluator = Evaluator._load_with_type(ir['_evaluator']['__type__'], ir['_evaluator'])
return model
def _dump(self) -> Any:
ret = {name: graph._dump() for name, graph in self.graphs.items()}
ret['_evaluator'] = {
'__type__': get_importable_name(self.evaluator.__class__),
**self.evaluator._dump()
}
if self.evaluator is not None:
ret['_evaluator'] = {
'__type__': get_importable_name(self.evaluator.__class__),
**self.evaluator._dump()
}
return ret
def get_nodes(self) -> Iterable['Node']:
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import warnings
from collections import OrderedDict
from typing import Any, List, Union, Dict
import warnings
import torch
import torch.nn as nn
......@@ -268,6 +269,7 @@ class ValueChoice(Translatable, nn.Module):
super().__init__()
self.candidates = candidates
self._label = label if label is not None else f'valuechoice_{uid()}'
self._accessor = []
@property
def label(self):
......@@ -279,11 +281,36 @@ class ValueChoice(Translatable, nn.Module):
def _translate(self):
# Will function as a value when used in serializer.
return self.candidates[0]
return self.access(self.candidates[0])
def __repr__(self):
return f'ValueChoice({self.candidates}, label={repr(self.label)})'
def access(self, value):
if not self._accessor:
return value
try:
v = value
for a in self._accessor:
v = v[a]
except KeyError:
raise KeyError(''.join([f'[{a}]' for a in self._accessor]) + f' does not work on {value}')
return v
def __getitem__(self, item):
"""
Get a sub-element of value choice.
The underlying implementation is to clone the current instance, and append item to "accessor", which records all
the history getitem calls. For example, when accessor is ``[a, b, c]``, the value choice will return ``vc[a][b][c]``
where ``vc`` is the original value choice.
"""
access = copy.deepcopy(self)
access._accessor.append(item)
for candidate in self.candidates:
access.access(candidate)
return access
@basic_unit
class Placeholder(nn.Module):
......
......@@ -4,7 +4,7 @@
from typing import Any, List, Optional, Tuple
from ...mutator import Mutator
from ...graph import Model, Node
from ...graph import Cell, Model, Node
from .api import ValueChoice
......@@ -14,13 +14,23 @@ class LayerChoiceMutator(Mutator):
self.nodes = nodes
def mutate(self, model):
n_candidates = len(self.nodes[0].operation.parameters['candidates'])
indices = list(range(n_candidates))
chosen_index = self.choice(indices)
candidates = self.nodes[0].operation.parameters['candidates']
chosen = self.choice(candidates)
for node in self.nodes:
target = model.get_node_by_name(node.name)
chosen_cand = node.operation.parameters['candidates'][chosen_index]
target.update_operation(chosen_cand['type'], chosen_cand['parameters'])
# Each layer choice corresponds to a cell, which is unconnected in the base graph.
# We add the connections here in the mutation logic.
# Thus, the mutated model should not be mutated again. Everything should be based on the original base graph.
target = model.graphs[node.operation.cell_name]
chosen_node = target.get_node_by_name(chosen)
assert chosen_node is not None
target.add_edge((target.input_node, 0), (chosen_node, None))
target.add_edge((chosen_node, None), (target.output_node, None))
model.get_node_by_name(node.name).update_operation(Cell(node.operation.cell_name))
# remove redundant nodes
for rm_node in target.hidden_nodes:
if rm_node.name != chosen_node.name:
rm_node.remove()
class InputChoiceMutator(Mutator):
......@@ -61,20 +71,14 @@ class ParameterChoiceMutator(Mutator):
def mutate(self, model):
chosen = self.choice(self.candidates)
for node, argname in self.nodes:
chosen_value = node.operation.parameters[argname].access(chosen)
target = model.get_node_by_name(node.name)
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen})
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value})
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
applied_mutators = []
lc_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.LayerChoice'))
for node_list in lc_nodes:
assert _is_all_equal(map(lambda node: len(node.operation.parameters['candidates']), node_list)), \
'Layer choice with the same label must have the same number of candidates.'
mutator = LayerChoiceMutator(node_list)
applied_mutators.append(mutator)
ic_nodes = _group_by_label(model.get_nodes_by_type('__torch__.nni.retiarii.nn.pytorch.api.InputChoice'))
for node_list in ic_nodes:
assert _is_all_equal(map(lambda node: node.operation.parameters['n_candidates'], node_list)) and \
......@@ -99,9 +103,20 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
for node_list in pc_nodes:
assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \
'Value choice with the same label must have the same candidates.'
mutator = ParameterChoiceMutator(node_list, node_list[0][0].operation.parameters[node_list[0][1]].candidates)
first_node, first_argname = node_list[0]
mutator = ParameterChoiceMutator(node_list, first_node.operation.parameters[first_argname].candidates)
applied_mutators.append(mutator)
# apply layer choice at last as it will delete some nodes
lc_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'layerchoice',
model.get_nodes_by_type('_cell')))
for node_list in lc_nodes:
assert _is_all_equal(map(lambda node: len(node.operation.parameters['candidates']), node_list)), \
'Layer choice with the same label must have the same number of candidates.'
mutator = LayerChoiceMutator(node_list)
applied_mutators.append(mutator)
if applied_mutators:
return applied_mutators
return None
......
......@@ -69,6 +69,9 @@ class PrimConstant(PyTorchOperation):
elif self.parameters['type'] == 'Device':
value = self.parameters['value']
return f'{output} = torch.device("{value}")'
elif self.parameters['type'] in ('dict', 'list', 'tuple'):
# TODO: prim::TupleIndex is not supported yet
return f'{output} = {repr(self.parameters["value"])}'
else:
raise RuntimeError(f'unsupported type of prim::Constant: {self.parameters["type"]}')
......
import random
import unittest
from collections import Counter
import nni.retiarii.nn.pytorch as nn
import torch
......@@ -252,6 +253,30 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_value_choice_in_layer_choice(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.LayerChoice([
nn.Linear(3, nn.ValueChoice([10, 20])),
nn.Linear(3, nn.ValueChoice([30, 40]))
])
def forward(self, x):
return self.linear(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 3)
sz_counter = Counter()
sampler = RandomSampler()
for i in range(100):
model_new = model
for mutator in mutators:
model_new = mutator.bind_sampler(sampler).apply(model_new)
sz_counter[self._get_converted_pytorch_model(model_new)(torch.randn(1, 3)).size(1)] += 1
self.assertEqual(len(sz_counter), 4)
def test_shared(self):
class Net(nn.Module):
def __init__(self, shared=True):
......@@ -284,12 +309,94 @@ class TestHighLevelAPI(unittest.TestCase):
# repeat test. Expectation: sometimes succeeds, sometimes fails.
failed_count = 0
for i in range(30):
model_new = model
for mutator in mutators:
model = mutator.bind_sampler(sampler).apply(model)
model_new = mutator.bind_sampler(sampler).apply(model_new)
self.assertEqual(sampler.counter, 2 * (i + 1))
try:
self._get_converted_pytorch_model(model)(torch.randn(1, 3, 3, 3))
self._get_converted_pytorch_model(model_new)(torch.randn(1, 3, 3, 3))
except RuntimeError:
failed_count += 1
self.assertGreater(failed_count, 0)
self.assertLess(failed_count, 30)
def test_valuechoice_access(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
vc = nn.ValueChoice([(6, 3), (8, 5)])
self.conv = nn.Conv2d(3, vc[0], kernel_size=vc[1])
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
torch.Size([1, 8, 1, 1]))
class Net2(nn.Module):
def __init__(self):
super().__init__()
choices = [
{'b': [3], 'bp': [6]},
{'b': [6], 'bp': [12]}
]
self.conv = nn.Conv2d(3, nn.ValueChoice(choices, label='a')['b'][0], 1)
self.conv1 = nn.Conv2d(nn.ValueChoice(choices, label='a')['bp'][0], 3, 1)
def forward(self, x):
x = self.conv(x)
return self.conv1(torch.cat((x, x), 1))
model = self._convert_to_ir(Net2())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
self._get_converted_pytorch_model(mutators[0].apply(model))(input)
def test_valuechoice_access_functional(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([[0.,], [1.,]])
def forward(self, x):
return F.dropout(x, self.dropout_rate()[0])
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_access_functional_expression(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([[1.05,], [1.1,]])
def forward(self, x):
# if expression failed, the exception would be:
# ValueError: dropout probability has to be between 0 and 1, but got 1.05
return F.dropout(x, self.dropout_rate()[0] - .1)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
......@@ -62,11 +62,11 @@ class Net(nn.Module):
self.fc1 = nn.LayerChoice([
nn.Linear(4*4*50, hidden_size, bias=True),
nn.Linear(4*4*50, hidden_size, bias=False)
])
], label='fc1')
self.fc2 = nn.LayerChoice([
nn.Linear(hidden_size, 10, bias=False),
nn.Linear(hidden_size, 10, bias=True)
])
], label='fc2')
def forward(self, x):
x = F.relu(self.conv1(x))
......@@ -97,8 +97,8 @@ def test_grid_search():
selection = set()
for model in engine.models:
selection.add((
model.get_node_by_name('_model__fc1').operation.parameters['bias'],
model.get_node_by_name('_model__fc2').operation.parameters['bias']
model.graphs['_model__fc1'].hidden_nodes[0].operation.parameters['bias'],
model.graphs['_model__fc2'].hidden_nodes[0].operation.parameters['bias']
))
assert len(selection) == 4
_reset_execution_engine()
......@@ -113,8 +113,8 @@ def test_random_search():
selection = set()
for model in engine.models:
selection.add((
model.get_node_by_name('_model__fc1').operation.parameters['bias'],
model.get_node_by_name('_model__fc2').operation.parameters['bias']
model.graphs['_model__fc1'].hidden_nodes[0].operation.parameters['bias'],
model.graphs['_model__fc2'].hidden_nodes[0].operation.parameters['bias']
))
assert len(selection) == 4
_reset_execution_engine()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment