Unverified Commit 53ae92cc authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Support hyper-parameter tuning in Retiarii (#4399)

parent ab22a5a5
...@@ -3,9 +3,18 @@ Mutation Primitives ...@@ -3,9 +3,18 @@ Mutation Primitives
To make users easily express a model space within their PyTorch/TensorFlow model, NNI provides some inline mutation APIs as shown below. To make users easily express a model space within their PyTorch/TensorFlow model, NNI provides some inline mutation APIs as shown below.
* `nn.LayerChoice <./ApiReference.rst#nni.retiarii.nn.pytorch.LayerChoice>`__. It allows users to put several candidate operations (e.g., PyTorch modules), one of them is chosen in each explored model. We show the most common use case here. For advanced usages, please see `reference <./ApiReference.rst>`__.
.. code-block:: python .. note:: We can actively adding more mutation primitives. If you have any suggestions, feel free to `ask here <https://github.com/microsoft/nni/issues>`__.
``nn.LayerChoice``
""""""""""""""""""
API reference: :class:`nni.retiarii.nn.pytorch.LayerChoice`
It allows users to put several candidate operations (e.g., PyTorch modules), one of them is chosen in each explored model.
.. code-block:: python
# import nni.retiarii.nn.pytorch as nn # import nni.retiarii.nn.pytorch as nn
# declared in `__init__` method # declared in `__init__` method
...@@ -17,9 +26,14 @@ To make users easily express a model space within their PyTorch/TensorFlow model ...@@ -17,9 +26,14 @@ To make users easily express a model space within their PyTorch/TensorFlow model
# invoked in `forward` method # invoked in `forward` method
out = self.layer(x) out = self.layer(x)
* `nn.InputChoice <./ApiReference.rst#nni.retiarii.nn.pytorch.InputChoice>`__. It is mainly for choosing (or trying) different connections. It takes several tensors and chooses ``n_chosen`` tensors from them. ``nn.InputChoice``
""""""""""""""""""
.. code-block:: python API reference: :class:`nni.retiarii.nn.pytorch.InputChoice`
It is mainly for choosing (or trying) different connections. It takes several tensors and chooses ``n_chosen`` tensors from them.
.. code-block:: python
# import nni.retiarii.nn.pytorch as nn # import nni.retiarii.nn.pytorch as nn
# declared in `__init__` method # declared in `__init__` method
...@@ -27,25 +41,87 @@ To make users easily express a model space within their PyTorch/TensorFlow model ...@@ -27,25 +41,87 @@ To make users easily express a model space within their PyTorch/TensorFlow model
# invoked in `forward` method, choose one from the three # invoked in `forward` method, choose one from the three
out = self.input_switch([tensor1, tensor2, tensor3]) out = self.input_switch([tensor1, tensor2, tensor3])
* `nn.ValueChoice <./ApiReference.rst#nni.retiarii.nn.pytorch.ValueChoice>`__. It is for choosing one value from some candidate values. It can only be used as input argument of basic units, that is, modules in ``nni.retiarii.nn.pytorch`` and user-defined modules decorated with ``@basic_unit``. ``nn.ValueChoice``
""""""""""""""""""
.. code-block:: python API reference: :class:`nni.retiarii.nn.pytorch.ValueChoice`
It is for choosing one value from some candidate values. The most common use cases are:
* Used as input arguments of `basic units <LINK_TBD>` (i.e., modules in ``nni.retiarii.nn.pytorch`` and user-defined modules decorated with ``@basic_unit``).
* Used as input arguments of evaluator (*new in v2.7*).
Examples are as follows:
.. code-block:: python
# import nni.retiarii.nn.pytorch as nn # import nni.retiarii.nn.pytorch as nn
# used in `__init__` method # used in `__init__` method
self.conv = nn.Conv2d(XX, XX, kernel_size=nn.ValueChoice([1, 3, 5]) self.conv = nn.Conv2d(XX, XX, kernel_size=nn.ValueChoice([1, 3, 5]))
self.op = MyOp(nn.ValueChoice([0, 1]), nn.ValueChoice([-1, 1])) self.op = MyOp(nn.ValueChoice([0, 1]), nn.ValueChoice([-1, 1]))
* `nn.Repeat <./ApiReference.rst#nni.retiarii.nn.pytorch.Repeat>`__. Repeat a block by a variable number of times. # used in evaluator
def train_and_evaluate(model_cls, learning_rate):
...
* `nn.Cell <./ApiReference.rst#nni.retiarii.nn.pytorch.Cell>`__. `This cell structure is popularly used in NAS literature <https://arxiv.org/abs/1611.01578>`__. Specifically, the cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from user specified candidates, and takes one input from previous nodes and predecessors. Predecessor means the input of cell. The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes). self.evaluator = FunctionalEvaluator(train_and_evaluate, learning_rate=nn.ValueChoice([1e-3, 1e-2, 1e-1]))
.. tip::
All the APIs have an optional argument called ``label``, mutations with the same label will share the same choice. A typical example is, All the APIs have an optional argument called ``label``, mutations with the same label will share the same choice. A typical example is,
.. code-block:: python .. code-block:: python
self.net = nn.Sequential( self.net = nn.Sequential(
nn.Linear(10, nn.ValueChoice([32, 64, 128], label='hidden_dim'), nn.Linear(10, nn.ValueChoice([32, 64, 128], label='hidden_dim')),
nn.Linear(nn.ValueChoice([32, 64, 128], label='hidden_dim'), 3) nn.Linear(nn.ValueChoice([32, 64, 128], label='hidden_dim'), 3)
) )
.. warning::
It looks as if a specific candidate has been chosen (e.g., the way you can put ``ValueChoice`` as a parameter of ``nn.ValueChoice``), but in fact it's a syntax sugar as because the basic units and evaluators do all the underlying works. That means, you cannot assume that ``ValueChoice`` can be used in the same way as its candidates. For example, the following usage will NOT work:
.. code-block:: python
self.blocks = []
for i in range(nn.ValueChoice([1, 2, 3])):
self.blocks.append(Block())
# NOTE: instead you should probably write
# self.blocks = nn.Repeat(Block(), (1, 3))
``nn.Repeat``
"""""""""""""
API reference: :class:`nni.retiarii.nn.pytorch.Repeat`
Repeat a block by a variable number of times.
.. code-block:: python
# import nni.retiarii.nn.pytorch as nn
# used in `__init__` method
# Block() will be deep copied and repeated 3 times
self.blocks = nn.Repeat(Block(), 3)
# Block() will be repeated 1, 2, or 3 times
self.blocks = nn.Repeat(Block(), (1, 3))
# FIXME
# The following use cases have known issues and will be fixed in current release
# Can be used together with layer choice
# With deep copy, the 3 layers will have the same label, thus share the choice
self.blocks = nn.Repeat(nn.LayerChoice([...]), (1, 3))
# To make the three layer choices independently
# Need a factory function that accepts index (0, 1, 2, ...) and returns the module of the `index`-th layer.
self.blocks = nn.Repeat(lambda index: nn.LayerChoice([...], label=f'layer{index}'), (1, 3))
``nn.Cell``
"""""""""""
API reference: :class:`nni.retiarii.nn.pytorch.Cell`
This cell structure is popularly used in `NAS literature <https://arxiv.org/abs/1611.01578>`__. Specifically, the cell consists of multiple "nodes". Each node is a sum of multiple operators. Each operator is chosen from user specified candidates, and takes one input from previous nodes and predecessors. Predecessor means the input of cell. The output of cell is the concatenation of some of the nodes in the cell (currently all the nodes).
# Copyright (c) Microsoft Corporation. # Copyright (c) Microsoft Corporation.
# Licensed under the MIT license. # Licensed under the MIT license.
from ..graph import Evaluator import nni
from nni.retiarii.graph import Evaluator
@nni.trace
class FunctionalEvaluator(Evaluator): class FunctionalEvaluator(Evaluator):
""" """
Functional evaluator that directly takes a function and thus should be general. Functional evaluator that directly takes a function and thus should be general.
......
...@@ -34,7 +34,7 @@ from ..execution.utils import get_mutation_dict ...@@ -34,7 +34,7 @@ from ..execution.utils import get_mutation_dict
from ..graph import Evaluator from ..graph import Evaluator
from ..integration import RetiariiAdvisor from ..integration import RetiariiAdvisor
from ..mutator import Mutator from ..mutator import Mutator
from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation from ..nn.pytorch.mutator import extract_mutation_from_pt_module, process_inline_mutation, process_evaluator_mutations
from ..oneshot.interface import BaseOneShotTrainer from ..oneshot.interface import BaseOneShotTrainer
from ..serializer import is_model_wrapped from ..serializer import is_model_wrapped
from ..strategy import BaseStrategy from ..strategy import BaseStrategy
...@@ -200,6 +200,7 @@ class RetiariiExperiment(Experiment): ...@@ -200,6 +200,7 @@ class RetiariiExperiment(Experiment):
full_ir=self.config.execution_engine not in ['py', 'benchmark'], full_ir=self.config.execution_engine not in ['py', 'benchmark'],
dummy_input=self.config.dummy_input dummy_input=self.config.dummy_input
) )
self.applied_mutators += process_evaluator_mutations(self.trainer, self.applied_mutators)
_logger.info('Start strategy...') _logger.info('Start strategy...')
search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators) search_space = dry_run_for_formatted_search_space(base_model_ir, self.applied_mutators)
......
...@@ -2,11 +2,13 @@ ...@@ -2,11 +2,13 @@
# Licensed under the MIT license. # Licensed under the MIT license.
import inspect import inspect
from collections import defaultdict
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple
import torch.nn as nn import torch.nn as nn
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node from nni.common.serializer import is_traceable
from nni.retiarii.graph import Cell, Graph, Model, ModelStatus, Node, Evaluator
from nni.retiarii.mutator import Mutator from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from nni.retiarii.utils import uid from nni.retiarii.utils import uid
...@@ -282,6 +284,51 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op ...@@ -282,6 +284,51 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
return model, mutators + mutators_final return model, mutators + mutators_final
# mutations for evaluator
class EvaluatorValueChoiceMutator(Mutator):
def __init__(self, keys: List[str], label: Optional[str]):
self.keys = keys
super().__init__(label=label)
def mutate(self, model: Model):
# make a copy to mutate the evaluator
model.evaluator = model.evaluator.trace_copy()
chosen = None
for i, key in enumerate(self.keys):
value_choice: ValueChoice = model.evaluator.trace_kwargs[key]
if i == 0:
# i == 0 is needed here because there can be candidates of "None"
chosen = self.choice(value_choice.candidates)
# get the real chosen value after "access"
model.evaluator.trace_kwargs[key] = value_choice.access(chosen)
return model
def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]:
# take all the value choice in the kwargs of evaluaator into a list
if not is_traceable(evaluator):
return []
mutator_candidates = {}
mutator_keys = defaultdict(list)
for key, param in evaluator.trace_kwargs.items():
if isinstance(param, ValueChoice):
# merge duplicate labels
for mutator in existing_mutators:
if mutator.name == param.label:
raise ValueError(f'Found duplicated labels for mutators {param.label}. When two mutators have the same name, '
'they would share choices. However, sharing choices between model and evaluator is not yet supported.')
if param.label in mutator_candidates and mutator_candidates[param.label] != param.candidates:
raise ValueError(f'Duplicate labels for evaluator ValueChoice {param.label}. They should share choices.'
f'But their candidate list is not equal: {mutator_candidates[param.label][1]} vs. {param.candidates}')
mutator_keys[param.label].append(key)
mutator_candidates[param.label] = param.candidates
mutators = []
for key in mutator_keys:
mutators.append(EvaluatorValueChoiceMutator(mutator_keys[key], key))
return mutators
# utility functions # utility functions
......
...@@ -8,8 +8,11 @@ import torch.nn.functional as F ...@@ -8,8 +8,11 @@ import torch.nn.functional as F
from nni.retiarii import InvalidMutation, Sampler, basic_unit from nni.retiarii import InvalidMutation, Sampler, basic_unit
from nni.retiarii.converter import convert_to_graph from nni.retiarii.converter import convert_to_graph
from nni.retiarii.codegen import model_to_pytorch_script from nni.retiarii.codegen import model_to_pytorch_script
from nni.retiarii.evaluator import FunctionalEvaluator
from nni.retiarii.execution.utils import _unpack_if_only_one from nni.retiarii.execution.utils import _unpack_if_only_one
from nni.retiarii.nn.pytorch.mutator import process_inline_mutation, extract_mutation_from_pt_module from nni.retiarii.graph import Model
from nni.retiarii.nn.pytorch.api import ValueChoice
from nni.retiarii.nn.pytorch.mutator import process_evaluator_mutations, process_inline_mutation, extract_mutation_from_pt_module
from nni.retiarii.serializer import model_wrapper from nni.retiarii.serializer import model_wrapper
from nni.retiarii.utils import ContextStack from nni.retiarii.utils import ContextStack
...@@ -64,14 +67,8 @@ class GraphIR(unittest.TestCase): ...@@ -64,14 +67,8 @@ class GraphIR(unittest.TestCase):
mutators = process_inline_mutation(model) mutators = process_inline_mutation(model)
return model, mutators return model, mutators
def get_serializer(self):
def dummy(cls):
return cls
return dummy
def test_layer_choice(self): def test_layer_choice(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -94,7 +91,7 @@ class GraphIR(unittest.TestCase): ...@@ -94,7 +91,7 @@ class GraphIR(unittest.TestCase):
torch.Size([1, 5, 3, 3])) torch.Size([1, 5, 3, 3]))
def test_layer_choice_multiple(self): def test_layer_choice_multiple(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -112,7 +109,7 @@ class GraphIR(unittest.TestCase): ...@@ -112,7 +109,7 @@ class GraphIR(unittest.TestCase):
torch.Size([1, i, 3, 3])) torch.Size([1, i, 3, 3]))
def test_nested_layer_choice(self): def test_nested_layer_choice(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -139,7 +136,7 @@ class GraphIR(unittest.TestCase): ...@@ -139,7 +136,7 @@ class GraphIR(unittest.TestCase):
torch.Size([1, 5, 5, 5])) torch.Size([1, 5, 5, 5]))
def test_input_choice(self): def test_input_choice(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -163,7 +160,7 @@ class GraphIR(unittest.TestCase): ...@@ -163,7 +160,7 @@ class GraphIR(unittest.TestCase):
torch.Size([1, 5, 3, 3])) torch.Size([1, 5, 3, 3]))
def test_chosen_inputs(self): def test_chosen_inputs(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self, reduction): def __init__(self, reduction):
super().__init__() super().__init__()
...@@ -192,7 +189,7 @@ class GraphIR(unittest.TestCase): ...@@ -192,7 +189,7 @@ class GraphIR(unittest.TestCase):
self.assertEqual(result.size(), torch.Size([1, 3, 3, 3])) self.assertEqual(result.size(), torch.Size([1, 3, 3, 3]))
def test_value_choice(self): def test_value_choice(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -213,7 +210,7 @@ class GraphIR(unittest.TestCase): ...@@ -213,7 +210,7 @@ class GraphIR(unittest.TestCase):
torch.Size([1, 5, 3, 3])) torch.Size([1, 5, 3, 3]))
def test_value_choice_as_parameter(self): def test_value_choice_as_parameter(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -233,7 +230,7 @@ class GraphIR(unittest.TestCase): ...@@ -233,7 +230,7 @@ class GraphIR(unittest.TestCase):
torch.Size([1, 5, 1, 1])) torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self): def test_value_choice_as_parameter(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -253,7 +250,7 @@ class GraphIR(unittest.TestCase): ...@@ -253,7 +250,7 @@ class GraphIR(unittest.TestCase):
torch.Size([1, 5, 1, 1])) torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self): def test_value_choice_as_parameter(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -273,7 +270,7 @@ class GraphIR(unittest.TestCase): ...@@ -273,7 +270,7 @@ class GraphIR(unittest.TestCase):
torch.Size([1, 8, 1, 1])) torch.Size([1, 8, 1, 1]))
def test_value_choice_as_parameter_shared(self): def test_value_choice_as_parameter_shared(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -294,7 +291,7 @@ class GraphIR(unittest.TestCase): ...@@ -294,7 +291,7 @@ class GraphIR(unittest.TestCase):
torch.Size([1, 8, 5, 5])) torch.Size([1, 8, 5, 5]))
def test_value_choice_in_functional(self): def test_value_choice_in_functional(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -313,7 +310,7 @@ class GraphIR(unittest.TestCase): ...@@ -313,7 +310,7 @@ class GraphIR(unittest.TestCase):
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_value_choice_in_layer_choice(self): def test_value_choice_in_layer_choice(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -337,7 +334,7 @@ class GraphIR(unittest.TestCase): ...@@ -337,7 +334,7 @@ class GraphIR(unittest.TestCase):
self.assertEqual(len(sz_counter), 4) self.assertEqual(len(sz_counter), 4)
def test_shared(self): def test_shared(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self, shared=True): def __init__(self, shared=True):
super().__init__() super().__init__()
...@@ -379,7 +376,7 @@ class GraphIR(unittest.TestCase): ...@@ -379,7 +376,7 @@ class GraphIR(unittest.TestCase):
self.assertLess(failed_count, 30) self.assertLess(failed_count, 30)
def test_valuechoice_access(self): def test_valuechoice_access(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -398,7 +395,7 @@ class GraphIR(unittest.TestCase): ...@@ -398,7 +395,7 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(), self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
torch.Size([1, 8, 1, 1])) torch.Size([1, 8, 1, 1]))
@self.get_serializer() @model_wrapper
class Net2(nn.Module): class Net2(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -420,7 +417,7 @@ class GraphIR(unittest.TestCase): ...@@ -420,7 +417,7 @@ class GraphIR(unittest.TestCase):
self._get_converted_pytorch_model(mutators[0].apply(model))(input) self._get_converted_pytorch_model(mutators[0].apply(model))(input)
def test_valuechoice_access_functional(self): def test_valuechoice_access_functional(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -439,7 +436,7 @@ class GraphIR(unittest.TestCase): ...@@ -439,7 +436,7 @@ class GraphIR(unittest.TestCase):
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_access_functional_expression(self): def test_valuechoice_access_functional_expression(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -464,7 +461,7 @@ class GraphIR(unittest.TestCase): ...@@ -464,7 +461,7 @@ class GraphIR(unittest.TestCase):
def forward(self, x): def forward(self, x):
return x + 1 return x + 1
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -532,7 +529,7 @@ class GraphIR(unittest.TestCase): ...@@ -532,7 +529,7 @@ class GraphIR(unittest.TestCase):
self.assertIn(1., result) self.assertIn(1., result)
def test_cell(self): def test_cell(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -551,7 +548,7 @@ class GraphIR(unittest.TestCase): ...@@ -551,7 +548,7 @@ class GraphIR(unittest.TestCase):
self.assertTrue(self._get_converted_pytorch_model(model)( self.assertTrue(self._get_converted_pytorch_model(model)(
torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64])) torch.randn(1, 16), torch.randn(1, 16)).size() == torch.Size([1, 64]))
@self.get_serializer() @model_wrapper
class Net2(nn.Module): class Net2(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -569,7 +566,7 @@ class GraphIR(unittest.TestCase): ...@@ -569,7 +566,7 @@ class GraphIR(unittest.TestCase):
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64])) self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(1, 16)).size() == torch.Size([1, 64]))
def test_nasbench201_cell(self): def test_nasbench201_cell(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -590,7 +587,7 @@ class GraphIR(unittest.TestCase): ...@@ -590,7 +587,7 @@ class GraphIR(unittest.TestCase):
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16])) self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16]))
def test_autoactivation(self): def test_autoactivation(self):
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -618,9 +615,6 @@ class Python(GraphIR): ...@@ -618,9 +615,6 @@ class Python(GraphIR):
def _get_model_with_mutators(self, pytorch_model): def _get_model_with_mutators(self, pytorch_model):
return extract_mutation_from_pt_module(pytorch_model) return extract_mutation_from_pt_module(pytorch_model)
def get_serializer(self):
return model_wrapper
@unittest.skip @unittest.skip
def test_value_choice(self): ... def test_value_choice(self): ...
...@@ -635,7 +629,7 @@ class Python(GraphIR): ...@@ -635,7 +629,7 @@ class Python(GraphIR):
def test_nasbench101_cell(self): def test_nasbench101_cell(self):
# this is only supported in python engine for now. # this is only supported in python engine for now.
@self.get_serializer() @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
...@@ -658,3 +652,42 @@ class Python(GraphIR): ...@@ -658,3 +652,42 @@ class Python(GraphIR):
except InvalidMutation: except InvalidMutation:
continue continue
self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16])) self.assertTrue(self._get_converted_pytorch_model(model)(torch.randn(2, 10)).size() == torch.Size([2, 16]))
class Shared(unittest.TestCase):
# This kind of tests are general across execution engines
def test_valuechoice_in_evaluator(self):
def foo():
pass
evaluator = FunctionalEvaluator(foo, t=1, x=2)
assert process_evaluator_mutations(evaluator, []) == []
evaluator = FunctionalEvaluator(foo, t=1, x=ValueChoice([1, 2]), y=ValueChoice([3, 4]))
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 2
init_model = Model(_internal=True)
init_model.evaluator = evaluator
sampler = EnumerateSampler()
model = mutators[0].bind_sampler(sampler).apply(init_model)
assert model.evaluator.trace_kwargs['x'] == 1
model = mutators[0].bind_sampler(sampler).apply(init_model)
assert model.evaluator.trace_kwargs['x'] == 2
# share label
evaluator = FunctionalEvaluator(foo, t=ValueChoice([1, 2], label='x'), x=ValueChoice([1, 2], label='x'))
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 1
# getitem
choice = ValueChoice([{"a": 1, "b": 2}, {"a": 3, "b": 4}])
evaluator = FunctionalEvaluator(foo, t=1, x=choice['a'], y=choice['b'])
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 1
init_model = Model(_internal=True)
init_model.evaluator = evaluator
sampler = RandomSampler()
for _ in range(10):
model = mutators[0].bind_sampler(sampler).apply(init_model)
assert (model.evaluator.trace_kwargs['x'], model.evaluator.trace_kwargs['y']) in [(1, 2), (3, 4)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment