Unverified Commit aea82d71 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Use ValueChoice inline in a serializable instance (#3382)

parent fddc8adc
...@@ -83,7 +83,7 @@ For easy usability and also backward compatibility, we provide some APIs for use ...@@ -83,7 +83,7 @@ For easy usability and also backward compatibility, we provide some APIs for use
# invoked in `forward` function, choose one from the three # invoked in `forward` function, choose one from the three
out = self.input_switch([tensor1, tensor2, tensor3]) out = self.input_switch([tensor1, tensor2, tensor3])
* ``nn.ValueChoice``. It is for choosing one value from some candidate values. It can only be used as input argument of the modules in ``nn.modules`` and ``@blackbox_module`` decorated user-defined modules. *Note that it has not been officially supported.* * ``nn.ValueChoice``. It is for choosing one value from some candidate values. It can only be used as input argument of the modules in ``nn.modules`` and ``@blackbox_module`` decorated user-defined modules.
.. code-block:: python .. code-block:: python
......
...@@ -6,7 +6,7 @@ import abc ...@@ -6,7 +6,7 @@ import abc
import copy import copy
import json import json
from enum import Enum from enum import Enum
from typing import (Any, Dict, List, Optional, Tuple, Union, overload) from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation from .operation import Cell, Operation, _IOPseudoOperation
from .utils import get_full_class_name, import_, uid from .utils import get_full_class_name, import_, uid
...@@ -152,6 +152,14 @@ class Model: ...@@ -152,6 +152,14 @@ class Model:
} }
return ret return ret
def get_nodes(self) -> Iterable['Node']:
"""
Traverse through all the nodes.
"""
for graph in self.graphs.values():
for node in graph.nodes:
yield node
def get_nodes_by_label(self, label: str) -> List['Node']: def get_nodes_by_label(self, label: str) -> List['Node']:
""" """
Traverse all the nodes to find the matched node(s) with the given name. Traverse all the nodes to find the matched node(s) with the given name.
......
...@@ -5,7 +5,7 @@ import warnings ...@@ -5,7 +5,7 @@ import warnings
import torch import torch
import torch.nn as nn import torch.nn as nn
from ...utils import uid, add_record, del_record from ...utils import uid, add_record, del_record, Translatable
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs'] __all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
...@@ -130,6 +130,9 @@ class LayerChoice(nn.Module): ...@@ -130,6 +130,9 @@ class LayerChoice(nn.Module):
warnings.warn('You should not run forward of this module directly.') warnings.warn('You should not run forward of this module directly.')
return x return x
def __repr__(self):
return f'LayerChoice({self.candidates}, label={repr(self.label)})'
class InputChoice(nn.Module): class InputChoice(nn.Module):
""" """
...@@ -188,33 +191,66 @@ class InputChoice(nn.Module): ...@@ -188,33 +191,66 @@ class InputChoice(nn.Module):
warnings.warn('You should not run forward of this module directly.') warnings.warn('You should not run forward of this module directly.')
return candidate_inputs[0] return candidate_inputs[0]
def __repr__(self):
return f'InputChoice(n_candidates={self.n_candidates}, n_chosen={self.n_chosen}, ' \
f'reduction={repr(self.reduction)}, label={repr(self.label)})'
class ValueChoice(nn.Module): class ValueChoice(Translatable, nn.Module):
""" """
ValueChoice is to choose one from ``candidates``. ValueChoice is to choose one from ``candidates``.
Should initialize the values to choose from in init and call the module in forward to get the chosen value. In most use scenarios, ValueChoice should be passed to the init parameters of a serializable module. For example,
.. code-block:: python
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, nn.ValueChoice([32, 64]), kernel_size=nn.ValueChoice([3, 5, 7]))
def forward(self, x):
return self.conv(x)
A common use is to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```. For example, In case, you want to search a parameter that is used repeatedly, this is also possible by sharing the same value choice instance.
(Sharing the label should have the same effect.) For example,
.. code-block:: python .. code-block:: python
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.dropout_rate = nn.ValueChoice([0., 1.]) hidden_dim = nn.ValueChoice([128, 512])
self.fc = nn.Sequential(
nn.Linear(64, hidden_dim),
nn.Linear(hidden_dim, 10)
)
# the following code has the same effect.
# self.fc = nn.Sequential(
# nn.Linear(64, nn.ValueChoice([128, 512], label='dim')),
# nn.Linear(nn.ValueChoice([128, 512], label='dim'), 10)
# )
def forward(self, x): def forward(self, x):
return F.dropout(x, self.dropout_rate()) return self.fc(x)
The following use case is currently not supported because ValueChoice cannot be called in ``__init__``. Note that ValueChoice should be used directly. Transformations like ``nn.Linear(32, nn.ValueChoice([64, 128]) * 2)``
Please use LayerChoice as a workaround. are not supported.
Another common use case is to initialize the values to choose from in init and call the module in forward to get the chosen value.
Usually, this is used to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```.
For example,
.. code-block:: python .. code-block:: python
# in __init__ code class Net(nn.Module):
self.kernel_size = nn.ValueChoice([3, 5]) def __init__(self):
self.conv = nn.Conv2d(3, self.out_channels, kernel_size=self.kernel_size()) super().__init__()
self.dropout_rate = nn.ValueChoice([0., 1.])
def forward(self, x):
return F.dropout(x, self.dropout_rate())
Parameters Parameters
---------- ----------
...@@ -237,6 +273,13 @@ class ValueChoice(nn.Module): ...@@ -237,6 +273,13 @@ class ValueChoice(nn.Module):
warnings.warn('You should not run forward of this module directly.') warnings.warn('You should not run forward of this module directly.')
return self.candidates[0] return self.candidates[0]
def _translate(self):
# Will function as a value when used in serializer.
return self.candidates[0]
def __repr__(self):
return f'ValueChoice({self.candidates}, label={repr(self.label)})'
class Placeholder(nn.Module): class Placeholder(nn.Module):
# TODO: docstring # TODO: docstring
......
from typing import Any, List, Optional from typing import Any, List, Optional, Tuple
from ...mutator import Mutator from ...mutator import Mutator
from ...graph import Model, Node from ...graph import Model, Node
from .api import ValueChoice
class LayerChoiceMutator(Mutator): class LayerChoiceMutator(Mutator):
...@@ -48,6 +49,19 @@ class ValueChoiceMutator(Mutator): ...@@ -48,6 +49,19 @@ class ValueChoiceMutator(Mutator):
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen}) target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
class ParameterChoiceMutator(Mutator):
def __init__(self, nodes: List[Tuple[Node, str]], candidates: List[Any]):
super().__init__()
self.nodes = nodes
self.candidates = candidates
def mutate(self, model):
chosen = self.choice(self.candidates)
for node, argname in self.nodes:
target = model.get_node_by_name(node.name)
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen})
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
applied_mutators = [] applied_mutators = []
...@@ -73,6 +87,18 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: ...@@ -73,6 +87,18 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates']) mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates'])
applied_mutators.append(mutator) applied_mutators.append(mutator)
pc_nodes = []
for node in model.get_nodes():
for name, choice in node.operation.parameters.items():
if isinstance(choice, ValueChoice):
pc_nodes.append((node, name))
pc_nodes = _group_parameters_by_label(pc_nodes)
for node_list in pc_nodes:
assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \
'Value choice with the same label must have the same candidates.'
mutator = ParameterChoiceMutator(node_list, node_list[0][0].operation.parameters[node_list[0][1]].candidates)
applied_mutators.append(mutator)
if applied_mutators: if applied_mutators:
return applied_mutators return applied_mutators
return None return None
...@@ -95,3 +121,13 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]: ...@@ -95,3 +121,13 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
result[label] = [] result[label] = []
result[label].append(node) result[label].append(node)
return list(result.values()) return list(result.values())
def _group_parameters_by_label(nodes: List[Tuple[Node, str]]) -> List[List[Tuple[Node, str]]]:
result = {}
for node, argname in nodes:
label = node.operation.parameters[argname].label
if label not in result:
result[label] = []
result[label].append((node, argname))
return list(result.values())
import abc
import functools import functools
import inspect import inspect
from collections import defaultdict from collections import defaultdict
...@@ -89,6 +90,17 @@ def del_record(key): ...@@ -89,6 +90,17 @@ def del_record(key):
_records.pop(key, None) _records.pop(key, None)
class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@abc.abstractmethod
def _translate(self) -> Any:
pass
def _blackbox_cls(cls): def _blackbox_cls(cls):
class wrapper(cls): class wrapper(cls):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -100,6 +112,15 @@ def _blackbox_cls(cls): ...@@ -100,6 +112,15 @@ def _blackbox_cls(cls):
for argname, value in zip(argname_list, args): for argname, value in zip(argname_list, args):
full_args[argname] = value full_args[argname] = value
# translate parameters
args = list(args)
for i, value in enumerate(args):
if isinstance(value, Translatable):
args[i] = value._translate()
for i, value in kwargs.items():
if isinstance(value, Translatable):
kwargs[i] = value._translate()
add_record(id(self), full_args) # for compatibility. Will remove soon. add_record(id(self), full_args) # for compatibility. Will remove soon.
self.__init_parameters__ = full_args self.__init_parameters__ = full_args
......
...@@ -10,7 +10,7 @@ from nni.retiarii.codegen import model_to_pytorch_script ...@@ -10,7 +10,7 @@ from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
class EnuemrateSampler(Sampler): class EnumerateSampler(Sampler):
def __init__(self): def __init__(self):
self.index = 0 self.index = 0
...@@ -70,7 +70,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -70,7 +70,7 @@ class TestHighLevelAPI(unittest.TestCase):
model = self._convert_to_ir(Net()) model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model) mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
model2 = mutator.apply(model) model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
...@@ -94,7 +94,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -94,7 +94,7 @@ class TestHighLevelAPI(unittest.TestCase):
model = self._convert_to_ir(Net()) model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model) mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
model2 = mutator.apply(model) model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
...@@ -119,7 +119,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -119,7 +119,7 @@ class TestHighLevelAPI(unittest.TestCase):
model = self._convert_to_ir(Net(reduction)) model = self._convert_to_ir(Net(reduction))
mutators = process_inline_mutation(model) mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model = mutator.apply(model) model = mutator.apply(model)
result = self._get_converted_pytorch_model(model)(torch.randn(1, 3, 3, 3)) result = self._get_converted_pytorch_model(model)(torch.randn(1, 3, 3, 3))
if reduction == 'none': if reduction == 'none':
...@@ -144,7 +144,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -144,7 +144,7 @@ class TestHighLevelAPI(unittest.TestCase):
model = self._convert_to_ir(Net()) model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model) mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
model2 = mutator.apply(model) model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
...@@ -152,6 +152,87 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -152,6 +152,87 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).size(), self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).size(),
torch.Size([1, 5, 3, 3])) torch.Size([1, 5, 3, 3]))
def test_value_choice_as_parameter(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5]))
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5]))
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, nn.ValueChoice([6, 8]), kernel_size=nn.ValueChoice([3, 5]))
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 2)
mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 8, 1, 1]))
def test_value_choice_as_parameter_shared(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, nn.ValueChoice([6, 8], label='shared'), 1)
self.conv2 = nn.Conv2d(3, nn.ValueChoice([6, 8], label='shared'), 1)
def forward(self, x):
return self.conv1(x) + self.conv2(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 6, 5, 5]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 8, 5, 5]))
def test_value_choice_in_functional(self): def test_value_choice_in_functional(self):
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
...@@ -164,7 +245,7 @@ class TestHighLevelAPI(unittest.TestCase): ...@@ -164,7 +245,7 @@ class TestHighLevelAPI(unittest.TestCase):
model = self._convert_to_ir(Net()) model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model) mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
model2 = mutator.apply(model) model2 = mutator.apply(model)
self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)) self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment