".github/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6cef7d2366c05a72f6b1e034e9260636d1eccd8d"
Unverified Commit a36dc07e authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Composition of `ValueChoice` (#4435)

parent f8327ba0
...@@ -68,6 +68,44 @@ Examples are as follows: ...@@ -68,6 +68,44 @@ Examples are as follows:
self.evaluator = FunctionalEvaluator(train_and_evaluate, learning_rate=nn.ValueChoice([1e-3, 1e-2, 1e-1])) self.evaluator = FunctionalEvaluator(train_and_evaluate, learning_rate=nn.ValueChoice([1e-3, 1e-2, 1e-1]))
Value choices supports arithmetic operators, which is particularly useful when searching for a network width multiplier:
.. code-block:: python
# init
scale = nn.ValueChoice([1.0, 1.5, 2.0])
self.conv1 = nn.Conv2d(3, round(scale * 16))
self.conv2 = nn.Conv2d(round(scale * 16), round(scale * 64))
self.conv3 = nn.Conv2d(round(scale * 64), round(scale * 256))
# forward
return self.conv3(self.conv2(self.conv1(x)))
Or when kernel size and padding are coupled so as to keep the output size constant:
.. code-block:: python
# init
ks = nn.ValueChoice([3, 5, 7])
self.conv = nn.Conv2d(3, 16, kernel_size=ks, padding=(ks - 1) // 2)
# forward
return self.conv(x)
Or when several layers are concatenated for a final layer.
.. code-block:: python
# init
self.linear1 = nn.Linear(3, nn.ValueChoice([1, 2, 3], label='a'))
self.linear2 = nn.Linear(3, nn.ValueChoice([4, 5, 6], label='b'))
self.final = nn.Linear(nn.ValueChoice([1, 2, 3], label='a') + nn.ValueChoice([4, 5, 6], label='b'), 2)
# forward
return self.final(torch.cat([self.linear1(x), self.linear2(x)], 1))
Some advanced operators are also provided, such as ``nn.ValueChoice.max`` and ``nn.ValueChoice.cond``. See reference of :class:`nni.retiarii.nn.pytorch.ValueChoice` for more details.
.. tip:: .. tip::
All the APIs have an optional argument called ``label``, mutations with the same label will share the same choice. A typical example is, All the APIs have an optional argument called ``label``, mutations with the same label will share the same choice. A typical example is,
......
...@@ -598,7 +598,6 @@ class GraphConverter: ...@@ -598,7 +598,6 @@ class GraphConverter:
return { return {
'candidates': module.candidates, 'candidates': module.candidates,
'label': module.label, 'label': module.label,
'accessor': module._accessor
} }
def _convert_module(self, script_module, module, module_name, module_python_name, ir_model): def _convert_module(self, script_module, module, module_name, module_python_name, ir_model):
......
...@@ -119,7 +119,7 @@ class Model: ...@@ -119,7 +119,7 @@ class Model:
self.graphs: Dict[str, Graph] = {} self.graphs: Dict[str, Graph] = {}
self.evaluator: Optional[Evaluator] = None self.evaluator: Optional[Evaluator] = None
self.history: List['Model'] = [] self.history: List['Mutation'] = []
self.metric: Optional[MetricData] = None self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = [] self.intermediate_metrics: List[MetricData] = []
......
This diff is collapsed.
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
import inspect import inspect
from collections import defaultdict from collections import defaultdict
from typing import Any, List, Optional, Tuple from typing import Any, List, Optional, Tuple, Dict
import torch.nn as nn import torch.nn as nn
...@@ -13,7 +13,7 @@ from nni.retiarii.mutator import Mutator ...@@ -13,7 +13,7 @@ from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from nni.retiarii.utils import uid from nni.retiarii.utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder from .api import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder
from .component import Repeat, NasBench101Cell, NasBench101Mutator from .component import Repeat, NasBench101Cell, NasBench101Mutator
...@@ -65,30 +65,66 @@ class InputChoiceMutator(Mutator): ...@@ -65,30 +65,66 @@ class InputChoiceMutator(Mutator):
class ValueChoiceMutator(Mutator): class ValueChoiceMutator(Mutator):
def __init__(self, nodes: List[Node], candidates: List[Any]): def __init__(self, nodes: List[Node], candidates: List[Any]):
# use nodes[0] as an example to get label
super().__init__(label=nodes[0].operation.parameters['label']) super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes self.nodes = nodes
self.candidates = candidates self.candidates = candidates
def mutate(self, model): def mutate(self, model):
chosen = self.choice(self.candidates) chosen = self.choice(self.candidates)
# no need to support transformation here,
# because it is naturally done in forward loop
for node in self.nodes: for node in self.nodes:
target = model.get_node_by_name(node.name) target = model.get_node_by_name(node.name)
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen}) target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
class ParameterChoiceLeafMutator(Mutator):
# mutate the leaf node (i.e., ValueChoice) of parameter choices
# should be used together with ParameterChoiceMutator
def __init__(self, candidates: List[Any], label: str):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> Model:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
class ParameterChoiceMutator(Mutator): class ParameterChoiceMutator(Mutator):
def __init__(self, nodes: List[Tuple[Node, str]], candidates: List[Any]): # To deal with ValueChoice used as a parameter of a basic unit
node, argname = nodes[0] # should be used together with ParameterChoiceLeafMutator
super().__init__(label=node.operation.parameters[argname].label) # parameter choice mutator is an empty-shell-mutator
# calculate all the parameter values based on previous mutations of value choice mutator
def __init__(self, nodes: List[Tuple[Node, str]]):
super().__init__()
self.nodes = nodes self.nodes = nodes
self.candidates = candidates
def mutate(self, model): def mutate(self, model: Model) -> Model:
chosen = self.choice(self.candidates) # looks like {"label1": "cat", "label2": 123}
value_choice_decisions = {}
for mutation in model.history:
if isinstance(mutation.mutator, ParameterChoiceLeafMutator):
value_choice_decisions[mutation.mutator.label] = mutation.samples[0]
for node, argname in self.nodes: for node, argname in self.nodes:
chosen_value = node.operation.parameters[argname].access(chosen) # argname is the location of the argument
# e.g., Conv2d(out_channels=nn.ValueChoice([1, 2, 3])) => argname = "out_channels"
value_choice: ValueChoiceX = node.operation.parameters[argname]
# calculate all the values on the leaf node of ValueChoiceX computation graph
leaf_node_values = []
for choice in value_choice.inner_choices():
leaf_node_values.append(value_choice_decisions[choice.label])
result_value = value_choice.evaluate(leaf_node_values)
# update model with graph mutation primitives
target = model.get_node_by_name(node.name) target = model.get_node_by_name(node.name)
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value}) target.update_operation(target.operation.type, {**target.operation.parameters, argname: result_value})
class RepeatMutator(Mutator): class RepeatMutator(Mutator):
...@@ -145,18 +181,31 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]: ...@@ -145,18 +181,31 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates']) mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates'])
applied_mutators.append(mutator) applied_mutators.append(mutator)
pc_nodes = [] # `pc_nodes` are arguments of basic units. They can be compositions.
pc_nodes: List[Tuple[Node, str, ValueChoiceX]] = []
for node in model.get_nodes(): for node in model.get_nodes():
for name, choice in node.operation.parameters.items(): for name, choice in node.operation.parameters.items():
if isinstance(choice, ValueChoice): if isinstance(choice, ValueChoiceX):
pc_nodes.append((node, name)) # e.g., (conv_node, "out_channels", ValueChoice([1, 3]))
pc_nodes = _group_parameters_by_label(pc_nodes) pc_nodes.append((node, name, choice))
for node_list in pc_nodes:
assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \ # Break `pc_nodes` down to leaf value choices. They should be what we want to sample.
'Value choice with the same label must have the same candidates.' leaf_value_choices: Dict[str, List[Any]] = {}
first_node, first_argname = node_list[0] for _, __, choice in pc_nodes:
mutator = ParameterChoiceMutator(node_list, first_node.operation.parameters[first_argname].candidates) for inner_choice in choice.inner_choices():
applied_mutators.append(mutator) if inner_choice.label not in leaf_value_choices:
leaf_value_choices[inner_choice.label] = inner_choice.candidates
else:
assert leaf_value_choices[inner_choice.label] == inner_choice.candidates, \
'Value choice with the same label must have the same candidates, but found ' \
f'{leaf_value_choices[inner_choice.label]} vs. {inner_choice.candidates}'
for label, candidates in leaf_value_choices.items():
applied_mutators.append(ParameterChoiceLeafMutator(candidates, label))
# in the end, add another parameter choice mutator for "real" mutations
if pc_nodes:
applied_mutators.append(ParameterChoiceMutator([(node, name) for node, name, _ in pc_nodes]))
# apply layer choice at last as it will delete some nodes # apply layer choice at last as it will delete some nodes
lc_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'layerchoice', lc_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'layerchoice',
...@@ -236,9 +285,10 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op ...@@ -236,9 +285,10 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
# tricky case: value choice that serves as parameters are stored in traced arguments # tricky case: value choice that serves as parameters are stored in traced arguments
if is_basic_unit(module): if is_basic_unit(module):
for key, value in module.trace_kwargs.items(): for key, value in module.trace_kwargs.items():
if isinstance(value, ValueChoice): if isinstance(value, ValueChoiceX):
node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates}) for i, choice in enumerate(value.inner_choices()):
node.label = value.label node = graph.add_node(f'{name}.init.{key}.{i}', 'ValueChoice', {'candidates': choice.candidates})
node.label = choice.label
if isinstance(module, (LayerChoice, InputChoice, ValueChoice)): if isinstance(module, (LayerChoice, InputChoice, ValueChoice)):
# TODO: check the label of module and warn if it's auto-generated # TODO: check the label of module and warn if it's auto-generated
...@@ -286,46 +336,76 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op ...@@ -286,46 +336,76 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
# mutations for evaluator # mutations for evaluator
class EvaluatorValueChoiceMutator(Mutator): class EvaluatorValueChoiceLeafMutator(Mutator):
def __init__(self, keys: List[str], label: Optional[str]): # see "ParameterChoiceLeafMutator"
self.keys = keys # works in the same way
def __init__(self, candidates: List[Any], label: str):
super().__init__(label=label) super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> Model:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
class EvaluatorValueChoiceMutator(Mutator):
# works in the same way as `ParameterChoiceMutator`
# we only need one such mutator for one model/evaluator
def mutate(self, model: Model): def mutate(self, model: Model):
# make a copy to mutate the evaluator # make a copy to mutate the evaluator
model.evaluator = model.evaluator.trace_copy() model.evaluator = model.evaluator.trace_copy()
chosen = None
for i, key in enumerate(self.keys): value_choice_decisions = {}
value_choice: ValueChoice = model.evaluator.trace_kwargs[key] for mutation in model.history:
if i == 0: if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator):
# i == 0 is needed here because there can be candidates of "None" value_choice_decisions[mutation.mutator.label] = mutation.samples[0]
chosen = self.choice(value_choice.candidates)
# get the real chosen value after "access" result = {}
model.evaluator.trace_kwargs[key] = value_choice.access(chosen)
return model # for each argument that is a composition of value choice
# we find all the leaf-value-choice in the mutation
# and compute the final result
for key, param in model.evaluator.trace_kwargs.items():
if isinstance(param, ValueChoiceX):
leaf_node_values = [value_choice_decisions[choice.label] for choice in param.inner_choices()]
result[key] = param.evaluate(leaf_node_values)
model.evaluator.trace_kwargs.update(result)
def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]: def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]:
# take all the value choice in the kwargs of evaluaator into a list # take all the value choice in the kwargs of evaluaator into a list
# `existing_mutators` can mutators generated from `model`
if not is_traceable(evaluator): if not is_traceable(evaluator):
return [] return []
mutator_candidates = {} mutator_candidates = {}
mutator_keys = defaultdict(list) mutator_keys = defaultdict(list)
for key, param in evaluator.trace_kwargs.items(): for key, param in evaluator.trace_kwargs.items():
if isinstance(param, ValueChoice): if isinstance(param, ValueChoiceX):
# merge duplicate labels for choice in param.inner_choices():
for mutator in existing_mutators: # merge duplicate labels
if mutator.name == param.label: for mutator in existing_mutators:
raise ValueError(f'Found duplicated labels for mutators {param.label}. When two mutators have the same name, ' if mutator.name == choice.label:
'they would share choices. However, sharing choices between model and evaluator is not yet supported.') raise ValueError(
if param.label in mutator_candidates and mutator_candidates[param.label] != param.candidates: f'Found duplicated labels “{choice.label}”. When two value choices have the same name, '
raise ValueError(f'Duplicate labels for evaluator ValueChoice {param.label}. They should share choices.' 'they would share choices. However, sharing choices between model and evaluator is not yet supported.'
f'But their candidate list is not equal: {mutator_candidates[param.label][1]} vs. {param.candidates}') )
mutator_keys[param.label].append(key) if choice.label in mutator_candidates and mutator_candidates[choice.label] != choice.candidates:
mutator_candidates[param.label] = param.candidates raise ValueError(
f'Duplicate labels for evaluator ValueChoice {choice.label}. They should share choices.'
f'But their candidate list is not equal: {mutator_candidates[choice.label][1]} vs. {choice.candidates}'
)
mutator_keys[choice.label].append(key)
mutator_candidates[choice.label] = choice.candidates
mutators = [] mutators = []
for key in mutator_keys: for label in mutator_keys:
mutators.append(EvaluatorValueChoiceMutator(mutator_keys[key], key)) mutators.append(EvaluatorValueChoiceLeafMutator(mutator_candidates[label], label))
if mutators:
# one last mutator to actually apply the mutations
mutators.append(EvaluatorValueChoiceMutator())
return mutators return mutators
...@@ -359,13 +439,3 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]: ...@@ -359,13 +439,3 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
result[label] = [] result[label] = []
result[label].append(node) result[label].append(node)
return list(result.values()) return list(result.values())
def _group_parameters_by_label(nodes: List[Tuple[Node, str]]) -> List[List[Tuple[Node, str]]]:
result = {}
for node, argname in nodes:
label = node.operation.parameters[argname].label
if label not in result:
result[label] = []
result[label].append((node, argname))
return list(result.values())
import math
import random import random
import unittest import unittest
from collections import Counter from collections import Counter
import pytest
import nni.retiarii.nn.pytorch as nn import nni.retiarii.nn.pytorch as nn
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -50,7 +53,19 @@ class MutableConv(nn.Module): ...@@ -50,7 +53,19 @@ class MutableConv(nn.Module):
return self.conv2(x) return self.conv2(x)
def _apply_all_mutators(model, mutators, samplers):
if not isinstance(samplers, list):
samplers = [samplers for _ in range(len(mutators))]
assert len(samplers) == len(mutators)
model_new = model
for mutator, sampler in zip(mutators, samplers):
model_new = mutator.bind_sampler(sampler).apply(model_new)
return model_new
class GraphIR(unittest.TestCase): class GraphIR(unittest.TestCase):
# graph engine will have an extra mutator for parameter choices
value_choice_incr = 1
def _convert_to_ir(self, model): def _convert_to_ir(self, model):
script_module = torch.jit.script(model) script_module = torch.jit.script(model)
...@@ -220,7 +235,7 @@ class GraphIR(unittest.TestCase): ...@@ -220,7 +235,7 @@ class GraphIR(unittest.TestCase):
return self.conv(x) return self.conv(x)
model, mutators = self._get_model_with_mutators(Net()) model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1 + self.value_choice_incr)
mutator = mutators[0].bind_sampler(EnumerateSampler()) mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model) model1 = mutator.apply(model)
model2 = mutator.apply(model) model2 = mutator.apply(model)
...@@ -240,16 +255,16 @@ class GraphIR(unittest.TestCase): ...@@ -240,16 +255,16 @@ class GraphIR(unittest.TestCase):
return self.conv(x) return self.conv(x)
model, mutators = self._get_model_with_mutators(Net()) model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), self.value_choice_incr + 1)
mutator = mutators[0].bind_sampler(EnumerateSampler()) samplers = [EnumerateSampler() for _ in range(len(mutators))]
model1 = mutator.apply(model) model1 = _apply_all_mutators(model, mutators, samplers)
model2 = mutator.apply(model) model2 = _apply_all_mutators(model, mutators, samplers)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(), self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 3, 3])) torch.Size([1, 5, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(), self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 1, 1])) torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self): def test_value_choice_as_two_parameters(self):
@model_wrapper @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
...@@ -260,13 +275,14 @@ class GraphIR(unittest.TestCase): ...@@ -260,13 +275,14 @@ class GraphIR(unittest.TestCase):
return self.conv(x) return self.conv(x)
model, mutators = self._get_model_with_mutators(Net()) model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2) self.assertEqual(len(mutators), 2 + self.value_choice_incr)
mutators[0].bind_sampler(EnumerateSampler()) samplers = [EnumerateSampler() for _ in range(len(mutators))]
mutators[1].bind_sampler(EnumerateSampler()) model1 = _apply_all_mutators(model, mutators, samplers)
model2 = _apply_all_mutators(model, mutators, samplers)
input = torch.randn(1, 3, 5, 5) input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(), self.assertEqual(self._get_converted_pytorch_model(model1)(input).size(),
torch.Size([1, 6, 3, 3])) torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(), self.assertEqual(self._get_converted_pytorch_model(model2)(input).size(),
torch.Size([1, 8, 1, 1])) torch.Size([1, 8, 1, 1]))
def test_value_choice_as_parameter_shared(self): def test_value_choice_as_parameter_shared(self):
...@@ -281,10 +297,10 @@ class GraphIR(unittest.TestCase): ...@@ -281,10 +297,10 @@ class GraphIR(unittest.TestCase):
return self.conv1(x) + self.conv2(x) return self.conv1(x) + self.conv2(x)
model, mutators = self._get_model_with_mutators(Net()) model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1 + self.value_choice_incr)
mutator = mutators[0].bind_sampler(EnumerateSampler()) sampler = EnumerateSampler()
model1 = mutator.apply(model) model1 = _apply_all_mutators(model, mutators, sampler)
model2 = mutator.apply(model) model2 = _apply_all_mutators(model, mutators, sampler)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(), self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 6, 5, 5])) torch.Size([1, 6, 5, 5]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(), self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
...@@ -323,13 +339,11 @@ class GraphIR(unittest.TestCase): ...@@ -323,13 +339,11 @@ class GraphIR(unittest.TestCase):
return self.linear(x) return self.linear(x)
model, mutators = self._get_model_with_mutators(Net()) model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 3) self.assertEqual(len(mutators), 3 + self.value_choice_incr)
sz_counter = Counter() sz_counter = Counter()
sampler = RandomSampler() sampler = RandomSampler()
for i in range(100): for i in range(100):
model_new = model model_new = _apply_all_mutators(model, mutators, sampler)
for mutator in mutators:
model_new = mutator.bind_sampler(sampler).apply(model_new)
sz_counter[self._get_converted_pytorch_model(model_new)(torch.randn(1, 3)).size(1)] += 1 sz_counter[self._get_converted_pytorch_model(model_new)(torch.randn(1, 3)).size(1)] += 1
self.assertEqual(len(sz_counter), 4) self.assertEqual(len(sz_counter), 4)
...@@ -375,7 +389,7 @@ class GraphIR(unittest.TestCase): ...@@ -375,7 +389,7 @@ class GraphIR(unittest.TestCase):
self.assertGreater(failed_count, 0) self.assertGreater(failed_count, 0)
self.assertLess(failed_count, 30) self.assertLess(failed_count, 30)
def test_valuechoice_access(self): def test_valuechoice_getitem(self):
@model_wrapper @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
...@@ -387,12 +401,12 @@ class GraphIR(unittest.TestCase): ...@@ -387,12 +401,12 @@ class GraphIR(unittest.TestCase):
return self.conv(x) return self.conv(x)
model, mutators = self._get_model_with_mutators(Net()) model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1 + self.value_choice_incr)
mutators[0].bind_sampler(EnumerateSampler()) sampler = EnumerateSampler()
input = torch.randn(1, 3, 5, 5) input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(), self.assertEqual(self._get_converted_pytorch_model(_apply_all_mutators(model, mutators, sampler))(input).size(),
torch.Size([1, 6, 3, 3])) torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(), self.assertEqual(self._get_converted_pytorch_model(_apply_all_mutators(model, mutators, sampler))(input).size(),
torch.Size([1, 8, 1, 1])) torch.Size([1, 8, 1, 1]))
@model_wrapper @model_wrapper
...@@ -411,12 +425,11 @@ class GraphIR(unittest.TestCase): ...@@ -411,12 +425,11 @@ class GraphIR(unittest.TestCase):
return self.conv1(torch.cat((x, x), 1)) return self.conv1(torch.cat((x, x), 1))
model, mutators = self._get_model_with_mutators(Net2()) model, mutators = self._get_model_with_mutators(Net2())
self.assertEqual(len(mutators), 1) self.assertEqual(len(mutators), 1 + self.value_choice_incr)
mutators[0].bind_sampler(EnumerateSampler())
input = torch.randn(1, 3, 5, 5) input = torch.randn(1, 3, 5, 5)
self._get_converted_pytorch_model(mutators[0].apply(model))(input) self._get_converted_pytorch_model(_apply_all_mutators(model, mutators, EnumerateSampler()))(input)
def test_valuechoice_access_functional(self): def test_valuechoice_getitem_functional(self):
@model_wrapper @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
...@@ -435,7 +448,7 @@ class GraphIR(unittest.TestCase): ...@@ -435,7 +448,7 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3])) self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_access_functional_expression(self): def test_valuechoice_getitem_functional_expression(self):
@model_wrapper @model_wrapper
class Net(nn.Module): class Net(nn.Module):
def __init__(self): def __init__(self):
...@@ -456,6 +469,43 @@ class GraphIR(unittest.TestCase): ...@@ -456,6 +469,43 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3])) self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0) self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_multi(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
choice1 = nn.ValueChoice([{"in": 1, "out": 3}, {"in": 2, "out": 6}, {"in": 3, "out": 9}])
choice2 = nn.ValueChoice([2.5, 3.0, 3.5], label='multi')
choice3 = nn.ValueChoice([2.5, 3.0, 3.5], label='multi')
self.conv1 = nn.Conv2d(choice1["in"], round(choice1["out"] * choice2), 1)
self.conv2 = nn.Conv2d(choice1["in"], round(choice1["out"] * choice3), 1)
def forward(self, x):
return self.conv1(x) + self.conv2(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2 + self.value_choice_incr)
samplers = [EnumerateSampler()] + [RandomSampler() for _ in range(self.value_choice_incr + 1)]
for i in range(10):
model_new = _apply_all_mutators(model, mutators, samplers)
result = self._get_converted_pytorch_model(model_new)(torch.randn(1, i % 3 + 1, 3, 3))
self.assertIn(result.size(), [torch.Size([1, round((i % 3 + 1) * 3 * k), 3, 3]) for k in [2.5, 3.0, 3.5]])
def test_valuechoice_inconsistent_label(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, nn.ValueChoice([3, 5], label='a'), 1)
self.conv2 = nn.Conv2d(3, nn.ValueChoice([3, 6], label='a'), 1)
def forward(self, x):
return torch.cat([self.conv1(x), self.conv2(x)], 1)
with pytest.raises(AssertionError):
self._get_model_with_mutators(Net())
def test_repeat(self): def test_repeat(self):
class AddOne(nn.Module): class AddOne(nn.Module):
def forward(self, x): def forward(self, x):
...@@ -645,6 +695,9 @@ class GraphIR(unittest.TestCase): ...@@ -645,6 +695,9 @@ class GraphIR(unittest.TestCase):
class Python(GraphIR): class Python(GraphIR):
# Python engine doesn't have the extra mutator
value_choice_incr = 0
def _get_converted_pytorch_model(self, model_ir): def _get_converted_pytorch_model(self, model_ir):
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history} mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
with ContextStack('fixed', mutation): with ContextStack('fixed', mutation):
...@@ -661,10 +714,10 @@ class Python(GraphIR): ...@@ -661,10 +714,10 @@ class Python(GraphIR):
def test_value_choice_in_functional(self): ... def test_value_choice_in_functional(self): ...
@unittest.skip @unittest.skip
def test_valuechoice_access_functional(self): ... def test_valuechoice_getitem_functional(self): ...
@unittest.skip @unittest.skip
def test_valuechoice_access_functional_expression(self): ... def test_valuechoice_getitem_functional_expression(self): ...
def test_cell_loose_end(self): def test_cell_loose_end(self):
@model_wrapper @model_wrapper
...@@ -744,6 +797,95 @@ class Python(GraphIR): ...@@ -744,6 +797,95 @@ class Python(GraphIR):
class Shared(unittest.TestCase): class Shared(unittest.TestCase):
# This kind of tests are general across execution engines # This kind of tests are general across execution engines
def test_value_choice_api_purely(self):
a = nn.ValueChoice([1, 2], label='a')
b = nn.ValueChoice([3, 4], label='b')
c = nn.ValueChoice([5, 6], label='c')
d = a + b + 3 * c
for i, choice in enumerate(d.inner_choices()):
if i == 0:
assert choice.candidates == [1, 2]
elif i == 1:
assert choice.candidates == [3, 4]
elif i == 2:
assert choice.candidates == [5, 6]
assert d.evaluate([2, 3, 5]) == 20
a = nn.ValueChoice(['cat', 'dog'])
b = nn.ValueChoice(['milk', 'coffee'])
assert (a + b).evaluate(['dog', 'coffee']) == 'dogcoffee'
assert (a + 2 * b).evaluate(['cat', 'milk']) == 'catmilkmilk'
assert (3 - nn.ValueChoice([1, 2])).evaluate([1]) == 2
with pytest.raises(TypeError):
a + nn.ValueChoice([1, 3])
a = nn.ValueChoice([1, 17])
a = (abs(-a * 3) % 11) ** 5
assert 'abs' in repr(a)
with pytest.raises(ValueError):
a.evaluate([42])
assert a.evaluate([17]) == 7 ** 5
a = round(7 / nn.ValueChoice([2, 5]))
assert a.evaluate([2]) == 4
a = ~(77 ^ (nn.ValueChoice([1, 4]) & 5))
assert a.evaluate([4]) == ~(77 ^ (4 & 5))
a = nn.ValueChoice([5, 3]) * nn.ValueChoice([6.5, 7.5])
assert math.floor(a.evaluate([5, 7.5])) == int(5 * 7.5)
a = nn.ValueChoice([1, 3])
b = nn.ValueChoice([2, 4])
with pytest.raises(RuntimeError):
min(a, b)
with pytest.raises(RuntimeError):
if a < b:
...
assert nn.ValueChoice.min(a, b).evaluate([3, 2]) == 2
assert nn.ValueChoice.max(a, b).evaluate([3, 2]) == 3
assert nn.ValueChoice.max(1, 2, 3) == 3
assert nn.ValueChoice.max([1, 3, 2]) == 3
assert nn.ValueChoice.condition(nn.ValueChoice([2, 3]) <= 2, 'a', 'b').evaluate([3]) == 'b'
assert nn.ValueChoice.condition(nn.ValueChoice([2, 3]) <= 2, 'a', 'b').evaluate([2]) == 'a'
with pytest.raises(RuntimeError):
assert int(nn.ValueChoice([2.5, 3.5])).evalute([2.5]) == 2
assert nn.ValueChoice.to_int(nn.ValueChoice([2.5, 3.5])).evaluate([2.5]) == 2
assert nn.ValueChoice.to_float(nn.ValueChoice(['2.5', '3.5'])).evaluate(['3.5']) == 3.5
def test_make_divisible(self):
def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
if min_value is None:
min_value = divisor
new_value = nn.ValueChoice.max(min_value, nn.ValueChoice.to_int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
return nn.ValueChoice.condition(new_value < min_ratio * value, new_value + divisor, new_value)
def original_make_divisible(value, divisor, min_value=None, min_ratio=0.9):
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
if new_value < min_ratio * value:
new_value += divisor
return new_value
values = [4, 8, 16, 32, 64, 128]
divisors = [2, 3, 5, 7, 15]
with pytest.raises(RuntimeError):
original_make_divisible(nn.ValueChoice(values, label='value'), nn.ValueChoice(divisors, label='divisor'))
result = make_divisible(nn.ValueChoice(values, label='value'), nn.ValueChoice(divisors, label='divisor'))
for value in values:
for divisor in divisors:
lst = [value if choice.label == 'value' else divisor for choice in result.inner_choices()]
assert result.evaluate(lst) == original_make_divisible(value, divisor)
def test_valuechoice_in_evaluator(self): def test_valuechoice_in_evaluator(self):
def foo(): def foo():
pass pass
...@@ -753,28 +895,28 @@ class Shared(unittest.TestCase): ...@@ -753,28 +895,28 @@ class Shared(unittest.TestCase):
evaluator = FunctionalEvaluator(foo, t=1, x=ValueChoice([1, 2]), y=ValueChoice([3, 4])) evaluator = FunctionalEvaluator(foo, t=1, x=ValueChoice([1, 2]), y=ValueChoice([3, 4]))
mutators = process_evaluator_mutations(evaluator, []) mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 2 assert len(mutators) == 3
init_model = Model(_internal=True) init_model = Model(_internal=True)
init_model.evaluator = evaluator init_model.evaluator = evaluator
sampler = EnumerateSampler() samplers = [EnumerateSampler() for _ in range(3)]
model = mutators[0].bind_sampler(sampler).apply(init_model) model = _apply_all_mutators(init_model, mutators, samplers)
assert model.evaluator.trace_kwargs['x'] == 1 assert model.evaluator.trace_kwargs['x'] == 1
model = mutators[0].bind_sampler(sampler).apply(init_model) model = _apply_all_mutators(init_model, mutators, samplers)
assert model.evaluator.trace_kwargs['x'] == 2 assert model.evaluator.trace_kwargs['x'] == 2
# share label # share label
evaluator = FunctionalEvaluator(foo, t=ValueChoice([1, 2], label='x'), x=ValueChoice([1, 2], label='x')) evaluator = FunctionalEvaluator(foo, t=ValueChoice([1, 2], label='x'), x=ValueChoice([1, 2], label='x'))
mutators = process_evaluator_mutations(evaluator, []) mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 1 assert len(mutators) == 2
# getitem # getitem
choice = ValueChoice([{"a": 1, "b": 2}, {"a": 3, "b": 4}]) choice = ValueChoice([{"a": 1, "b": 2}, {"a": 3, "b": 4}])
evaluator = FunctionalEvaluator(foo, t=1, x=choice['a'], y=choice['b']) evaluator = FunctionalEvaluator(foo, t=1, x=choice['a'], y=choice['b'])
mutators = process_evaluator_mutations(evaluator, []) mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 1 assert len(mutators) == 2
init_model = Model(_internal=True) init_model = Model(_internal=True)
init_model.evaluator = evaluator init_model.evaluator = evaluator
sampler = RandomSampler() sampler = RandomSampler()
for _ in range(10): for _ in range(10):
model = mutators[0].bind_sampler(sampler).apply(init_model) model = _apply_all_mutators(init_model, mutators, sampler)
assert (model.evaluator.trace_kwargs['x'], model.evaluator.trace_kwargs['y']) in [(1, 2), (3, 4)] assert (model.evaluator.trace_kwargs['x'], model.evaluator.trace_kwargs['y']) in [(1, 2), (3, 4)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment