Unverified Commit a36dc07e authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Composition of `ValueChoice` (#4435)

parent f8327ba0
......@@ -68,6 +68,44 @@ Examples are as follows:
self.evaluator = FunctionalEvaluator(train_and_evaluate, learning_rate=nn.ValueChoice([1e-3, 1e-2, 1e-1]))
Value choices supports arithmetic operators, which is particularly useful when searching for a network width multiplier:
.. code-block:: python
# init
scale = nn.ValueChoice([1.0, 1.5, 2.0])
self.conv1 = nn.Conv2d(3, round(scale * 16))
self.conv2 = nn.Conv2d(round(scale * 16), round(scale * 64))
self.conv3 = nn.Conv2d(round(scale * 64), round(scale * 256))
# forward
return self.conv3(self.conv2(self.conv1(x)))
Or when kernel size and padding are coupled so as to keep the output size constant:
.. code-block:: python
# init
ks = nn.ValueChoice([3, 5, 7])
self.conv = nn.Conv2d(3, 16, kernel_size=ks, padding=(ks - 1) // 2)
# forward
return self.conv(x)
Or when several layers are concatenated for a final layer.
.. code-block:: python
# init
self.linear1 = nn.Linear(3, nn.ValueChoice([1, 2, 3], label='a'))
self.linear2 = nn.Linear(3, nn.ValueChoice([4, 5, 6], label='b'))
self.final = nn.Linear(nn.ValueChoice([1, 2, 3], label='a') + nn.ValueChoice([4, 5, 6], label='b'), 2)
# forward
return self.final(torch.cat([self.linear1(x), self.linear2(x)], 1))
Some advanced operators are also provided, such as ``nn.ValueChoice.max`` and ``nn.ValueChoice.cond``. See reference of :class:`nni.retiarii.nn.pytorch.ValueChoice` for more details.
.. tip::
All the APIs have an optional argument called ``label``, mutations with the same label will share the same choice. A typical example is,
......
......@@ -598,7 +598,6 @@ class GraphConverter:
return {
'candidates': module.candidates,
'label': module.label,
'accessor': module._accessor
}
def _convert_module(self, script_module, module, module_name, module_python_name, ir_model):
......
......@@ -119,7 +119,7 @@ class Model:
self.graphs: Dict[str, Graph] = {}
self.evaluator: Optional[Evaluator] = None
self.history: List['Model'] = []
self.history: List['Mutation'] = []
self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = []
......
This diff is collapsed.
......@@ -3,7 +3,7 @@
import inspect
from collections import defaultdict
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Dict
import torch.nn as nn
......@@ -13,7 +13,7 @@ from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from nni.retiarii.utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
from .api import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder
from .component import Repeat, NasBench101Cell, NasBench101Mutator
......@@ -65,30 +65,66 @@ class InputChoiceMutator(Mutator):
class ValueChoiceMutator(Mutator):
def __init__(self, nodes: List[Node], candidates: List[Any]):
# use nodes[0] as an example to get label
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes
self.candidates = candidates
def mutate(self, model):
chosen = self.choice(self.candidates)
# no need to support transformation here,
# because it is naturally done in forward loop
for node in self.nodes:
target = model.get_node_by_name(node.name)
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
class ParameterChoiceLeafMutator(Mutator):
# mutate the leaf node (i.e., ValueChoice) of parameter choices
# should be used together with ParameterChoiceMutator
def __init__(self, candidates: List[Any], label: str):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> Model:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
class ParameterChoiceMutator(Mutator):
def __init__(self, nodes: List[Tuple[Node, str]], candidates: List[Any]):
node, argname = nodes[0]
super().__init__(label=node.operation.parameters[argname].label)
# To deal with ValueChoice used as a parameter of a basic unit
# should be used together with ParameterChoiceLeafMutator
# parameter choice mutator is an empty-shell-mutator
# calculate all the parameter values based on previous mutations of value choice mutator
def __init__(self, nodes: List[Tuple[Node, str]]):
super().__init__()
self.nodes = nodes
self.candidates = candidates
def mutate(self, model):
chosen = self.choice(self.candidates)
def mutate(self, model: Model) -> Model:
# looks like {"label1": "cat", "label2": 123}
value_choice_decisions = {}
for mutation in model.history:
if isinstance(mutation.mutator, ParameterChoiceLeafMutator):
value_choice_decisions[mutation.mutator.label] = mutation.samples[0]
for node, argname in self.nodes:
chosen_value = node.operation.parameters[argname].access(chosen)
# argname is the location of the argument
# e.g., Conv2d(out_channels=nn.ValueChoice([1, 2, 3])) => argname = "out_channels"
value_choice: ValueChoiceX = node.operation.parameters[argname]
# calculate all the values on the leaf node of ValueChoiceX computation graph
leaf_node_values = []
for choice in value_choice.inner_choices():
leaf_node_values.append(value_choice_decisions[choice.label])
result_value = value_choice.evaluate(leaf_node_values)
# update model with graph mutation primitives
target = model.get_node_by_name(node.name)
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value})
target.update_operation(target.operation.type, {**target.operation.parameters, argname: result_value})
class RepeatMutator(Mutator):
......@@ -145,18 +181,31 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates'])
applied_mutators.append(mutator)
pc_nodes = []
# `pc_nodes` are arguments of basic units. They can be compositions.
pc_nodes: List[Tuple[Node, str, ValueChoiceX]] = []
for node in model.get_nodes():
for name, choice in node.operation.parameters.items():
if isinstance(choice, ValueChoice):
pc_nodes.append((node, name))
pc_nodes = _group_parameters_by_label(pc_nodes)
for node_list in pc_nodes:
assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \
'Value choice with the same label must have the same candidates.'
first_node, first_argname = node_list[0]
mutator = ParameterChoiceMutator(node_list, first_node.operation.parameters[first_argname].candidates)
applied_mutators.append(mutator)
if isinstance(choice, ValueChoiceX):
# e.g., (conv_node, "out_channels", ValueChoice([1, 3]))
pc_nodes.append((node, name, choice))
# Break `pc_nodes` down to leaf value choices. They should be what we want to sample.
leaf_value_choices: Dict[str, List[Any]] = {}
for _, __, choice in pc_nodes:
for inner_choice in choice.inner_choices():
if inner_choice.label not in leaf_value_choices:
leaf_value_choices[inner_choice.label] = inner_choice.candidates
else:
assert leaf_value_choices[inner_choice.label] == inner_choice.candidates, \
'Value choice with the same label must have the same candidates, but found ' \
f'{leaf_value_choices[inner_choice.label]} vs. {inner_choice.candidates}'
for label, candidates in leaf_value_choices.items():
applied_mutators.append(ParameterChoiceLeafMutator(candidates, label))
# in the end, add another parameter choice mutator for "real" mutations
if pc_nodes:
applied_mutators.append(ParameterChoiceMutator([(node, name) for node, name, _ in pc_nodes]))
# apply layer choice at last as it will delete some nodes
lc_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'layerchoice',
......@@ -236,9 +285,10 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
# tricky case: value choice that serves as parameters are stored in traced arguments
if is_basic_unit(module):
for key, value in module.trace_kwargs.items():
if isinstance(value, ValueChoice):
node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates})
node.label = value.label
if isinstance(value, ValueChoiceX):
for i, choice in enumerate(value.inner_choices()):
node = graph.add_node(f'{name}.init.{key}.{i}', 'ValueChoice', {'candidates': choice.candidates})
node.label = choice.label
if isinstance(module, (LayerChoice, InputChoice, ValueChoice)):
# TODO: check the label of module and warn if it's auto-generated
......@@ -286,46 +336,76 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
# mutations for evaluator
class EvaluatorValueChoiceMutator(Mutator):
def __init__(self, keys: List[str], label: Optional[str]):
self.keys = keys
class EvaluatorValueChoiceLeafMutator(Mutator):
# see "ParameterChoiceLeafMutator"
# works in the same way
def __init__(self, candidates: List[Any], label: str):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> Model:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
class EvaluatorValueChoiceMutator(Mutator):
# works in the same way as `ParameterChoiceMutator`
# we only need one such mutator for one model/evaluator
def mutate(self, model: Model):
# make a copy to mutate the evaluator
model.evaluator = model.evaluator.trace_copy()
chosen = None
for i, key in enumerate(self.keys):
value_choice: ValueChoice = model.evaluator.trace_kwargs[key]
if i == 0:
# i == 0 is needed here because there can be candidates of "None"
chosen = self.choice(value_choice.candidates)
# get the real chosen value after "access"
model.evaluator.trace_kwargs[key] = value_choice.access(chosen)
return model
value_choice_decisions = {}
for mutation in model.history:
if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator):
value_choice_decisions[mutation.mutator.label] = mutation.samples[0]
result = {}
# for each argument that is a composition of value choice
# we find all the leaf-value-choice in the mutation
# and compute the final result
for key, param in model.evaluator.trace_kwargs.items():
if isinstance(param, ValueChoiceX):
leaf_node_values = [value_choice_decisions[choice.label] for choice in param.inner_choices()]
result[key] = param.evaluate(leaf_node_values)
model.evaluator.trace_kwargs.update(result)
def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]:
# take all the value choice in the kwargs of evaluaator into a list
# `existing_mutators` can mutators generated from `model`
if not is_traceable(evaluator):
return []
mutator_candidates = {}
mutator_keys = defaultdict(list)
for key, param in evaluator.trace_kwargs.items():
if isinstance(param, ValueChoice):
if isinstance(param, ValueChoiceX):
for choice in param.inner_choices():
# merge duplicate labels
for mutator in existing_mutators:
if mutator.name == param.label:
raise ValueError(f'Found duplicated labels for mutators {param.label}. When two mutators have the same name, '
'they would share choices. However, sharing choices between model and evaluator is not yet supported.')
if param.label in mutator_candidates and mutator_candidates[param.label] != param.candidates:
raise ValueError(f'Duplicate labels for evaluator ValueChoice {param.label}. They should share choices.'
f'But their candidate list is not equal: {mutator_candidates[param.label][1]} vs. {param.candidates}')
mutator_keys[param.label].append(key)
mutator_candidates[param.label] = param.candidates
if mutator.name == choice.label:
raise ValueError(
f'Found duplicated labels “{choice.label}”. When two value choices have the same name, '
'they would share choices. However, sharing choices between model and evaluator is not yet supported.'
)
if choice.label in mutator_candidates and mutator_candidates[choice.label] != choice.candidates:
raise ValueError(
f'Duplicate labels for evaluator ValueChoice {choice.label}. They should share choices.'
f'But their candidate list is not equal: {mutator_candidates[choice.label][1]} vs. {choice.candidates}'
)
mutator_keys[choice.label].append(key)
mutator_candidates[choice.label] = choice.candidates
mutators = []
for key in mutator_keys:
mutators.append(EvaluatorValueChoiceMutator(mutator_keys[key], key))
for label in mutator_keys:
mutators.append(EvaluatorValueChoiceLeafMutator(mutator_candidates[label], label))
if mutators:
# one last mutator to actually apply the mutations
mutators.append(EvaluatorValueChoiceMutator())
return mutators
......@@ -359,13 +439,3 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
result[label] = []
result[label].append(node)
return list(result.values())
def _group_parameters_by_label(nodes: List[Tuple[Node, str]]) -> List[List[Tuple[Node, str]]]:
result = {}
for node, argname in nodes:
label = node.operation.parameters[argname].label
if label not in result:
result[label] = []
result[label].append((node, argname))
return list(result.values())
import math
import random
import unittest
from collections import Counter
import pytest
import nni.retiarii.nn.pytorch as nn
import torch
import torch.nn.functional as F
......@@ -50,7 +53,19 @@ class MutableConv(nn.Module):
return self.conv2(x)
def _apply_all_mutators(model, mutators, samplers):
if not isinstance(samplers, list):
samplers = [samplers for _ in range(len(mutators))]
assert len(samplers) == len(mutators)
model_new = model
for mutator, sampler in zip(mutators, samplers):
model_new = mutator.bind_sampler(sampler).apply(model_new)
return model_new
class GraphIR(unittest.TestCase):
# graph engine will have an extra mutator for parameter choices
value_choice_incr = 1
def _convert_to_ir(self, model):
script_module = torch.jit.script(model)
......@@ -220,7 +235,7 @@ class GraphIR(unittest.TestCase):
return self.conv(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
self.assertEqual(len(mutators), 1 + self.value_choice_incr)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
......@@ -240,16 +255,16 @@ class GraphIR(unittest.TestCase):
return self.conv(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(len(mutators), self.value_choice_incr + 1)
samplers = [EnumerateSampler() for _ in range(len(mutators))]
model1 = _apply_all_mutators(model, mutators, samplers)
model2 = _apply_all_mutators(model, mutators, samplers)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self):
def test_value_choice_as_two_parameters(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
......@@ -260,13 +275,14 @@ class GraphIR(unittest.TestCase):
return self.conv(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2)
mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler())
self.assertEqual(len(mutators), 2 + self.value_choice_incr)
samplers = [EnumerateSampler() for _ in range(len(mutators))]
model1 = _apply_all_mutators(model, mutators, samplers)
model2 = _apply_all_mutators(model, mutators, samplers)
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
self.assertEqual(self._get_converted_pytorch_model(model1)(input).size(),
torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
self.assertEqual(self._get_converted_pytorch_model(model2)(input).size(),
torch.Size([1, 8, 1, 1]))
def test_value_choice_as_parameter_shared(self):
......@@ -281,10 +297,10 @@ class GraphIR(unittest.TestCase):
return self.conv1(x) + self.conv2(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(len(mutators), 1 + self.value_choice_incr)
sampler = EnumerateSampler()
model1 = _apply_all_mutators(model, mutators, sampler)
model2 = _apply_all_mutators(model, mutators, sampler)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 6, 5, 5]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
......@@ -323,13 +339,11 @@ class GraphIR(unittest.TestCase):
return self.linear(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 3)
self.assertEqual(len(mutators), 3 + self.value_choice_incr)
sz_counter = Counter()
sampler = RandomSampler()
for i in range(100):
model_new = model
for mutator in mutators:
model_new = mutator.bind_sampler(sampler).apply(model_new)
model_new = _apply_all_mutators(model, mutators, sampler)
sz_counter[self._get_converted_pytorch_model(model_new)(torch.randn(1, 3)).size(1)] += 1
self.assertEqual(len(sz_counter), 4)
......@@ -375,7 +389,7 @@ class GraphIR(unittest.TestCase):
self.assertGreater(failed_count, 0)
self.assertLess(failed_count, 30)
def test_valuechoice_access(self):
def test_valuechoice_getitem(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
......@@ -387,12 +401,12 @@ class GraphIR(unittest.TestCase):
return self.conv(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler())
self.assertEqual(len(mutators), 1 + self.value_choice_incr)
sampler = EnumerateSampler()
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
self.assertEqual(self._get_converted_pytorch_model(_apply_all_mutators(model, mutators, sampler))(input).size(),
torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
self.assertEqual(self._get_converted_pytorch_model(_apply_all_mutators(model, mutators, sampler))(input).size(),
torch.Size([1, 8, 1, 1]))
@model_wrapper
......@@ -411,12 +425,11 @@ class GraphIR(unittest.TestCase):
return self.conv1(torch.cat((x, x), 1))
model, mutators = self._get_model_with_mutators(Net2())
self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler())
self.assertEqual(len(mutators), 1 + self.value_choice_incr)
input = torch.randn(1, 3, 5, 5)
self._get_converted_pytorch_model(mutators[0].apply(model))(input)
self._get_converted_pytorch_model(_apply_all_mutators(model, mutators, EnumerateSampler()))(input)
def test_valuechoice_access_functional(self):
def test_valuechoice_getitem_functional(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
......@@ -435,7 +448,7 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_access_functional_expression(self):
def test_valuechoice_getitem_functional_expression(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
......@@ -456,6 +469,43 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_multi(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
choice1 = nn.ValueChoice([{"in": 1, "out": 3}, {"in": 2, "out": 6}, {"in": 3, "out": 9}])
choice2 = nn.ValueChoice([2.5, 3.0, 3.5], label='multi')
choice3 = nn.ValueChoice([2.5, 3.0, 3.5], label='multi')
self.conv1 = nn.Conv2d(choice1["in"], round(choice1["out"] * choice2), 1)
self.conv2 = nn.Conv2d(choice1["in"], round(choice1["out"] * choice3), 1)
def forward(self, x):
return self.conv1(x) + self.conv2(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2 + self.value_choice_incr)
samplers = [EnumerateSampler()] + [RandomSampler() for _ in range(self.value_choice_incr + 1)]
for i in range(10):
model_new = _apply_all_mutators(model, mutators, samplers)
result = self._get_converted_pytorch_model(model_new)(torch.randn(1, i % 3 + 1, 3, 3))
self.assertIn(result.size(), [torch.Size([1, round((i % 3 + 1) * 3 * k), 3, 3]) for k in [2.5, 3.0, 3.5]])
def test_valuechoice_inconsistent_label(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, nn.ValueChoice([3, 5], label='a'), 1)
self.conv2 = nn.Conv2d(3, nn.ValueChoice([3, 6], label='a'), 1)
def forward(self, x):
return torch.cat([self.conv1(x), self.conv2(x)], 1)
with pytest.raises(AssertionError):
self._get_model_with_mutators(Net())
def test_repeat(self):
class AddOne(nn.Module):
def forward(self, x):
......@@ -645,6 +695,9 @@ class GraphIR(unittest.TestCase):
class Python(GraphIR):
# Python engine doesn't have the extra mutator
value_choice_incr = 0
def _get_converted_pytorch_model(self, model_ir):
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
with ContextStack('fixed', mutation):
......@@ -661,10 +714,10 @@ class Python(GraphIR):
def test_value_choice_in_functional(self): ...
@unittest.skip
def test_valuechoice_access_functional(self): ...
def test_valuechoice_getitem_functional(self): ...
@unittest.skip
def test_valuechoice_access_functional_expression(self): ...
def test_valuechoice_getitem_functional_expression(self): ...
def test_cell_loose_end(self):
@model_wrapper
......@@ -744,6 +797,95 @@ class Python(GraphIR):
class Shared(unittest.TestCase):
# This kind of tests are general across execution engines
def test_value_choice_api_purely(self):
a = nn.ValueChoice([1, 2], label='a')
b = nn.ValueChoice([3, 4], label='b')
c = nn.ValueChoice([5, 6], label='c')
d = a + b + 3 * c
for i, choice in enumerate(d.inner_choices()):
if i == 0:
assert choice.candidates == [1, 2]
elif i == 1:
assert choice.candidates == [3, 4]
elif i == 2:
assert choice.candidates == [5, 6]
assert d.evaluate([2, 3, 5]) == 20
a = nn.ValueChoice(['cat', 'dog'])
b = nn.ValueChoice(['milk', 'coffee'])
assert (a + b).evaluate(['dog', 'coffee']) == 'dogcoffee'
assert (a + 2 * b).evaluate(['cat', 'milk']) == 'catmilkmilk'
assert (3 - nn.ValueChoice([1, 2])).evaluate([1]) == 2
with pytest.raises(TypeError):
a + nn.ValueChoice([1, 3])
a = nn.ValueChoice([1, 17])
a = (abs(-a * 3) % 11) ** 5
assert 'abs' in repr(a)
with pytest.raises(ValueError):
a.evaluate([42])
assert a.evaluate([17]) == 7 ** 5
a = round(7 / nn.ValueChoice([2, 5]))
assert a.evaluate([2]) == 4
a = ~(77 ^ (nn.ValueChoice([1, 4]) & 5))
assert a.evaluate([4]) == ~(77 ^ (4 & 5))
a = nn.ValueChoice([5, 3]) * nn.ValueChoice([6.5, 7.5])
assert math.floor(a.evaluate([5, 7.5])) == int(5 * 7.5)
a = nn.ValueChoice([1, 3])
b = nn.ValueChoice([2, 4])
with pytest.raises(RuntimeError):
min(a, b)
with pytest.raises(RuntimeError):
if a < b:
...
assert nn.ValueChoice.min(a, b).evaluate([3, 2]) == 2
assert nn.ValueChoice.max(a, b).evaluate([3, 2]) == 3
assert nn.ValueChoice.max(1, 2, 3) == 3
assert nn.ValueChoice.max([1, 3, 2]) == 3
assert nn.ValueChoice.condition(nn.ValueChoice([2, 3]) <= 2, 'a', 'b').evaluate([3]) == 'b'
assert nn.ValueChoice.condition(nn.ValueChoice([2, 3]) <= 2, 'a', 'b').evaluate([2]) == 'a'
with pytest.raises(RuntimeError):
assert int(nn.ValueChoice([2.5, 3.5])).evalute([2.5]) == 2
assert nn.ValueChoice.to_int(nn.ValueChoice([2.5, 3.5])).evaluate([2.5]) == 2
assert nn.ValueChoice.to_float(nn.ValueChoice(['2.5', '3.5'])).evaluate(['3.5']) == 3.5
def test_make_divisible(self):
def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
if min_value is None:
min_value = divisor
new_value = nn.ValueChoice.max(min_value, nn.ValueChoice.to_int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
return nn.ValueChoice.condition(new_value < min_ratio * value, new_value + divisor, new_value)
def original_make_divisible(value, divisor, min_value=None, min_ratio=0.9):
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
if new_value < min_ratio * value:
new_value += divisor
return new_value
values = [4, 8, 16, 32, 64, 128]
divisors = [2, 3, 5, 7, 15]
with pytest.raises(RuntimeError):
original_make_divisible(nn.ValueChoice(values, label='value'), nn.ValueChoice(divisors, label='divisor'))
result = make_divisible(nn.ValueChoice(values, label='value'), nn.ValueChoice(divisors, label='divisor'))
for value in values:
for divisor in divisors:
lst = [value if choice.label == 'value' else divisor for choice in result.inner_choices()]
assert result.evaluate(lst) == original_make_divisible(value, divisor)
def test_valuechoice_in_evaluator(self):
def foo():
pass
......@@ -753,28 +895,28 @@ class Shared(unittest.TestCase):
evaluator = FunctionalEvaluator(foo, t=1, x=ValueChoice([1, 2]), y=ValueChoice([3, 4]))
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 2
assert len(mutators) == 3
init_model = Model(_internal=True)
init_model.evaluator = evaluator
sampler = EnumerateSampler()
model = mutators[0].bind_sampler(sampler).apply(init_model)
samplers = [EnumerateSampler() for _ in range(3)]
model = _apply_all_mutators(init_model, mutators, samplers)
assert model.evaluator.trace_kwargs['x'] == 1
model = mutators[0].bind_sampler(sampler).apply(init_model)
model = _apply_all_mutators(init_model, mutators, samplers)
assert model.evaluator.trace_kwargs['x'] == 2
# share label
evaluator = FunctionalEvaluator(foo, t=ValueChoice([1, 2], label='x'), x=ValueChoice([1, 2], label='x'))
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 1
assert len(mutators) == 2
# getitem
choice = ValueChoice([{"a": 1, "b": 2}, {"a": 3, "b": 4}])
evaluator = FunctionalEvaluator(foo, t=1, x=choice['a'], y=choice['b'])
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 1
assert len(mutators) == 2
init_model = Model(_internal=True)
init_model.evaluator = evaluator
sampler = RandomSampler()
for _ in range(10):
model = mutators[0].bind_sampler(sampler).apply(init_model)
model = _apply_all_mutators(init_model, mutators, sampler)
assert (model.evaluator.trace_kwargs['x'], model.evaluator.trace_kwargs['y']) in [(1, 2), (3, 4)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment