Unverified Commit a36dc07e authored by Yuge Zhang's avatar Yuge Zhang Committed by GitHub
Browse files

Composition of `ValueChoice` (#4435)

parent f8327ba0
......@@ -68,6 +68,44 @@ Examples are as follows:
self.evaluator = FunctionalEvaluator(train_and_evaluate, learning_rate=nn.ValueChoice([1e-3, 1e-2, 1e-1]))
Value choices supports arithmetic operators, which is particularly useful when searching for a network width multiplier:
.. code-block:: python
# init
scale = nn.ValueChoice([1.0, 1.5, 2.0])
self.conv1 = nn.Conv2d(3, round(scale * 16))
self.conv2 = nn.Conv2d(round(scale * 16), round(scale * 64))
self.conv3 = nn.Conv2d(round(scale * 64), round(scale * 256))
# forward
return self.conv3(self.conv2(self.conv1(x)))
Or when kernel size and padding are coupled so as to keep the output size constant:
.. code-block:: python
# init
ks = nn.ValueChoice([3, 5, 7])
self.conv = nn.Conv2d(3, 16, kernel_size=ks, padding=(ks - 1) // 2)
# forward
return self.conv(x)
Or when several layers are concatenated for a final layer.
.. code-block:: python
# init
self.linear1 = nn.Linear(3, nn.ValueChoice([1, 2, 3], label='a'))
self.linear2 = nn.Linear(3, nn.ValueChoice([4, 5, 6], label='b'))
self.final = nn.Linear(nn.ValueChoice([1, 2, 3], label='a') + nn.ValueChoice([4, 5, 6], label='b'), 2)
# forward
return self.final(torch.cat([self.linear1(x), self.linear2(x)], 1))
Some advanced operators are also provided, such as ``nn.ValueChoice.max`` and ``nn.ValueChoice.cond``. See reference of :class:`nni.retiarii.nn.pytorch.ValueChoice` for more details.
.. tip::
All the APIs have an optional argument called ``label``, mutations with the same label will share the same choice. A typical example is,
......
......@@ -598,7 +598,6 @@ class GraphConverter:
return {
'candidates': module.candidates,
'label': module.label,
'accessor': module._accessor
}
def _convert_module(self, script_module, module, module_name, module_python_name, ir_model):
......
......@@ -119,7 +119,7 @@ class Model:
self.graphs: Dict[str, Graph] = {}
self.evaluator: Optional[Evaluator] = None
self.history: List['Model'] = []
self.history: List['Mutation'] = []
self.metric: Optional[MetricData] = None
self.intermediate_metrics: List[MetricData] = []
......
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import copy
import math
import operator
import warnings
from typing import Any, List, Union, Dict, Optional
from typing import Any, List, Union, Dict, Optional, Callable, Iterable, NoReturn, TypeVar
import torch
import torch.nn as nn
......@@ -156,6 +157,14 @@ class LayerChoice(Mutable):
return f'LayerChoice({self.candidates}, label={repr(self.label)})'
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
ReductionType = Literal['mean', 'concat', 'sum', 'none']
class InputChoice(Mutable):
"""
Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys).
......@@ -183,7 +192,8 @@ class InputChoice(Mutable):
"""
@classmethod
def create_fixed_module(cls, n_candidates: int, n_chosen: Optional[int] = 1, reduction: str = 'sum', *,
def create_fixed_module(cls, n_candidates: int, n_chosen: Optional[int] = 1,
reduction: ReductionType = 'sum', *,
prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
return ChosenInputs(get_fixed_value(label), reduction=reduction)
......@@ -228,7 +238,428 @@ class InputChoice(Mutable):
f'reduction={repr(self.reduction)}, label={repr(self.label)})'
class ValueChoice(Translatable, Mutable):
class ChosenInputs(nn.Module):
"""
A module that chooses from a tensor list and outputs a reduced tensor.
The already-chosen version of InputChoice.
When forward, ``chosen`` will be used to select inputs from ``candidate_inputs``,
and ``reduction`` will be used to choose from those inputs to form a tensor.
Attributes
----------
chosen : list of int
Indices of chosen inputs.
reduction : ``mean`` | ``concat`` | ``sum`` | ``none``
How to reduce the inputs when multiple are selected.
"""
def __init__(self, chosen: Union[List[int], int], reduction: ReductionType):
super().__init__()
self.chosen = chosen if isinstance(chosen, list) else [chosen]
self.reduction = reduction
def forward(self, candidate_inputs):
return self._tensor_reduction(self.reduction, [candidate_inputs[i] for i in self.chosen])
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == 'none':
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == 'sum':
return sum(tensor_list)
if reduction_type == 'mean':
return sum(tensor_list) / len(tensor_list)
if reduction_type == 'concat':
return torch.cat(tensor_list, dim=1)
raise ValueError(f'Unrecognized reduction policy: "{reduction_type}"')
# the code in ValueChoice can be generated with this codegen
# this is not done online because I want to have type-hint supports
# $ python -c "from nni.retiarii.nn.pytorch.api import _valuechoice_codegen; _valuechoice_codegen(_internal=True)"
def _valuechoice_codegen(*, _internal: bool = False):
if not _internal:
raise RuntimeError("This method is set to be internal. Please don't use it directly.")
MAPPING = {
# unary
'neg': '-', 'pos': '+', 'invert': '~',
# binary
'add': '+', 'sub': '-', 'mul': '*', 'matmul': '@',
'truediv': '//', 'floordiv': '/', 'mod': '%',
'lshift': '<<', 'rshift': '>>',
'and': '&', 'xor': '^', 'or': '|',
# no reflection
'lt': '<', 'le': '<=', 'eq': '==',
'ne': '!=', 'ge': '>=', 'gt': '>',
# NOTE
# Currently we don't support operators like __contains__ (b in a),
# Might support them in future when we actually need them.
}
binary_template = """ def __{op}__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [self, other])"""
binary_r_template = """ def __r{op}__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [other, self])"""
unary_template = """ def __{op}__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.{op}, '{sym}{{}}', [self])"""
for op, sym in MAPPING.items():
if op in ['neg', 'pos', 'invert']:
print(unary_template.format(op=op, sym=sym) + '\n')
else:
opt = op + '_' if op in ['and', 'or'] else op
print(binary_template.format(op=op, opt=opt, sym=sym) + '\n')
if op not in ['lt', 'le', 'eq', 'ne', 'ge', 'gt']:
print(binary_r_template.format(op=op, opt=opt, sym=sym) + '\n')
def _valuechoice_staticmethod_helper(orig_func):
orig_func.__doc__ += """
Notes
-----
This function performs lazy evaluation.
Only the expression will be recorded when the function is called.
The real evaluation happens when the inner value choice has determined its final decision.
If no value choice is contained in the parameter list, the evaluation will be intermediate."""
return orig_func
class ValueChoiceX(Translatable):
"""Internal API. Implementation note:
The transformed (X) version of value choice.
It can be the result of composition (transformation) of one or several value choices. For example,
.. code-block:: python
nn.ValueChoice([1, 2]) + nn.ValueChoice([3, 4]) + 5
The instance of base class cannot be created directly. Instead, they should be only the result of transformation of value choice.
Therefore, there is no need to implement ``create_fixed_module`` in this class, because,
1. For python-engine, value choice itself has create fixed module. Consequently, the transformation is born to be fixed.
2. For graph-engine, it uses evaluate to calculate the result.
Potentially, we have to implement the evaluation logic in oneshot algorithms. I believe we can postpone the discussion till then.
"""
def __init__(self, function: Callable[..., Any], repr_template: str, arguments: List[Any], dry_run: bool = True):
super().__init__()
if function is None:
# this case is a hack for ValueChoice subclass
# it will reach here only because ``__init__`` in ``nn.Module`` is useful.
return
self.function = function
self.repr_template = repr_template
self.arguments = arguments
assert any(isinstance(arg, ValueChoiceX) for arg in self.arguments)
if dry_run:
# for sanity check
self.dry_run()
def inner_choices(self) -> Iterable['ValueChoice']:
"""
Return an iterable of all leaf value choices.
Useful for composition of value choices.
No deduplication on labels. Mutators should take care.
"""
for arg in self.arguments:
if isinstance(arg, ValueChoiceX):
yield from arg.inner_choices()
def dry_run(self) -> Any:
"""
Dry run the value choice to get one of its possible evaluation results.
"""
# values are not used
return self._evaluate(iter([]), True)
def evaluate(self, values: Iterable[Any]) -> Any:
"""
Evaluate the result of this group.
``values`` should in the same order of ``inner_choices()``.
"""
return self._evaluate(iter(values), False)
def _evaluate(self, values: Iterable[Any], dry_run: bool = False) -> Any:
# "values" iterates in the recursion
eval_args = []
for arg in self.arguments:
if isinstance(arg, ValueChoiceX):
# recursive evaluation
eval_args.append(arg._evaluate(values, dry_run))
# the recursion will stop when it hits a leaf node (value choice)
# the implementation is in `ValueChoice`
else:
# constant value
eval_args.append(arg)
return self.function(*eval_args)
def _translate(self):
"""
Try to behave like one of its candidates when used in ``basic_unit``.
"""
return self.dry_run()
def __repr__(self):
reprs = []
for arg in self.arguments:
if isinstance(arg, ValueChoiceX) and not isinstance(arg, ValueChoice):
reprs.append('(' + repr(arg) + ')') # add parenthesis for operator priority
else:
reprs.append(repr(arg))
return self.repr_template.format(*reprs)
# the following are a series of methods to create "ValueChoiceX"
# which is a transformed version of value choice
# https://docs.python.org/3/reference/datamodel.html#special-method-names
# Special operators that can be useful in place of built-in conditional operators.
@staticmethod
@_valuechoice_staticmethod_helper
def to_int(obj: 'ValueChoiceOrAny') -> Union['ValueChoiceX', int]:
"""
Convert a ``ValueChoice`` to an integer.
"""
if isinstance(obj, ValueChoiceX):
return ValueChoiceX(int, 'int({})', [obj])
return int(obj)
@staticmethod
@_valuechoice_staticmethod_helper
def to_float(obj: 'ValueChoiceOrAny') -> Union['ValueChoiceX', float]:
"""
Convert a ``ValueChoice`` to a float.
"""
if isinstance(obj, ValueChoiceX):
return ValueChoiceX(float, 'float({})', [obj])
return float(obj)
@staticmethod
@_valuechoice_staticmethod_helper
def condition(pred: 'ValueChoiceOrAny',
true: 'ValueChoiceOrAny',
false: 'ValueChoiceOrAny') -> 'ValueChoiceOrAny':
"""
Return ``true`` if the predicate ``pred`` is true else ``false``.
Examples
--------
>>> ValueChoice.condition(ValueChoice([1, 2]) > ValueChoice([0, 3]), 2, 1)
"""
if any(isinstance(obj, ValueChoiceX) for obj in [pred, true, false]):
return ValueChoiceX(lambda t, c, f: t if c else f, '{} if {} else {}', [true, pred, false])
return true if pred else false
@staticmethod
@_valuechoice_staticmethod_helper
def max(arg0: Union[Iterable['ValueChoiceOrAny'], 'ValueChoiceOrAny'],
*args: List['ValueChoiceOrAny']) -> 'ValueChoiceOrAny':
"""
Returns the maximum value from a list of value choices.
The usage should be similar to Python's built-in value choices,
where the parameters could be an iterable, or at least two arguments.
"""
if not args:
return ValueChoiceX.max(*list(arg0))
lst = [arg0] + list(args)
if any(isinstance(obj, ValueChoiceX) for obj in lst):
return ValueChoiceX(max, 'max({})', lst)
return max(lst)
@staticmethod
@_valuechoice_staticmethod_helper
def min(arg0: Union[Iterable['ValueChoiceOrAny'], 'ValueChoiceOrAny'],
*args: List['ValueChoiceOrAny']) -> 'ValueChoiceOrAny':
"""
Returns the minunum value from a list of value choices.
The usage should be similar to Python's built-in value choices,
where the parameters could be an iterable, or at least two arguments.
"""
if not args:
return ValueChoiceX.min(*list(arg0))
lst = [arg0] + list(args)
if any(isinstance(obj, ValueChoiceX) for obj in lst):
return ValueChoiceX(min, 'min({})', lst)
return min(lst)
def __hash__(self):
# this is required because we have implemented ``__eq__``
return id(self)
# NOTE:
# Write operations are not supported. Reasons follow:
# - Semantics are not clear. It can be applied to "all" the inner candidates, or only the chosen one.
# - Implementation effort is too huge.
# As a result, inplace operators like +=, *=, magic methods like `__getattr__` are not included in this list.
def __getitem__(self, key: Any) -> 'ValueChoiceX':
return ValueChoiceX(lambda x, y: x[y], '{}[{}]', [self, key])
# region implement int, float, round, trunc, floor, ceil
# because I believe sometimes we need them to calculate #channels
# `__int__` and `__float__` are not supported because `__int__` is required to return int.
def __round__(self, ndigits: Optional[Any] = None) -> 'ValueChoiceX':
if ndigits is not None:
return ValueChoiceX(round, 'round({}, {})', [self, ndigits])
return ValueChoiceX(round, 'round({})', [self])
def __trunc__(self) -> 'ValueChoiceX':
raise RuntimeError("Try to use `ValueChoice.to_int()` instead of `math.trunc()` on value choices.")
def __floor__(self) -> 'ValueChoiceX':
return ValueChoiceX(math.floor, 'math.floor({})', [self])
def __ceil__(self) -> 'ValueChoiceX':
return ValueChoiceX(math.ceil, 'math.ceil({})', [self])
def __index__(self) -> NoReturn:
# https://docs.python.org/3/reference/datamodel.html#object.__index__
raise RuntimeError("`__index__` is not allowed on ValueChoice, which means you can't "
"use int(), float(), complex(), range() on a ValueChoice.")
def __bool__(self) -> NoReturn:
raise RuntimeError('Cannot use bool() on ValueChoice. That means, using ValueChoice in a if-clause is illegal. '
'Please try methods like `ValueChoice.max(a, b)` to see whether that meets your needs.')
# endregion
# region the following code is generated with codegen (see above)
# Annotated with "region" because I want to collapse them in vscode
def __neg__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.neg, '-{}', [self])
def __pos__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.pos, '+{}', [self])
def __invert__(self) -> 'ValueChoiceX':
return ValueChoiceX(operator.invert, '~{}', [self])
def __add__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.add, '{} + {}', [self, other])
def __radd__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.add, '{} + {}', [other, self])
def __sub__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.sub, '{} - {}', [self, other])
def __rsub__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.sub, '{} - {}', [other, self])
def __mul__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.mul, '{} * {}', [self, other])
def __rmul__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.mul, '{} * {}', [other, self])
def __matmul__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.matmul, '{} @ {}', [self, other])
def __rmatmul__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.matmul, '{} @ {}', [other, self])
def __truediv__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.truediv, '{} // {}', [self, other])
def __rtruediv__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.truediv, '{} // {}', [other, self])
def __floordiv__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.floordiv, '{} / {}', [self, other])
def __rfloordiv__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.floordiv, '{} / {}', [other, self])
def __mod__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.mod, '{} % {}', [self, other])
def __rmod__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.mod, '{} % {}', [other, self])
def __lshift__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.lshift, '{} << {}', [self, other])
def __rlshift__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.lshift, '{} << {}', [other, self])
def __rshift__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.rshift, '{} >> {}', [self, other])
def __rrshift__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.rshift, '{} >> {}', [other, self])
def __and__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.and_, '{} & {}', [self, other])
def __rand__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.and_, '{} & {}', [other, self])
def __xor__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.xor, '{} ^ {}', [self, other])
def __rxor__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.xor, '{} ^ {}', [other, self])
def __or__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.or_, '{} | {}', [self, other])
def __ror__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.or_, '{} | {}', [other, self])
def __lt__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.lt, '{} < {}', [self, other])
def __le__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.le, '{} <= {}', [self, other])
def __eq__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.eq, '{} == {}', [self, other])
def __ne__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.ne, '{} != {}', [self, other])
def __ge__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.ge, '{} >= {}', [self, other])
def __gt__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(operator.gt, '{} > {}', [self, other])
# endregion
# __pow__, __divmod__, __abs__ are special ones.
# Not easy to cover those cases with codegen.
def __pow__(self, other: Any, modulo: Optional[Any] = None) -> 'ValueChoiceX':
if modulo is not None:
return ValueChoiceX(pow, 'pow({}, {}, {})', [self, other, modulo])
return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [self, other])
def __rpow__(self, other: Any, modulo: Optional[Any] = None) -> 'ValueChoiceX':
if modulo is not None:
return ValueChoiceX(pow, 'pow({}, {}, {})', [other, self, modulo])
return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [other, self])
def __divmod__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(divmod, 'divmod({}, {})', [self, other])
def __rdivmod__(self, other: Any) -> 'ValueChoiceX':
return ValueChoiceX(divmod, 'divmod({}, {})', [other, self])
def __abs__(self) -> 'ValueChoiceX':
return ValueChoiceX(abs, 'abs({})', [self])
ValueChoiceOrAny = TypeVar('ValueChoiceOrAny', ValueChoiceX, Any)
class ValueChoice(ValueChoiceX, Mutable):
"""
ValueChoice is to choose one from ``candidates``.
......@@ -298,10 +729,13 @@ class ValueChoice(Translatable, Mutable):
@classmethod
def create_fixed_module(cls, candidates: List[Any], *, label: Optional[str] = None, **kwargs):
return get_fixed_value(label)
value = get_fixed_value(label)
if value not in candidates:
raise ValueError(f'Value {value} does not belong to the candidates: {candidates}.')
return value
def __init__(self, candidates: List[Any], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
super().__init__()
super().__init__(None, None, None)
self.candidates = candidates
self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
......@@ -316,50 +750,34 @@ class ValueChoice(Translatable, Mutable):
warnings.warn('You should not run forward of this module directly.')
return self.candidates[0]
def _translate(self):
# Will function as a value when used in serializer.
return self.access(self.candidates[0])
def inner_choices(self) -> Iterable['ValueChoice']:
# yield self because self is the only value choice here
yield self
def __repr__(self):
return f'ValueChoice({self.candidates}, label={repr(self.label)})'
def dry_run(self) -> Any:
return self.candidates[0]
def access(self, value):
if not self._accessor:
return value
def _evaluate(self, values: Iterable[Any], dry_run: bool = False) -> Any:
if dry_run:
return self.candidates[0]
try:
v = value
for a in self._accessor:
v = v[a]
except KeyError:
raise KeyError(''.join([f'[{a}]' for a in self._accessor]) + f' does not work on {value}')
return v
def __copy__(self):
return self
def __deepcopy__(self, memo):
new_item = ValueChoice(self.candidates, label=self.label)
new_item._accessor = [*self._accessor]
return new_item
def __getitem__(self, item):
"""
Get a sub-element of value choice.
value = next(values)
except StopIteration:
raise ValueError(f'Value list {values} is exhausted when trying to get a chosen value of {self}.')
if value not in self.candidates:
raise ValueError(f'Value {value} does not belong to the candidates of {self}.')
return value
The underlying implementation is to clone the current instance, and append item to "accessor", which records all
the history getitem calls. For example, when accessor is ``[a, b, c]``, the value choice will return ``vc[a][b][c]``
where ``vc`` is the original value choice.
"""
access = copy.deepcopy(self)
access._accessor.append(item)
for candidate in self.candidates:
access.access(candidate)
return access
def __repr__(self):
return f'ValueChoice({self.candidates}, label={repr(self.label)})'
@basic_unit
class Placeholder(nn.Module):
# TODO: docstring
"""
The API that creates an empty module for later mutations.
For advanced usages only.
"""
def __init__(self, label, **related_info):
self.label = label
......@@ -368,38 +786,3 @@ class Placeholder(nn.Module):
def forward(self, x):
return x
class ChosenInputs(nn.Module):
"""
A module that chooses from a tensor list and outputs a reduced tensor.
The already-chosen version of InputChoice.
Attributes
----------
chosen : list of int
Indices of chosen candidates.
"""
def __init__(self, chosen: Union[List[int], int], reduction: str):
super().__init__()
self.chosen = chosen if isinstance(chosen, list) else [chosen]
self.reduction = reduction
def forward(self, candidate_inputs):
return self._tensor_reduction(self.reduction, [candidate_inputs[i] for i in self.chosen])
def _tensor_reduction(self, reduction_type, tensor_list):
if reduction_type == 'none':
return tensor_list
if not tensor_list:
return None # empty. return None for now
if len(tensor_list) == 1:
return tensor_list[0]
if reduction_type == 'sum':
return sum(tensor_list)
if reduction_type == 'mean':
return sum(tensor_list) / len(tensor_list)
if reduction_type == 'concat':
return torch.cat(tensor_list, dim=1)
raise ValueError(f'Unrecognized reduction policy: "{reduction_type}"')
......@@ -3,7 +3,7 @@
import inspect
from collections import defaultdict
from typing import Any, List, Optional, Tuple
from typing import Any, List, Optional, Tuple, Dict
import torch.nn as nn
......@@ -13,7 +13,7 @@ from nni.retiarii.mutator import Mutator
from nni.retiarii.serializer import is_basic_unit, is_model_wrapped
from nni.retiarii.utils import uid
from .api import LayerChoice, InputChoice, ValueChoice, Placeholder
from .api import LayerChoice, InputChoice, ValueChoice, ValueChoiceX, Placeholder
from .component import Repeat, NasBench101Cell, NasBench101Mutator
......@@ -65,30 +65,66 @@ class InputChoiceMutator(Mutator):
class ValueChoiceMutator(Mutator):
def __init__(self, nodes: List[Node], candidates: List[Any]):
# use nodes[0] as an example to get label
super().__init__(label=nodes[0].operation.parameters['label'])
self.nodes = nodes
self.candidates = candidates
def mutate(self, model):
chosen = self.choice(self.candidates)
# no need to support transformation here,
# because it is naturally done in forward loop
for node in self.nodes:
target = model.get_node_by_name(node.name)
target.update_operation('prim::Constant', {'type': type(chosen).__name__, 'value': chosen})
class ParameterChoiceLeafMutator(Mutator):
# mutate the leaf node (i.e., ValueChoice) of parameter choices
# should be used together with ParameterChoiceMutator
def __init__(self, candidates: List[Any], label: str):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> Model:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
class ParameterChoiceMutator(Mutator):
def __init__(self, nodes: List[Tuple[Node, str]], candidates: List[Any]):
node, argname = nodes[0]
super().__init__(label=node.operation.parameters[argname].label)
# To deal with ValueChoice used as a parameter of a basic unit
# should be used together with ParameterChoiceLeafMutator
# parameter choice mutator is an empty-shell-mutator
# calculate all the parameter values based on previous mutations of value choice mutator
def __init__(self, nodes: List[Tuple[Node, str]]):
super().__init__()
self.nodes = nodes
self.candidates = candidates
def mutate(self, model):
chosen = self.choice(self.candidates)
def mutate(self, model: Model) -> Model:
# looks like {"label1": "cat", "label2": 123}
value_choice_decisions = {}
for mutation in model.history:
if isinstance(mutation.mutator, ParameterChoiceLeafMutator):
value_choice_decisions[mutation.mutator.label] = mutation.samples[0]
for node, argname in self.nodes:
chosen_value = node.operation.parameters[argname].access(chosen)
# argname is the location of the argument
# e.g., Conv2d(out_channels=nn.ValueChoice([1, 2, 3])) => argname = "out_channels"
value_choice: ValueChoiceX = node.operation.parameters[argname]
# calculate all the values on the leaf node of ValueChoiceX computation graph
leaf_node_values = []
for choice in value_choice.inner_choices():
leaf_node_values.append(value_choice_decisions[choice.label])
result_value = value_choice.evaluate(leaf_node_values)
# update model with graph mutation primitives
target = model.get_node_by_name(node.name)
target.update_operation(target.operation.type, {**target.operation.parameters, argname: chosen_value})
target.update_operation(target.operation.type, {**target.operation.parameters, argname: result_value})
class RepeatMutator(Mutator):
......@@ -145,18 +181,31 @@ def process_inline_mutation(model: Model) -> Optional[List[Mutator]]:
mutator = ValueChoiceMutator(node_list, node_list[0].operation.parameters['candidates'])
applied_mutators.append(mutator)
pc_nodes = []
# `pc_nodes` are arguments of basic units. They can be compositions.
pc_nodes: List[Tuple[Node, str, ValueChoiceX]] = []
for node in model.get_nodes():
for name, choice in node.operation.parameters.items():
if isinstance(choice, ValueChoice):
pc_nodes.append((node, name))
pc_nodes = _group_parameters_by_label(pc_nodes)
for node_list in pc_nodes:
assert _is_all_equal([node.operation.parameters[name].candidates for node, name in node_list]), \
'Value choice with the same label must have the same candidates.'
first_node, first_argname = node_list[0]
mutator = ParameterChoiceMutator(node_list, first_node.operation.parameters[first_argname].candidates)
applied_mutators.append(mutator)
if isinstance(choice, ValueChoiceX):
# e.g., (conv_node, "out_channels", ValueChoice([1, 3]))
pc_nodes.append((node, name, choice))
# Break `pc_nodes` down to leaf value choices. They should be what we want to sample.
leaf_value_choices: Dict[str, List[Any]] = {}
for _, __, choice in pc_nodes:
for inner_choice in choice.inner_choices():
if inner_choice.label not in leaf_value_choices:
leaf_value_choices[inner_choice.label] = inner_choice.candidates
else:
assert leaf_value_choices[inner_choice.label] == inner_choice.candidates, \
'Value choice with the same label must have the same candidates, but found ' \
f'{leaf_value_choices[inner_choice.label]} vs. {inner_choice.candidates}'
for label, candidates in leaf_value_choices.items():
applied_mutators.append(ParameterChoiceLeafMutator(candidates, label))
# in the end, add another parameter choice mutator for "real" mutations
if pc_nodes:
applied_mutators.append(ParameterChoiceMutator([(node, name) for node, name, _ in pc_nodes]))
# apply layer choice at last as it will delete some nodes
lc_nodes = _group_by_label(filter(lambda d: d.operation.parameters.get('mutation') == 'layerchoice',
......@@ -236,9 +285,10 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
# tricky case: value choice that serves as parameters are stored in traced arguments
if is_basic_unit(module):
for key, value in module.trace_kwargs.items():
if isinstance(value, ValueChoice):
node = graph.add_node(name + '.init.' + key, 'ValueChoice', {'candidates': value.candidates})
node.label = value.label
if isinstance(value, ValueChoiceX):
for i, choice in enumerate(value.inner_choices()):
node = graph.add_node(f'{name}.init.{key}.{i}', 'ValueChoice', {'candidates': choice.candidates})
node.label = choice.label
if isinstance(module, (LayerChoice, InputChoice, ValueChoice)):
# TODO: check the label of module and warn if it's auto-generated
......@@ -286,46 +336,76 @@ def extract_mutation_from_pt_module(pytorch_model: nn.Module) -> Tuple[Model, Op
# mutations for evaluator
class EvaluatorValueChoiceMutator(Mutator):
def __init__(self, keys: List[str], label: Optional[str]):
self.keys = keys
class EvaluatorValueChoiceLeafMutator(Mutator):
# see "ParameterChoiceLeafMutator"
# works in the same way
def __init__(self, candidates: List[Any], label: str):
super().__init__(label=label)
self.candidates = candidates
def mutate(self, model: Model) -> Model:
# leave a record here
# real mutations will be done in ParameterChoiceMutator
self.choice(self.candidates)
class EvaluatorValueChoiceMutator(Mutator):
# works in the same way as `ParameterChoiceMutator`
# we only need one such mutator for one model/evaluator
def mutate(self, model: Model):
# make a copy to mutate the evaluator
model.evaluator = model.evaluator.trace_copy()
chosen = None
for i, key in enumerate(self.keys):
value_choice: ValueChoice = model.evaluator.trace_kwargs[key]
if i == 0:
# i == 0 is needed here because there can be candidates of "None"
chosen = self.choice(value_choice.candidates)
# get the real chosen value after "access"
model.evaluator.trace_kwargs[key] = value_choice.access(chosen)
return model
value_choice_decisions = {}
for mutation in model.history:
if isinstance(mutation.mutator, EvaluatorValueChoiceLeafMutator):
value_choice_decisions[mutation.mutator.label] = mutation.samples[0]
result = {}
# for each argument that is a composition of value choice
# we find all the leaf-value-choice in the mutation
# and compute the final result
for key, param in model.evaluator.trace_kwargs.items():
if isinstance(param, ValueChoiceX):
leaf_node_values = [value_choice_decisions[choice.label] for choice in param.inner_choices()]
result[key] = param.evaluate(leaf_node_values)
model.evaluator.trace_kwargs.update(result)
def process_evaluator_mutations(evaluator: Evaluator, existing_mutators: List[Mutator]) -> List[Mutator]:
# take all the value choice in the kwargs of evaluaator into a list
# `existing_mutators` can mutators generated from `model`
if not is_traceable(evaluator):
return []
mutator_candidates = {}
mutator_keys = defaultdict(list)
for key, param in evaluator.trace_kwargs.items():
if isinstance(param, ValueChoice):
if isinstance(param, ValueChoiceX):
for choice in param.inner_choices():
# merge duplicate labels
for mutator in existing_mutators:
if mutator.name == param.label:
raise ValueError(f'Found duplicated labels for mutators {param.label}. When two mutators have the same name, '
'they would share choices. However, sharing choices between model and evaluator is not yet supported.')
if param.label in mutator_candidates and mutator_candidates[param.label] != param.candidates:
raise ValueError(f'Duplicate labels for evaluator ValueChoice {param.label}. They should share choices.'
f'But their candidate list is not equal: {mutator_candidates[param.label][1]} vs. {param.candidates}')
mutator_keys[param.label].append(key)
mutator_candidates[param.label] = param.candidates
if mutator.name == choice.label:
raise ValueError(
f'Found duplicated labels “{choice.label}”. When two value choices have the same name, '
'they would share choices. However, sharing choices between model and evaluator is not yet supported.'
)
if choice.label in mutator_candidates and mutator_candidates[choice.label] != choice.candidates:
raise ValueError(
f'Duplicate labels for evaluator ValueChoice {choice.label}. They should share choices.'
f'But their candidate list is not equal: {mutator_candidates[choice.label][1]} vs. {choice.candidates}'
)
mutator_keys[choice.label].append(key)
mutator_candidates[choice.label] = choice.candidates
mutators = []
for key in mutator_keys:
mutators.append(EvaluatorValueChoiceMutator(mutator_keys[key], key))
for label in mutator_keys:
mutators.append(EvaluatorValueChoiceLeafMutator(mutator_candidates[label], label))
if mutators:
# one last mutator to actually apply the mutations
mutators.append(EvaluatorValueChoiceMutator())
return mutators
......@@ -359,13 +439,3 @@ def _group_by_label(nodes: List[Node]) -> List[List[Node]]:
result[label] = []
result[label].append(node)
return list(result.values())
def _group_parameters_by_label(nodes: List[Tuple[Node, str]]) -> List[List[Tuple[Node, str]]]:
result = {}
for node, argname in nodes:
label = node.operation.parameters[argname].label
if label not in result:
result[label] = []
result[label].append((node, argname))
return list(result.values())
import math
import random
import unittest
from collections import Counter
import pytest
import nni.retiarii.nn.pytorch as nn
import torch
import torch.nn.functional as F
......@@ -50,7 +53,19 @@ class MutableConv(nn.Module):
return self.conv2(x)
def _apply_all_mutators(model, mutators, samplers):
if not isinstance(samplers, list):
samplers = [samplers for _ in range(len(mutators))]
assert len(samplers) == len(mutators)
model_new = model
for mutator, sampler in zip(mutators, samplers):
model_new = mutator.bind_sampler(sampler).apply(model_new)
return model_new
class GraphIR(unittest.TestCase):
# graph engine will have an extra mutator for parameter choices
value_choice_incr = 1
def _convert_to_ir(self, model):
script_module = torch.jit.script(model)
......@@ -220,7 +235,7 @@ class GraphIR(unittest.TestCase):
return self.conv(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
self.assertEqual(len(mutators), 1 + self.value_choice_incr)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
......@@ -240,16 +255,16 @@ class GraphIR(unittest.TestCase):
return self.conv(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(len(mutators), self.value_choice_incr + 1)
samplers = [EnumerateSampler() for _ in range(len(mutators))]
model1 = _apply_all_mutators(model, mutators, samplers)
model2 = _apply_all_mutators(model, mutators, samplers)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 5, 1, 1]))
def test_value_choice_as_parameter(self):
def test_value_choice_as_two_parameters(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
......@@ -260,13 +275,14 @@ class GraphIR(unittest.TestCase):
return self.conv(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2)
mutators[0].bind_sampler(EnumerateSampler())
mutators[1].bind_sampler(EnumerateSampler())
self.assertEqual(len(mutators), 2 + self.value_choice_incr)
samplers = [EnumerateSampler() for _ in range(len(mutators))]
model1 = _apply_all_mutators(model, mutators, samplers)
model2 = _apply_all_mutators(model, mutators, samplers)
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
self.assertEqual(self._get_converted_pytorch_model(model1)(input).size(),
torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[1].apply(mutators[0].apply(model)))(input).size(),
self.assertEqual(self._get_converted_pytorch_model(model2)(input).size(),
torch.Size([1, 8, 1, 1]))
def test_value_choice_as_parameter_shared(self):
......@@ -281,10 +297,10 @@ class GraphIR(unittest.TestCase):
return self.conv1(x) + self.conv2(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutator = mutators[0].bind_sampler(EnumerateSampler())
model1 = mutator.apply(model)
model2 = mutator.apply(model)
self.assertEqual(len(mutators), 1 + self.value_choice_incr)
sampler = EnumerateSampler()
model1 = _apply_all_mutators(model, mutators, sampler)
model2 = _apply_all_mutators(model, mutators, sampler)
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 5, 5)).size(),
torch.Size([1, 6, 5, 5]))
self.assertEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 5, 5)).size(),
......@@ -323,13 +339,11 @@ class GraphIR(unittest.TestCase):
return self.linear(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 3)
self.assertEqual(len(mutators), 3 + self.value_choice_incr)
sz_counter = Counter()
sampler = RandomSampler()
for i in range(100):
model_new = model
for mutator in mutators:
model_new = mutator.bind_sampler(sampler).apply(model_new)
model_new = _apply_all_mutators(model, mutators, sampler)
sz_counter[self._get_converted_pytorch_model(model_new)(torch.randn(1, 3)).size(1)] += 1
self.assertEqual(len(sz_counter), 4)
......@@ -375,7 +389,7 @@ class GraphIR(unittest.TestCase):
self.assertGreater(failed_count, 0)
self.assertLess(failed_count, 30)
def test_valuechoice_access(self):
def test_valuechoice_getitem(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
......@@ -387,12 +401,12 @@ class GraphIR(unittest.TestCase):
return self.conv(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler())
self.assertEqual(len(mutators), 1 + self.value_choice_incr)
sampler = EnumerateSampler()
input = torch.randn(1, 3, 5, 5)
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
self.assertEqual(self._get_converted_pytorch_model(_apply_all_mutators(model, mutators, sampler))(input).size(),
torch.Size([1, 6, 3, 3]))
self.assertEqual(self._get_converted_pytorch_model(mutators[0].apply(model))(input).size(),
self.assertEqual(self._get_converted_pytorch_model(_apply_all_mutators(model, mutators, sampler))(input).size(),
torch.Size([1, 8, 1, 1]))
@model_wrapper
......@@ -411,12 +425,11 @@ class GraphIR(unittest.TestCase):
return self.conv1(torch.cat((x, x), 1))
model, mutators = self._get_model_with_mutators(Net2())
self.assertEqual(len(mutators), 1)
mutators[0].bind_sampler(EnumerateSampler())
self.assertEqual(len(mutators), 1 + self.value_choice_incr)
input = torch.randn(1, 3, 5, 5)
self._get_converted_pytorch_model(mutators[0].apply(model))(input)
self._get_converted_pytorch_model(_apply_all_mutators(model, mutators, EnumerateSampler()))(input)
def test_valuechoice_access_functional(self):
def test_valuechoice_getitem_functional(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
......@@ -435,7 +448,7 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_access_functional_expression(self):
def test_valuechoice_getitem_functional_expression(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
......@@ -456,6 +469,43 @@ class GraphIR(unittest.TestCase):
self.assertEqual(self._get_converted_pytorch_model(model1)(torch.randn(1, 3, 3, 3)).size(), torch.Size([1, 3, 3, 3]))
self.assertAlmostEqual(self._get_converted_pytorch_model(model2)(torch.randn(1, 3, 3, 3)).abs().sum().item(), 0)
def test_valuechoice_multi(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
choice1 = nn.ValueChoice([{"in": 1, "out": 3}, {"in": 2, "out": 6}, {"in": 3, "out": 9}])
choice2 = nn.ValueChoice([2.5, 3.0, 3.5], label='multi')
choice3 = nn.ValueChoice([2.5, 3.0, 3.5], label='multi')
self.conv1 = nn.Conv2d(choice1["in"], round(choice1["out"] * choice2), 1)
self.conv2 = nn.Conv2d(choice1["in"], round(choice1["out"] * choice3), 1)
def forward(self, x):
return self.conv1(x) + self.conv2(x)
model, mutators = self._get_model_with_mutators(Net())
self.assertEqual(len(mutators), 2 + self.value_choice_incr)
samplers = [EnumerateSampler()] + [RandomSampler() for _ in range(self.value_choice_incr + 1)]
for i in range(10):
model_new = _apply_all_mutators(model, mutators, samplers)
result = self._get_converted_pytorch_model(model_new)(torch.randn(1, i % 3 + 1, 3, 3))
self.assertIn(result.size(), [torch.Size([1, round((i % 3 + 1) * 3 * k), 3, 3]) for k in [2.5, 3.0, 3.5]])
def test_valuechoice_inconsistent_label(self):
@model_wrapper
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(3, nn.ValueChoice([3, 5], label='a'), 1)
self.conv2 = nn.Conv2d(3, nn.ValueChoice([3, 6], label='a'), 1)
def forward(self, x):
return torch.cat([self.conv1(x), self.conv2(x)], 1)
with pytest.raises(AssertionError):
self._get_model_with_mutators(Net())
def test_repeat(self):
class AddOne(nn.Module):
def forward(self, x):
......@@ -645,6 +695,9 @@ class GraphIR(unittest.TestCase):
class Python(GraphIR):
# Python engine doesn't have the extra mutator
value_choice_incr = 0
def _get_converted_pytorch_model(self, model_ir):
mutation = {mut.mutator.label: _unpack_if_only_one(mut.samples) for mut in model_ir.history}
with ContextStack('fixed', mutation):
......@@ -661,10 +714,10 @@ class Python(GraphIR):
def test_value_choice_in_functional(self): ...
@unittest.skip
def test_valuechoice_access_functional(self): ...
def test_valuechoice_getitem_functional(self): ...
@unittest.skip
def test_valuechoice_access_functional_expression(self): ...
def test_valuechoice_getitem_functional_expression(self): ...
def test_cell_loose_end(self):
@model_wrapper
......@@ -744,6 +797,95 @@ class Python(GraphIR):
class Shared(unittest.TestCase):
# This kind of tests are general across execution engines
def test_value_choice_api_purely(self):
a = nn.ValueChoice([1, 2], label='a')
b = nn.ValueChoice([3, 4], label='b')
c = nn.ValueChoice([5, 6], label='c')
d = a + b + 3 * c
for i, choice in enumerate(d.inner_choices()):
if i == 0:
assert choice.candidates == [1, 2]
elif i == 1:
assert choice.candidates == [3, 4]
elif i == 2:
assert choice.candidates == [5, 6]
assert d.evaluate([2, 3, 5]) == 20
a = nn.ValueChoice(['cat', 'dog'])
b = nn.ValueChoice(['milk', 'coffee'])
assert (a + b).evaluate(['dog', 'coffee']) == 'dogcoffee'
assert (a + 2 * b).evaluate(['cat', 'milk']) == 'catmilkmilk'
assert (3 - nn.ValueChoice([1, 2])).evaluate([1]) == 2
with pytest.raises(TypeError):
a + nn.ValueChoice([1, 3])
a = nn.ValueChoice([1, 17])
a = (abs(-a * 3) % 11) ** 5
assert 'abs' in repr(a)
with pytest.raises(ValueError):
a.evaluate([42])
assert a.evaluate([17]) == 7 ** 5
a = round(7 / nn.ValueChoice([2, 5]))
assert a.evaluate([2]) == 4
a = ~(77 ^ (nn.ValueChoice([1, 4]) & 5))
assert a.evaluate([4]) == ~(77 ^ (4 & 5))
a = nn.ValueChoice([5, 3]) * nn.ValueChoice([6.5, 7.5])
assert math.floor(a.evaluate([5, 7.5])) == int(5 * 7.5)
a = nn.ValueChoice([1, 3])
b = nn.ValueChoice([2, 4])
with pytest.raises(RuntimeError):
min(a, b)
with pytest.raises(RuntimeError):
if a < b:
...
assert nn.ValueChoice.min(a, b).evaluate([3, 2]) == 2
assert nn.ValueChoice.max(a, b).evaluate([3, 2]) == 3
assert nn.ValueChoice.max(1, 2, 3) == 3
assert nn.ValueChoice.max([1, 3, 2]) == 3
assert nn.ValueChoice.condition(nn.ValueChoice([2, 3]) <= 2, 'a', 'b').evaluate([3]) == 'b'
assert nn.ValueChoice.condition(nn.ValueChoice([2, 3]) <= 2, 'a', 'b').evaluate([2]) == 'a'
with pytest.raises(RuntimeError):
assert int(nn.ValueChoice([2.5, 3.5])).evalute([2.5]) == 2
assert nn.ValueChoice.to_int(nn.ValueChoice([2.5, 3.5])).evaluate([2.5]) == 2
assert nn.ValueChoice.to_float(nn.ValueChoice(['2.5', '3.5'])).evaluate(['3.5']) == 3.5
def test_make_divisible(self):
def make_divisible(value, divisor, min_value=None, min_ratio=0.9):
if min_value is None:
min_value = divisor
new_value = nn.ValueChoice.max(min_value, nn.ValueChoice.to_int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
return nn.ValueChoice.condition(new_value < min_ratio * value, new_value + divisor, new_value)
def original_make_divisible(value, divisor, min_value=None, min_ratio=0.9):
if min_value is None:
min_value = divisor
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than (1-min_ratio).
if new_value < min_ratio * value:
new_value += divisor
return new_value
values = [4, 8, 16, 32, 64, 128]
divisors = [2, 3, 5, 7, 15]
with pytest.raises(RuntimeError):
original_make_divisible(nn.ValueChoice(values, label='value'), nn.ValueChoice(divisors, label='divisor'))
result = make_divisible(nn.ValueChoice(values, label='value'), nn.ValueChoice(divisors, label='divisor'))
for value in values:
for divisor in divisors:
lst = [value if choice.label == 'value' else divisor for choice in result.inner_choices()]
assert result.evaluate(lst) == original_make_divisible(value, divisor)
def test_valuechoice_in_evaluator(self):
def foo():
pass
......@@ -753,28 +895,28 @@ class Shared(unittest.TestCase):
evaluator = FunctionalEvaluator(foo, t=1, x=ValueChoice([1, 2]), y=ValueChoice([3, 4]))
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 2
assert len(mutators) == 3
init_model = Model(_internal=True)
init_model.evaluator = evaluator
sampler = EnumerateSampler()
model = mutators[0].bind_sampler(sampler).apply(init_model)
samplers = [EnumerateSampler() for _ in range(3)]
model = _apply_all_mutators(init_model, mutators, samplers)
assert model.evaluator.trace_kwargs['x'] == 1
model = mutators[0].bind_sampler(sampler).apply(init_model)
model = _apply_all_mutators(init_model, mutators, samplers)
assert model.evaluator.trace_kwargs['x'] == 2
# share label
evaluator = FunctionalEvaluator(foo, t=ValueChoice([1, 2], label='x'), x=ValueChoice([1, 2], label='x'))
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 1
assert len(mutators) == 2
# getitem
choice = ValueChoice([{"a": 1, "b": 2}, {"a": 3, "b": 4}])
evaluator = FunctionalEvaluator(foo, t=1, x=choice['a'], y=choice['b'])
mutators = process_evaluator_mutations(evaluator, [])
assert len(mutators) == 1
assert len(mutators) == 2
init_model = Model(_internal=True)
init_model.evaluator = evaluator
sampler = RandomSampler()
for _ in range(10):
model = mutators[0].bind_sampler(sampler).apply(init_model)
model = _apply_all_mutators(init_model, mutators, sampler)
assert (model.evaluator.trace_kwargs['x'], model.evaluator.trace_kwargs['y']) in [(1, 2), (3, 4)]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment