"vscode:/vscode.git/clone" did not exist on "a9b2b1dcd7aec58885cc1ca5cd20dc3fa77e5a5b"
Unverified Commit aea82d71 authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

[Retiarii] Use ValueChoice inline in a serializable instance (#3382)

parent fddc8adc
......@@ -83,7 +83,7 @@ For easy usability and also backward compatibility, we provide some APIs for use
# invoked in `forward` function, choose one from the three
out = self.input_switch([tensor1, tensor2, tensor3])
* ``nn.ValueChoice``. It is for choosing one value from some candidate values. It can only be used as input argument of the modules in ``nn.modules`` and ``@blackbox_module`` decorated user-defined modules. *Note that it has not been officially supported.*
* ``nn.ValueChoice``. It is for choosing one value from some candidate values. It can only be used as input argument of the modules in ``nn.modules`` and ``@blackbox_module`` decorated user-defined modules.
.. code-block:: python
......
......@@ -6,7 +6,7 @@ import abc
import copy
import json
from enum import Enum
from typing import (Any, Dict, List, Optional, Tuple, Union, overload)
from typing import (Any, Dict, Iterable, List, Optional, Tuple, Union, overload)
from .operation import Cell, Operation, _IOPseudoOperation
from .utils import get_full_class_name, import_, uid
......@@ -152,6 +152,14 @@ class Model:
}
return ret
def get_nodes(self) -> Iterable['Node']:
"""
Traverse through all the nodes.
"""
for graph in self.graphs.values():
for node in graph.nodes:
yield node
def get_nodes_by_label(self, label: str) -> List['Node']:
"""
Traverse all the nodes to find the matched node(s) with the given name.
......
......@@ -5,7 +5,7 @@ import warnings
import torch
import torch.nn as nn
from ...utils import uid, add_record, del_record
from ...utils import uid, add_record, del_record, Translatable
__all__ = ['LayerChoice', 'InputChoice', 'ValueChoice', 'Placeholder', 'ChosenInputs']
......@@ -130,6 +130,9 @@ class LayerChoice(nn.Module):
warnings.warn('You should not run forward of this module directly.')
return x
def __repr__(self):
return f'LayerChoice({self.candidates}, label={repr(self.label)})'
class InputChoice(nn.Module):
"""
......@@ -188,33 +191,66 @@ class InputChoice(nn.Module):
warnings.warn('You should not run forward of this module directly.')
return candidate_inputs[0]
def __repr__(self):
return f'InputChoice(n_candidates={self.n_candidates}, n_chosen={self.n_chosen}, ' \
f'reduction={repr(self.reduction)}, label={repr(self.label)})'
class ValueChoice(nn.Module):
class ValueChoice(Translatable, nn.Module):
"""
ValueChoice is to choose one from ``candidates``.
Should initialize the values to choose from in init and call the module in forward to get the chosen value.
In most use scenarios, ValueChoice should be passed to the init parameters of a serializable module. For example,
.. code-block:: python
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, nn.ValueChoice([32, 64]), kernel_size=nn.ValueChoice([3, 5, 7]))
def forward(self, x):
return self.conv(x)
A common use is to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```. For example,
In case, you want to search a parameter that is used repeatedly, this is also possible by sharing the same value choice instance.
(Sharing the label should have the same effect.) For example,
.. code-block:: python
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([0., 1.])
hidden_dim = nn.ValueChoice([128, 512])
self.fc = nn.Sequential(
nn.Linear(64, hidden_dim),
nn.Linear(hidden_dim, 10)
)
# the following code has the same effect.
# self.fc = nn.Sequential(
# nn.Linear(64, nn.ValueChoice([128, 512], label='dim')),
# nn.Linear(nn.ValueChoice([128, 512], label='dim'), 10)
# )
def forward(self, x):
return F.dropout(x, self.dropout_rate())
return self.fc(x)
The following use case is currently not supported because ValueChoice cannot be called in ``__init__``.
Please use LayerChoice as a workaround.
Note that ValueChoice should be used directly. Transformations like ``nn.Linear(32, nn.ValueChoice([64, 128]) * 2)``
are not supported.
Another common use case is to initialize the values to choose from in init and call the module in forward to get the chosen value.
Usually, this is used to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```.
For example,
.. code-block:: python
# in __init__ code
self.kernel_size = nn.ValueChoice([3, 5])
self.conv = nn.Conv2d(3, self.out_channels, kernel_size=self.kernel_size())
class Net(nn.Module):
def __init__(self):
super().__init__()
self.dropout_rate = nn.ValueChoice([0., 1.])
def forward(self, x):
return F.dropout(x, self.dropout_rate())
Parameters
----------
......@@ -237,6 +273,13 @@ class ValueChoice(nn.Module):
warnings.warn('You should not run forward of this module directly.')
return self.candidates[0]
def _translate(self):
# Will function as a value when used in serializer.
return self.candidates[0]
def __repr__(self):
return f'ValueChoice({self.candidates}, label={repr(self.label)})'
class Placeholder(nn.Module):
# TODO: docstring
......
from typing import Any, List, Optional
from typing import Any, List, Optional, Tuple
from ...mutator import Mutator
from ...graph import Model, Node
from .api import ValueChoice
class LayerChoiceMutator(Mutator):
......@@ -48,6 +49,19 @@ class ValueChoiceMutator(Mutator):
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
class ParameterChoiceMutator(Mutator):
def __init__(self, nodes: List[Tuple[Node, str]], candidates: List[Any]):
super().__init__()
self.nodes = nodes
self.candidates = candidates
def mutate(self, model):
chosen = self.choice(self.candidates)
for node, argname in self.nodes:
target = model.get_node_by_name(node.name)
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen})
def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
applied_mutators = []
......@@ -73,6 +87,18 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates'])
applied_mutators.append(mutator)
pc_nodes = []
for node in model.get_nodes():
for name, choice in node.operation.parameters.items():
if isinstance(choice, ValueChoice):
pc_nodes.append((node, name))
pc_nodes = _group_parameters_by_label(pc_nodes)
for node_list in pc_nodes:
assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \
'Value choice with the same label must have the same candidates.'
mutator = ParameterChoiceMutator(node_list, node_list[0][0].operation.parameters[node_list[0][1]].candidates)
applied_mutators.append(mutator)
if applied_mutators:
return applied_mutators
return None
......@@ -95,3 +121,13 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
result[label] = []
result[label].append(node)
return list(result.values())
def _group_parameters_by_label(nodes: List[Tuple[Node, str]]) -> List[List[Tuple[Node, str]]]:
result = {}
for node, argname in nodes:
label = node.operation.parameters[argname].label
if label not in result:
result[label] = []
result[label].append((node, argname))
return list(result.values())
import abc
import functools
import inspect
from collections import defaultdict
......@@ -89,6 +90,17 @@ def del_record(key):
_records.pop(key, None)
class Translatable(abc.ABC):
"""
Inherit this class and implement ``translate`` when the inner class needs a different
parameter from the wrapper class in its init function.
"""
@abc.abstractmethod
def _translate(self) -> Any:
pass
def _blackbox_cls(cls):
class wrapper(cls):
def __init__(self, *args, **kwargs):
......@@ -100,6 +112,15 @@ def _blackbox_cls(cls):
for argname, value in zip(argname_list, args):
full_args[argname] = value
# translate parameters
args = list(args)
for i, value in enumerate(args):
if isinstance(value, Translatable):
args[i] = value._translate()
for i, value in kwargs.items():
if isinstance(value, Translatable):
kwargs[i] = value._translate()
add_record(id(self), full_args) # for compatibility. Will remove soon.
self.__init_parameters__ = full_args
......
......@@ -10,7 +10,7 @@ from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation
class EnuemrateSampler(Sampler):
class EnumerateSampler(Sampler):
def __init__(self):
self.index = 0
......@@ -70,7 +70,7 @@ class TestHighLevelAPI(unittest.TestCase):
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler())
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
......@@ -94,7 +94,7 @@ class TestHighLevelAPI(unittest.TestCase):
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler())
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
......@@ -119,7 +119,7 @@ class TestHighLevelAPI(unittest.TestCase):
model = self._convert_to_ir(Net(reduction))
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler())
mutator = mutators[0].bind_sampler(EnumerateSampler())
model = mutator.apply(model)
result = self._get_converted_pytorch_model(model)(torch.randn(1, 3, 3, 3))
if reduction == 'none':
......@@ -144,7 +144,7 @@ class TestHighLevelAPI(unittest.TestCase):
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler())
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(),
......@@ -152,6 +152,87 @@ class TestHighLevelAPI(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).size(),
torch.Size([1, 5, 3, 3]))
def test_value_choice_as_parameter(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5]))
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, 5, kernel_size=nn.ValueChoice([3, 5]))
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv = nn.Conv2d(3, nn.ValueChoice([6, 8]), kernel_size=nn.ValueChoice([3, 5]))
def forward(self, x):
return self.conv(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 2)
mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
torch.Size([1, 8, 1, 1]))
def test_value_choice_as_parameter_shared(self):
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, nn.ValueChoice([6, 8], label='shared'), 1)
self.conv2 = nn.Conv2d(3, nn.ValueChoice([6, 8], label='shared'), 1)
def forward(self, x):
return self.conv1(x) + self.conv2(x)
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 6, 5, 5]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 8, 5, 5]))
def test_value_choice_in_functional(self):
class Net(nn.Module):
def __init__(self):
......@@ -164,7 +245,7 @@ class TestHighLevelAPI(unittest.TestCase):
model = self._convert_to_ir(Net())
mutators = process_inline_mutation(model)
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnuemrateSampler())
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment