choice.py 43.9 KB
Newer Older
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import itertools
5
import math
6
import operator
7
import warnings
8
9
10
from typing import (Any, Callable, Dict, Generic, Iterable, Iterator, List,
                    NoReturn, Optional, Sequence, SupportsRound, TypeVar,
                    Union, cast)
11
12
13

import torch
import torch.nn as nn
14
from nni.common.hpo_utils import ParameterSpec
15
from nni.common.serializer import Translatable
16
from nni.nas.utils import STATE_DICT_PY_MAPPING_PARTIAL, ModelNamespace, NoContextError, basic_unit
17

18
from .mutation_utils import Mutable, generate_new_label, get_fixed_value
19

20
21
22
23
24
25
26
27
28
29
__all__ = [
    # APIs
    'LayerChoice',
    'InputChoice',
    'ValueChoice',
    'ModelParameterChoice',
    'Placeholder',

    # Fixed module
    'ChosenInputs',
30

31
32
33
34
35
    # Type utils
    'ReductionType',
    'MaybeChoice',
    'ChoiceOf',
]
36
37


Yuge Zhang's avatar
Yuge Zhang committed
38
class LayerChoice(Mutable):
39
40
41
    """
    Layer choice selects one of the ``candidates``, then apply it on inputs and return results.

Yuge Zhang's avatar
Yuge Zhang committed
42
43
44
    It allows users to put several candidate operations (e.g., PyTorch modules), one of them is chosen in each explored model.

    *New in v2.2:* Layer choice can be nested.
45
46
47
48
49

    Parameters
    ----------
    candidates : list of nn.Module or OrderedDict
        A module list to be selected from.
Yuge Zhang's avatar
Yuge Zhang committed
50
51
    prior : list of float
        Prior distribution used in random sampling.
52
53
54
55
56
57
58
59
60
61
62
63
64
    label : str
        Identifier of the layer choice.

    Attributes
    ----------
    length : int
        Deprecated. Number of ops to choose from. ``len(layer_choice)`` is recommended.
    names : list of str
        Names of candidates.
    choices : list of Module
        Deprecated. A list of all candidate modules in the layer choice module.
        ``list(layer_choice)`` is recommended, which will serve the same purpose.

Yuge Zhang's avatar
Yuge Zhang committed
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    Examples
    --------

    ::

        # import nni.retiarii.nn.pytorch as nn
        # declared in `__init__` method
        self.layer = nn.LayerChoice([
            ops.PoolBN('max', channels, 3, stride, 1),
            ops.SepConv(channels, channels, 3, stride, 1),
            nn.Identity()
        ])
        # invoked in `forward` method
        out = self.layer(x)

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
    Notes
    -----
    ``candidates`` can be a list of modules or a ordered dict of named modules, for example,

    .. code-block:: python

        self.op_choice = LayerChoice(OrderedDict([
            ("conv3x3", nn.Conv2d(3, 16, 128)),
            ("conv5x5", nn.Conv2d(5, 16, 128)),
            ("conv7x7", nn.Conv2d(7, 16, 128))
        ]))

    Elements in layer choice can be modified or deleted. Use ``del self.op_choice["conv5x5"]`` or
    ``self.op_choice[1] = nn.Conv3d(...)``. Adding more choices is not supported yet.
    """

Yuge Zhang's avatar
Yuge Zhang committed
96
97
    # FIXME: prior is designed but not supported yet

Yuge Zhang's avatar
Yuge Zhang committed
98
99
100
101
102
    @classmethod
    def create_fixed_module(cls, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
                            label: Optional[str] = None, **kwargs):
        chosen = get_fixed_value(label)
        if isinstance(candidates, list):
103
            result = candidates[int(chosen)]
Yuge Zhang's avatar
Yuge Zhang committed
104
        else:
105
106
107
108
109
110
111
112
113
114
115
116
117
118
            result = candidates[chosen]

        # map the named hierarchies to support weight inheritance for python engine
        if hasattr(result, STATE_DICT_PY_MAPPING_PARTIAL):
            # handle cases where layer choices are nested
            # already has a mapping, will merge with it
            prev_mapping = getattr(result, STATE_DICT_PY_MAPPING_PARTIAL)
            setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {k: f'{chosen}.{v}' for k, v in prev_mapping.items()})
        else:
            # "result" needs to know where to map itself.
            # Ideally, we should put a _mapping_ in the module where "result" is located,
            # but it's impossible to put mapping into parent module here.
            setattr(result, STATE_DICT_PY_MAPPING_PARTIAL, {'__self__': str(chosen)})
        return result
119

Yuge Zhang's avatar
Yuge Zhang committed
120
121
    def __init__(self, candidates: Union[Dict[str, nn.Module], List[nn.Module]], *,
                 prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
122
123
124
125
126
127
128
129
130
        super(LayerChoice, self).__init__()
        if 'key' in kwargs:
            warnings.warn(f'"key" is deprecated. Assuming label.')
            label = kwargs['key']
        if 'return_mask' in kwargs:
            warnings.warn(f'"return_mask" is deprecated. Ignoring...')
        if 'reduction' in kwargs:
            warnings.warn(f'"reduction" is deprecated. Ignoring...')
        self.candidates = candidates
Yuge Zhang's avatar
Yuge Zhang committed
131
132
        self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
        assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
133
        self._label = generate_new_label(label)
134
135

        self.names = []
Yuge Zhang's avatar
Yuge Zhang committed
136
        if isinstance(candidates, dict):
137
138
139
140
141
142
143
144
145
146
147
            for name, module in candidates.items():
                assert name not in ["length", "reduction", "return_mask", "_key", "key", "names"], \
                    "Please don't use a reserved name '{}' for your module.".format(name)
                self.add_module(name, module)
                self.names.append(name)
        elif isinstance(candidates, list):
            for i, module in enumerate(candidates):
                self.add_module(str(i), module)
                self.names.append(str(i))
        else:
            raise TypeError("Unsupported candidates type: {}".format(type(candidates)))
148
        self._first_module = cast(nn.Module, self._modules[self.names[0]])  # to make the dummy forward meaningful
149
150
151
152
153

    @property
    def label(self):
        return self._label

154
    def __getitem__(self, idx: Union[int, str]) -> nn.Module:
155
        if isinstance(idx, str):
156
157
            return cast(nn.Module, self._modules[idx])
        return cast(nn.Module, list(self)[idx])
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181

    def __setitem__(self, idx, module):
        key = idx if isinstance(idx, str) else self.names[idx]
        return setattr(self, key, module)

    def __delitem__(self, idx):
        if isinstance(idx, slice):
            for key in self.names[idx]:
                delattr(self, key)
        else:
            if isinstance(idx, str):
                key, idx = idx, self.names.index(idx)
            else:
                key = self.names[idx]
            delattr(self, key)
        del self.names[idx]

    def __len__(self):
        return len(self.names)

    def __iter__(self):
        return map(lambda name: self._modules[name], self.names)

    def forward(self, x):
Yuge Zhang's avatar
Yuge Zhang committed
182
183
184
185
        """
        The forward of layer choice is simply running the first candidate module.
        It shouldn't be called directly by users in most cases.
        """
186
        warnings.warn('You should not run forward of this module directly.')
187
        return self._first_module(x)
188

189
190
191
    def __repr__(self):
        return f'LayerChoice({self.candidates}, label={repr(self.label)})'

192

193
194
195
196
197
198
199
200
try:
    from typing import Literal
except ImportError:
    from typing_extensions import Literal

ReductionType = Literal['mean', 'concat', 'sum', 'none']


Yuge Zhang's avatar
Yuge Zhang committed
201
class InputChoice(Mutable):
202
203
    """
    Input choice selects ``n_chosen`` inputs from ``choose_from`` (contains ``n_candidates`` keys).
Yuge Zhang's avatar
Yuge Zhang committed
204
205
206
207

    It is mainly for choosing (or trying) different connections. It takes several tensors and chooses ``n_chosen`` tensors from them.
    When specific inputs are chosen, ``InputChoice`` will become :class:`ChosenInputs`.

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
    Use ``reduction`` to specify how chosen inputs are reduced into one output. A few options are:

    * ``none``: do nothing and return the list directly.
    * ``sum``: summing all the chosen inputs.
    * ``mean``: taking the average of all chosen inputs.
    * ``concat``: concatenate all chosen inputs at dimension 1.

    We don't support customizing reduction yet.

    Parameters
    ----------
    n_candidates : int
        Number of inputs to choose from. It is required.
    n_chosen : int
        Recommended inputs to choose. If None, mutator is instructed to select any.
    reduction : str
        ``mean``, ``concat``, ``sum`` or ``none``.
Yuge Zhang's avatar
Yuge Zhang committed
225
226
    prior : list of float
        Prior distribution used in random sampling.
227
228
    label : str
        Identifier of the input choice.
Yuge Zhang's avatar
Yuge Zhang committed
229
230
231
232
233
234
235
236
237
238

    Examples
    --------
    ::

        # import nni.retiarii.nn.pytorch as nn
        # declared in `__init__` method
        self.input_switch = nn.InputChoice(n_chosen=1)
        # invoked in `forward` method, choose one from the three
        out = self.input_switch([tensor1, tensor2, tensor3])
239
240
    """

Yuge Zhang's avatar
Yuge Zhang committed
241
    @classmethod
242
243
    def create_fixed_module(cls, n_candidates: int, n_chosen: Optional[int] = 1,
                            reduction: ReductionType = 'sum', *,
Yuge Zhang's avatar
Yuge Zhang committed
244
245
                            prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
        return ChosenInputs(get_fixed_value(label), reduction=reduction)
246

Yuge Zhang's avatar
Yuge Zhang committed
247
248
249
    def __init__(self, n_candidates: int, n_chosen: Optional[int] = 1,
                 reduction: str = 'sum', *,
                 prior: Optional[List[float]] = None, label: Optional[str] = None, **kwargs):
250
251
252
253
254
255
256
257
258
259
260
        super(InputChoice, self).__init__()
        if 'key' in kwargs:
            warnings.warn(f'"key" is deprecated. Assuming label.')
            label = kwargs['key']
        if 'return_mask' in kwargs:
            warnings.warn(f'"return_mask" is deprecated. Ignoring...')
        if 'choose_from' in kwargs:
            warnings.warn(f'"reduction" is deprecated. Ignoring...')
        self.n_candidates = n_candidates
        self.n_chosen = n_chosen
        self.reduction = reduction
Yuge Zhang's avatar
Yuge Zhang committed
261
        self.prior = prior or [1 / n_candidates for _ in range(n_candidates)]
262
        assert self.reduction in ['mean', 'concat', 'sum', 'none']
263
        self._label = generate_new_label(label)
264
265
266
267
268
269

    @property
    def label(self):
        return self._label

    def forward(self, candidate_inputs: List[torch.Tensor]) -> torch.Tensor:
Yuge Zhang's avatar
Yuge Zhang committed
270
271
272
273
        """
        The forward of input choice is simply the first item of ``candidate_inputs``.
        It shouldn't be called directly by users in most cases.
        """
274
275
276
        warnings.warn('You should not run forward of this module directly.')
        return candidate_inputs[0]

277
278
279
280
    def __repr__(self):
        return f'InputChoice(n_candidates={self.n_candidates}, n_chosen={self.n_chosen}, ' \
            f'reduction={repr(self.reduction)}, label={repr(self.label)})'

281

282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
class ChosenInputs(nn.Module):
    """
    A module that chooses from a tensor list and outputs a reduced tensor.
    The already-chosen version of InputChoice.

    When forward, ``chosen`` will be used to select inputs from ``candidate_inputs``,
    and ``reduction`` will be used to choose from those inputs to form a tensor.

    Attributes
    ----------
    chosen : list of int
        Indices of chosen inputs.
    reduction : ``mean`` | ``concat`` | ``sum`` | ``none``
        How to reduce the inputs when multiple are selected.
    """

    def __init__(self, chosen: Union[List[int], int], reduction: ReductionType):
        super().__init__()
        self.chosen = chosen if isinstance(chosen, list) else [chosen]
        self.reduction = reduction

    def forward(self, candidate_inputs):
Yuge Zhang's avatar
Yuge Zhang committed
304
305
306
        """
        Compute the reduced input based on ``chosen`` and ``reduction``.
        """
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
        return self._tensor_reduction(self.reduction, [candidate_inputs[i] for i in self.chosen])

    def _tensor_reduction(self, reduction_type, tensor_list):
        if reduction_type == 'none':
            return tensor_list
        if not tensor_list:
            return None  # empty. return None for now
        if len(tensor_list) == 1:
            return tensor_list[0]
        if reduction_type == 'sum':
            return sum(tensor_list)
        if reduction_type == 'mean':
            return sum(tensor_list) / len(tensor_list)
        if reduction_type == 'concat':
            return torch.cat(tensor_list, dim=1)
        raise ValueError(f'Unrecognized reduction policy: "{reduction_type}"')


# the code in ValueChoice can be generated with this codegen
# this is not done online because I want to have type-hint supports
# $ python -c "from nni.retiarii.nn.pytorch.api import _valuechoice_codegen; _valuechoice_codegen(_internal=True)"
def _valuechoice_codegen(*, _internal: bool = False):
    if not _internal:
        raise RuntimeError("This method is set to be internal. Please don't use it directly.")
    MAPPING = {
        # unary
        'neg': '-', 'pos': '+', 'invert': '~',
        # binary
        'add': '+', 'sub': '-', 'mul': '*', 'matmul': '@',
        'truediv': '//', 'floordiv': '/', 'mod': '%',
        'lshift': '<<', 'rshift': '>>',
        'and': '&', 'xor': '^', 'or': '|',
339
        # no reverse
340
341
342
343
344
345
346
        'lt': '<', 'le': '<=', 'eq': '==',
        'ne': '!=', 'ge': '>=', 'gt': '>',
        # NOTE
        # Currently we don't support operators like __contains__ (b in a),
        # Might support them in future when we actually need them.
    }

347
    binary_template = """    def __{op}__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
348
349
        return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [self, other])"""

350
    binary_r_template = """    def __r{op}__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
351
352
        return ValueChoiceX(operator.{opt}, '{{}} {sym} {{}}', [other, self])"""

353
354
    unary_template = """    def __{op}__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
        return cast(ChoiceOf[_value], ValueChoiceX(operator.{op}, '{sym}{{}}', [self]))"""
355
356
357
358
359
360
361
362
363
364
365

    for op, sym in MAPPING.items():
        if op in ['neg', 'pos', 'invert']:
            print(unary_template.format(op=op, sym=sym) + '\n')
        else:
            opt = op + '_' if op in ['and', 'or'] else op
            print(binary_template.format(op=op, opt=opt, sym=sym) + '\n')
            if op not in ['lt', 'le', 'eq', 'ne', 'ge', 'gt']:
                print(binary_r_template.format(op=op, opt=opt, sym=sym) + '\n')


366
367
368
369
370
371
372
373
_func = TypeVar('_func')
_cand = TypeVar('_cand')
_value = TypeVar('_value')


def _valuechoice_staticmethod_helper(orig_func: _func) -> _func:
    if orig_func.__doc__ is not None:
        orig_func.__doc__ += """
374
375
376
377
378
379
380
381
382
        Notes
        -----
        This function performs lazy evaluation.
        Only the expression will be recorded when the function is called.
        The real evaluation happens when the inner value choice has determined its final decision.
        If no value choice is contained in the parameter list, the evaluation will be intermediate."""
    return orig_func


383
class ValueChoiceX(Generic[_cand], Translatable, nn.Module):
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
    """Internal API. Implementation note:

    The transformed (X) version of value choice.
    It can be the result of composition (transformation) of one or several value choices. For example,

    .. code-block:: python

        nn.ValueChoice([1, 2]) + nn.ValueChoice([3, 4]) + 5

    The instance of base class cannot be created directly. Instead, they should be only the result of transformation of value choice.
    Therefore, there is no need to implement ``create_fixed_module`` in this class, because,
    1. For python-engine, value choice itself has create fixed module. Consequently, the transformation is born to be fixed.
    2. For graph-engine, it uses evaluate to calculate the result.

    Potentially, we have to implement the evaluation logic in oneshot algorithms. I believe we can postpone the discussion till then.
399
400

    This class is implemented as a ``nn.Module`` so that it can be scanned by python engine / torchscript.
401
402
    """

403
404
405
406
    def __init__(self, function: Callable[..., _cand] = cast(Callable[..., _cand], None),
                 repr_template: str = cast(str, None),
                 arguments: List[Any] = cast('List[MaybeChoice[_cand]]', None),
                 dry_run: bool = True):
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
        super().__init__()

        if function is None:
            # this case is a hack for ValueChoice subclass
            # it will reach here only because ``__init__`` in ``nn.Module`` is useful.
            return

        self.function = function
        self.repr_template = repr_template
        self.arguments = arguments

        assert any(isinstance(arg, ValueChoiceX) for arg in self.arguments)

        if dry_run:
            # for sanity check
            self.dry_run()

424
425
426
    def forward(self) -> None:
        raise RuntimeError('You should never call forward of the composition of a value-choice.')

427
428
    def inner_choices(self) -> Iterable['ValueChoice']:
        """
429
        Return a generator of all leaf value choices.
430
431
432
433
434
435
436
        Useful for composition of value choices.
        No deduplication on labels. Mutators should take care.
        """
        for arg in self.arguments:
            if isinstance(arg, ValueChoiceX):
                yield from arg.inner_choices()

437
    def dry_run(self) -> _cand:
438
439
440
441
442
443
        """
        Dry run the value choice to get one of its possible evaluation results.
        """
        # values are not used
        return self._evaluate(iter([]), True)

444
    def all_options(self) -> Iterable[_cand]:
445
446
447
        """Explore all possibilities of a value choice.
        """
        # Record all inner choices: label -> candidates, no duplicates.
448
        dedup_inner_choices: Dict[str, List[_cand]] = {}
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
        # All labels of leaf nodes on tree, possibly duplicates.
        all_labels: List[str] = []

        for choice in self.inner_choices():
            all_labels.append(choice.label)
            if choice.label in dedup_inner_choices:
                if choice.candidates != dedup_inner_choices[choice.label]:
                    # check for choice with the same label
                    raise ValueError(f'"{choice.candidates}" is not equal to "{dedup_inner_choices[choice.label]}", '
                                     f'but they share the same label: {choice.label}')
            else:
                dedup_inner_choices[choice.label] = choice.candidates

        dedup_labels, dedup_candidates = list(dedup_inner_choices.keys()), list(dedup_inner_choices.values())

        for chosen in itertools.product(*dedup_candidates):
            chosen = dict(zip(dedup_labels, chosen))
            yield self.evaluate([chosen[label] for label in all_labels])

468
    def evaluate(self, values: Iterable[_cand]) -> _cand:
469
470
471
472
473
474
        """
        Evaluate the result of this group.
        ``values`` should in the same order of ``inner_choices()``.
        """
        return self._evaluate(iter(values), False)

475
    def _evaluate(self, values: Iterator[_cand], dry_run: bool = False) -> _cand:
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
        # "values" iterates in the recursion
        eval_args = []
        for arg in self.arguments:
            if isinstance(arg, ValueChoiceX):
                # recursive evaluation
                eval_args.append(arg._evaluate(values, dry_run))
                # the recursion will stop when it hits a leaf node (value choice)
                # the implementation is in `ValueChoice`
            else:
                # constant value
                eval_args.append(arg)
        return self.function(*eval_args)

    def _translate(self):
        """
        Try to behave like one of its candidates when used in ``basic_unit``.
        """
        return self.dry_run()

495
    def __repr__(self) -> str:
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
        reprs = []
        for arg in self.arguments:
            if isinstance(arg, ValueChoiceX) and not isinstance(arg, ValueChoice):
                reprs.append('(' + repr(arg) + ')')  # add parenthesis for operator priority
            else:
                reprs.append(repr(arg))
        return self.repr_template.format(*reprs)

    # the following are a series of methods to create "ValueChoiceX"
    # which is a transformed version of value choice
    # https://docs.python.org/3/reference/datamodel.html#special-method-names

    # Special operators that can be useful in place of built-in conditional operators.
    @staticmethod
    @_valuechoice_staticmethod_helper
511
    def to_int(obj: 'MaybeChoice[Any]') -> 'MaybeChoice[int]':
512
513
514
515
516
517
518
519
520
        """
        Convert a ``ValueChoice`` to an integer.
        """
        if isinstance(obj, ValueChoiceX):
            return ValueChoiceX(int, 'int({})', [obj])
        return int(obj)

    @staticmethod
    @_valuechoice_staticmethod_helper
521
    def to_float(obj: 'MaybeChoice[Any]') -> 'MaybeChoice[float]':
522
523
524
525
526
527
528
529
530
        """
        Convert a ``ValueChoice`` to a float.
        """
        if isinstance(obj, ValueChoiceX):
            return ValueChoiceX(float, 'float({})', [obj])
        return float(obj)

    @staticmethod
    @_valuechoice_staticmethod_helper
531
532
533
    def condition(pred: 'MaybeChoice[bool]',
                  true: 'MaybeChoice[_value]',
                  false: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
534
535
536
537
538
539
540
541
542
543
544
545
546
        """
        Return ``true`` if the predicate ``pred`` is true else ``false``.

        Examples
        --------
        >>> ValueChoice.condition(ValueChoice([1, 2]) > ValueChoice([0, 3]), 2, 1)
        """
        if any(isinstance(obj, ValueChoiceX) for obj in [pred, true, false]):
            return ValueChoiceX(lambda t, c, f: t if c else f, '{} if {} else {}', [true, pred, false])
        return true if pred else false

    @staticmethod
    @_valuechoice_staticmethod_helper
547
548
    def max(arg0: Union[Iterable['MaybeChoice[_value]'], 'MaybeChoice[_value]'],
            *args: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
549
550
551
552
553
554
        """
        Returns the maximum value from a list of value choices.
        The usage should be similar to Python's built-in value choices,
        where the parameters could be an iterable, or at least two arguments.
        """
        if not args:
555
556
557
558
            if not isinstance(arg0, Iterable):
                raise TypeError('Expect more than one items to compare max')
            return cast(MaybeChoice[_value], ValueChoiceX.max(*list(arg0)))
        lst = list(arg0) if isinstance(arg0, Iterable) else [arg0] + list(args)
559
560
        if any(isinstance(obj, ValueChoiceX) for obj in lst):
            return ValueChoiceX(max, 'max({})', lst)
561
        return max(cast(Any, lst))
562
563
564

    @staticmethod
    @_valuechoice_staticmethod_helper
565
566
    def min(arg0: Union[Iterable['MaybeChoice[_value]'], 'MaybeChoice[_value]'],
            *args: 'MaybeChoice[_value]') -> 'MaybeChoice[_value]':
567
568
569
570
571
572
        """
        Returns the minunum value from a list of value choices.
        The usage should be similar to Python's built-in value choices,
        where the parameters could be an iterable, or at least two arguments.
        """
        if not args:
573
574
575
576
            if not isinstance(arg0, Iterable):
                raise TypeError('Expect more than one items to compare min')
            return cast(MaybeChoice[_value], ValueChoiceX.min(*list(arg0)))
        lst = list(arg0) if isinstance(arg0, Iterable) else [arg0] + list(args)
577
578
        if any(isinstance(obj, ValueChoiceX) for obj in lst):
            return ValueChoiceX(min, 'min({})', lst)
579
        return min(cast(Any, lst))
580
581
582
583
584
585
586
587
588
589
590

    def __hash__(self):
        # this is required because we have implemented ``__eq__``
        return id(self)

    # NOTE:
    # Write operations are not supported. Reasons follow:
    # - Semantics are not clear. It can be applied to "all" the inner candidates, or only the chosen one.
    # - Implementation effort is too huge.
    # As a result, inplace operators like +=, *=, magic methods like `__getattr__` are not included in this list.

591
    def __getitem__(self: 'ChoiceOf[Any]', key: Any) -> 'ChoiceOf[Any]':
592
593
594
595
596
        return ValueChoiceX(lambda x, y: x[y], '{}[{}]', [self, key])

    # region implement int, float, round, trunc, floor, ceil
    # because I believe sometimes we need them to calculate #channels
    # `__int__` and `__float__` are not supported because `__int__` is required to return int.
597
598
    def __round__(self: 'ChoiceOf[SupportsRound[_value]]',
                  ndigits: Optional['MaybeChoice[int]'] = None) -> 'ChoiceOf[Union[int, SupportsRound[_value]]]':
599
        if ndigits is not None:
600
601
            return cast(ChoiceOf[Union[int, SupportsRound[_value]]], ValueChoiceX(round, 'round({}, {})', [self, ndigits]))
        return cast(ChoiceOf[Union[int, SupportsRound[_value]]], ValueChoiceX(round, 'round({})', [self]))
602

603
    def __trunc__(self) -> NoReturn:
604
605
        raise RuntimeError("Try to use `ValueChoice.to_int()` instead of `math.trunc()` on value choices.")

606
    def __floor__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[int]':
607
608
        return ValueChoiceX(math.floor, 'math.floor({})', [self])

609
    def __ceil__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[int]':
610
611
612
613
614
        return ValueChoiceX(math.ceil, 'math.ceil({})', [self])

    def __index__(self) -> NoReturn:
        # https://docs.python.org/3/reference/datamodel.html#object.__index__
        raise RuntimeError("`__index__` is not allowed on ValueChoice, which means you can't "
Yuge Zhang's avatar
Yuge Zhang committed
615
616
                           "use int(), float(), complex(), range() on a ValueChoice. "
                           "To cast the type of ValueChoice, please try `ValueChoice.to_int()` or `ValueChoice.to_float()`.")
617
618
619
620
621
622
623
624

    def __bool__(self) -> NoReturn:
        raise RuntimeError('Cannot use bool() on ValueChoice. That means, using ValueChoice in a if-clause is illegal. '
                           'Please try methods like `ValueChoice.max(a, b)` to see whether that meets your needs.')
    # endregion

    # region the following code is generated with codegen (see above)
    # Annotated with "region" because I want to collapse them in vscode
625
626
    def __neg__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
        return cast(ChoiceOf[_value], ValueChoiceX(operator.neg, '-{}', [self]))
627

628
629
    def __pos__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
        return cast(ChoiceOf[_value], ValueChoiceX(operator.pos, '+{}', [self]))
630

631
632
    def __invert__(self: 'ChoiceOf[_value]') -> 'ChoiceOf[_value]':
        return cast(ChoiceOf[_value], ValueChoiceX(operator.invert, '~{}', [self]))
633

634
    def __add__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
635
636
        return ValueChoiceX(operator.add, '{} + {}', [self, other])

637
    def __radd__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
638
639
        return ValueChoiceX(operator.add, '{} + {}', [other, self])

640
    def __sub__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
641
642
        return ValueChoiceX(operator.sub, '{} - {}', [self, other])

643
    def __rsub__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
644
645
        return ValueChoiceX(operator.sub, '{} - {}', [other, self])

646
    def __mul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
647
648
        return ValueChoiceX(operator.mul, '{} * {}', [self, other])

649
    def __rmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
650
651
        return ValueChoiceX(operator.mul, '{} * {}', [other, self])

652
    def __matmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
653
654
        return ValueChoiceX(operator.matmul, '{} @ {}', [self, other])

655
    def __rmatmul__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
656
657
        return ValueChoiceX(operator.matmul, '{} @ {}', [other, self])

658
    def __truediv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
659
660
        return ValueChoiceX(operator.truediv, '{} // {}', [self, other])

661
    def __rtruediv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
662
663
        return ValueChoiceX(operator.truediv, '{} // {}', [other, self])

664
    def __floordiv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
665
666
        return ValueChoiceX(operator.floordiv, '{} / {}', [self, other])

667
    def __rfloordiv__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
668
669
        return ValueChoiceX(operator.floordiv, '{} / {}', [other, self])

670
    def __mod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
671
672
        return ValueChoiceX(operator.mod, '{} % {}', [self, other])

673
    def __rmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
674
675
        return ValueChoiceX(operator.mod, '{} % {}', [other, self])

676
    def __lshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
677
678
        return ValueChoiceX(operator.lshift, '{} << {}', [self, other])

679
    def __rlshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
680
681
        return ValueChoiceX(operator.lshift, '{} << {}', [other, self])

682
    def __rshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
683
684
        return ValueChoiceX(operator.rshift, '{} >> {}', [self, other])

685
    def __rrshift__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
686
687
        return ValueChoiceX(operator.rshift, '{} >> {}', [other, self])

688
    def __and__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
689
690
        return ValueChoiceX(operator.and_, '{} & {}', [self, other])

691
    def __rand__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
692
693
        return ValueChoiceX(operator.and_, '{} & {}', [other, self])

694
    def __xor__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
695
696
        return ValueChoiceX(operator.xor, '{} ^ {}', [self, other])

697
    def __rxor__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
698
699
        return ValueChoiceX(operator.xor, '{} ^ {}', [other, self])

700
    def __or__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
701
702
        return ValueChoiceX(operator.or_, '{} | {}', [self, other])

703
    def __ror__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
704
705
        return ValueChoiceX(operator.or_, '{} | {}', [other, self])

706
    def __lt__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
707
708
        return ValueChoiceX(operator.lt, '{} < {}', [self, other])

709
    def __le__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
710
711
        return ValueChoiceX(operator.le, '{} <= {}', [self, other])

712
    def __eq__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
713
714
        return ValueChoiceX(operator.eq, '{} == {}', [self, other])

715
    def __ne__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
716
717
        return ValueChoiceX(operator.ne, '{} != {}', [self, other])

718
    def __ge__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
719
720
        return ValueChoiceX(operator.ge, '{} >= {}', [self, other])

721
    def __gt__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
722
723
724
725
726
        return ValueChoiceX(operator.gt, '{} > {}', [self, other])
    # endregion

    # __pow__, __divmod__, __abs__ are special ones.
    # Not easy to cover those cases with codegen.
727
    def __pow__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]', modulo: Optional['MaybeChoice[Any]'] = None) -> 'ChoiceOf[Any]':
728
729
730
731
        if modulo is not None:
            return ValueChoiceX(pow, 'pow({}, {}, {})', [self, other, modulo])
        return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [self, other])

732
    def __rpow__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]', modulo: Optional['MaybeChoice[Any]'] = None) -> 'ChoiceOf[Any]':
733
734
735
736
        if modulo is not None:
            return ValueChoiceX(pow, 'pow({}, {}, {})', [other, self, modulo])
        return ValueChoiceX(lambda a, b: a ** b, '{} ** {}', [other, self])

737
    def __divmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
738
739
        return ValueChoiceX(divmod, 'divmod({}, {})', [self, other])

740
    def __rdivmod__(self: 'ChoiceOf[Any]', other: 'MaybeChoice[Any]') -> 'ChoiceOf[Any]':
741
742
        return ValueChoiceX(divmod, 'divmod({}, {})', [other, self])

743
    def __abs__(self: 'ChoiceOf[Any]') -> 'ChoiceOf[Any]':
744
745
746
        return ValueChoiceX(abs, 'abs({})', [self])


747
748
ChoiceOf = ValueChoiceX
MaybeChoice = Union[ValueChoiceX[_cand], _cand]
749
750


751
class ValueChoice(ValueChoiceX[_cand], Mutable):
752
    """
Yuge Zhang's avatar
Yuge Zhang committed
753
    ValueChoice is to choose one from ``candidates``. The most common use cases are:
754

Yuge Zhang's avatar
Yuge Zhang committed
755
756
757
    * Used as input arguments of :class:`~nni.retiarii.basic_unit`
      (i.e., modules in ``nni.retiarii.nn.pytorch`` and user-defined modules decorated with ``@basic_unit``).
    * Used as input arguments of evaluator (*new in v2.7*).
758

759
    It can be used in parameters of operators (i.e., a sub-module of the model): ::
760
761
762
763
764
765
766
767

        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.conv = nn.Conv2d(3, nn.ValueChoice([32, 64]), kernel_size=nn.ValueChoice([3, 5, 7]))

            def forward(self, x):
                return self.conv(x)
768

769
770
    Or evaluator (only if the evaluator is :doc:`traceable </nas/serialization>`, e.g.,
    :class:`FunctionalEvaluator <nni.retiarii.evaluator.FunctionalEvaluator>`): ::
771

Yuge Zhang's avatar
Yuge Zhang committed
772
773
        def train_and_evaluate(model_cls, learning_rate):
            ...
774

Yuge Zhang's avatar
Yuge Zhang committed
775
        self.evaluator = FunctionalEvaluator(train_and_evaluate, learning_rate=nn.ValueChoice([1e-3, 1e-2, 1e-1]))
776

Yuge Zhang's avatar
Yuge Zhang committed
777
    Value choices supports arithmetic operators, which is particularly useful when searching for a network width multiplier: ::
778

Yuge Zhang's avatar
Yuge Zhang committed
779
780
781
782
783
        # init
        scale = nn.ValueChoice([1.0, 1.5, 2.0])
        self.conv1 = nn.Conv2d(3, round(scale * 16))
        self.conv2 = nn.Conv2d(round(scale * 16), round(scale * 64))
        self.conv3 = nn.Conv2d(round(scale * 64), round(scale * 256))
784

Yuge Zhang's avatar
Yuge Zhang committed
785
786
        # forward
        return self.conv3(self.conv2(self.conv1(x)))
787

Yuge Zhang's avatar
Yuge Zhang committed
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
    Or when kernel size and padding are coupled so as to keep the output size constant: ::

        # init
        ks = nn.ValueChoice([3, 5, 7])
        self.conv = nn.Conv2d(3, 16, kernel_size=ks, padding=(ks - 1) // 2)

        # forward
        return self.conv(x)

    Or when several layers are concatenated for a final layer. ::

        # init
        self.linear1 = nn.Linear(3, nn.ValueChoice([1, 2, 3], label='a'))
        self.linear2 = nn.Linear(3, nn.ValueChoice([4, 5, 6], label='b'))
        self.final = nn.Linear(nn.ValueChoice([1, 2, 3], label='a') + nn.ValueChoice([4, 5, 6], label='b'), 2)

        # forward
        return self.final(torch.cat([self.linear1(x), self.linear2(x)], 1))

    Some advanced operators are also provided, such as :meth:`ValueChoice.max` and :meth:`ValueChoice.cond`.

    .. tip::

        All the APIs have an optional argument called ``label``,
        mutations with the same label will share the same choice. A typical example is, ::

            self.net = nn.Sequential(
                nn.Linear(10, nn.ValueChoice([32, 64, 128], label='hidden_dim')),
                nn.Linear(nn.ValueChoice([32, 64, 128], label='hidden_dim'), 3)
            )

        Sharing the same value choice instance has the similar effect. ::

            class Net(nn.Module):
                def __init__(self):
                    super().__init__()
                    hidden_dim = nn.ValueChoice([128, 512])
                    self.fc = nn.Sequential(
                        nn.Linear(64, hidden_dim),
                        nn.Linear(hidden_dim, 10)
                    )

    .. warning::

        It looks as if a specific candidate has been chosen (e.g., how it looks like when you can put ``ValueChoice``
        as a parameter of ``nn.Conv2d``), but in fact it's a syntax sugar as because the basic units and evaluators
        do all the underlying works. That means, you cannot assume that ``ValueChoice`` can be used in the same way
        as its candidates. For example, the following usage will NOT work: ::

            self.blocks = []
            for i in range(nn.ValueChoice([1, 2, 3])):
                self.blocks.append(Block())

            # NOTE: instead you should probably write
            # self.blocks = nn.Repeat(Block(), (1, 3))

    Another use case is to initialize the values to choose from in init and call the module in forward to get the chosen value.
    Usually, this is used to pass a mutable value to a functional API like ``torch.xxx`` or ``nn.functional.xxx```.
    For example, ::
847

848
849
850
851
852
853
854
        class Net(nn.Module):
            def __init__(self):
                super().__init__()
                self.dropout_rate = nn.ValueChoice([0., 1.])

            def forward(self, x):
                return F.dropout(x, self.dropout_rate())
855
856
857
858
859

    Parameters
    ----------
    candidates : list
        List of values to choose from.
Yuge Zhang's avatar
Yuge Zhang committed
860
861
    prior : list of float
        Prior distribution to sample from.
862
863
864
865
    label : str
        Identifier of the value choice.
    """

Yuge Zhang's avatar
Yuge Zhang committed
866
867
    # FIXME: prior is designed but not supported yet

Yuge Zhang's avatar
Yuge Zhang committed
868
    @classmethod
869
    def create_fixed_module(cls, candidates: List[_cand], *, label: Optional[str] = None, **kwargs):
870
871
872
873
        value = get_fixed_value(label)
        if value not in candidates:
            raise ValueError(f'Value {value} does not belong to the candidates: {candidates}.')
        return value
874

875
    def __init__(self, candidates: List[_cand], *, prior: Optional[List[float]] = None, label: Optional[str] = None):
Yuge Zhang's avatar
Yuge Zhang committed
876
        super().__init__()  # type: ignore
877
        self.candidates = candidates
Yuge Zhang's avatar
Yuge Zhang committed
878
879
        self.prior = prior or [1 / len(candidates) for _ in range(len(candidates))]
        assert abs(sum(self.prior) - 1) < 1e-5, 'Sum of prior distribution is not 1.'
880
        self._label = generate_new_label(label)
881
882
883
884
885
886

    @property
    def label(self):
        return self._label

    def forward(self):
Yuge Zhang's avatar
Yuge Zhang committed
887
888
889
890
        """
        The forward of input choice is simply the first value of ``candidates``.
        It shouldn't be called directly by users in most cases.
        """
891
892
893
        warnings.warn('You should not run forward of this module directly.')
        return self.candidates[0]

894
895
896
    def inner_choices(self) -> Iterable['ValueChoice']:
        # yield self because self is the only value choice here
        yield self
897

898
    def dry_run(self) -> _cand:
899
        return self.candidates[0]
900

901
    def _evaluate(self, values: Iterator[_cand], dry_run: bool = False) -> _cand:
902
903
        if dry_run:
            return self.candidates[0]
904
        try:
905
906
907
908
909
910
            value = next(values)
        except StopIteration:
            raise ValueError(f'Value list {values} is exhausted when trying to get a chosen value of {self}.')
        if value not in self.candidates:
            raise ValueError(f'Value {value} does not belong to the candidates of {self}.')
        return value
911

912
913
    def __repr__(self):
        return f'ValueChoice({self.candidates}, label={repr(self.label)})'
914

915

916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
ValueType = TypeVar('ValueType')


class ModelParameterChoice:
    """ModelParameterChoice chooses one hyper-parameter from ``candidates``.

    .. attention::

       This API is internal, and does not guarantee forward-compatibility.

    It's quite similar to :class:`ValueChoice`, but unlike :class:`ValueChoice`,
    it always returns a fixed value, even at the construction of base model.

    This makes it highly flexible (e.g., can be used in for-loop, if-condition, as argument of any function). For example: ::

        self.has_auxiliary_head = ModelParameterChoice([False, True])
        # this will raise error if you use `ValueChoice`
        if self.has_auxiliary_head is True:  # or self.has_auxiliary_head
            self.auxiliary_head = Head()
        else:
            self.auxiliary_head = None
        print(type(self.has_auxiliary_head))  # <class 'bool'>

    The working mechanism of :class:`ModelParameterChoice` is that, it registers itself
    in the ``model_wrapper``, as a hyper-parameter of the model, and then returns the value specified with ``default``.
    At base model construction, the default value will be used (as a mocked hyper-parameter).
    In trial, the hyper-parameter selected by strategy will be used.

    Although flexible, we still recommend using :class:`ValueChoice` in favor of :class:`ModelParameterChoice`,
    because information are lost when using :class:`ModelParameterChoice` in exchange of its flexibility,
    making it incompatible with one-shot strategies and non-python execution engines.

    .. warning::

        :class:`ModelParameterChoice` can NOT be nested.

    .. tip::

        Although called :class:`ModelParameterChoice`, it's meant to tune hyper-parameter of architecture.
        It's NOT used to tune model-training hyper-parameters like ``learning_rate``.
        If you need to tune ``learning_rate``, please use :class:`ValueChoice` on arguments of :class:`nni.retiarii.Evaluator`.

    Parameters
    ----------
    candidates : list of any
        List of values to choose from.
    prior : list of float
        Prior distribution to sample from. Currently has no effect.
    default : Callable[[List[Any]], Any] or Any
        Function that selects one from ``candidates``, or a candidate.
        Use :meth:`ModelParameterChoice.FIRST` or :meth:`ModelParameterChoice.LAST` to take the first or last item.
        Default: :meth:`ModelParameterChoice.FIRST`
    label : str
        Identifier of the value choice.

    Warnings
    --------
    :class:`ModelParameterChoice` is incompatible with one-shot strategies and non-python execution engines.

    Sometimes, the same search space implemented **without** :class:`ModelParameterChoice` can be simpler, and explored
    with more types of search strategies. For example, the following usages are equivalent: ::

        # with ModelParameterChoice
        depth = nn.ModelParameterChoice(list(range(3, 10)))
        blocks = []
        for i in range(depth):
            blocks.append(Block())

        # w/o HyperParmaeterChoice
        blocks = Repeat(Block(), (3, 9))

    Examples
    --------
    Get a dynamic-shaped parameter. Because ``torch.zeros`` is not a basic unit, we can't use :class:`ValueChoice` on it.
990

991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
    >>> parameter_dim = nn.ModelParameterChoice([64, 128, 256])
    >>> self.token = nn.Parameter(torch.zeros(1, parameter_dim, 32, 32))
    """

    # FIXME: fix signature in docs

    # FIXME: prior is designed but not supported yet

    def __new__(cls, candidates: List[ValueType], *,
                prior: Optional[List[float]] = None,
                default: Union[Callable[[List[ValueType]], ValueType], ValueType] = None,
                label: Optional[str] = None) -> ValueType:
        # Actually, creating a `ModelParameterChoice` never creates one.
        # It always return a fixed value, and register a ParameterSpec

        if default is None:
            default = cls.FIRST

        try:
            return cls.create_fixed_module(candidates, label=label)
        except NoContextError:
            return cls.create_default(candidates, default, label)

    @staticmethod
    def create_default(candidates: List[ValueType],
                       default: Union[Callable[[List[ValueType]], ValueType], ValueType],
                       label: Optional[str]) -> ValueType:
        if default not in candidates:
            # could be callable
            try:
1021
                default = cast(Callable[[List[ValueType]], ValueType], default)(candidates)
1022
1023
1024
1025
1026
            except TypeError as e:
                if 'not callable' in str(e):
                    raise TypeError("`default` is not in `candidates`, and it's also not callable.")
                raise

1027
1028
        default = cast(ValueType, default)

1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
        label = generate_new_label(label)
        parameter_spec = ParameterSpec(
            label,          # name
            'choice',       # TODO: support more types
            candidates,     # value
            (label,),       # we don't have nested now
            True,           # yes, categorical
        )

        # there could be duplicates. Dedup is done in mutator
        ModelNamespace.current_context().parameter_specs.append(parameter_spec)

        return default

    @classmethod
    def create_fixed_module(cls, candidates: List[ValueType], *, label: Optional[str] = None, **kwargs) -> ValueType:
        # same as ValueChoice
        value = get_fixed_value(label)
        if value not in candidates:
            raise ValueError(f'Value {value} does not belong to the candidates: {candidates}.')
        return value

    @staticmethod
    def FIRST(sequence: Sequence[ValueType]) -> ValueType:
        """Get the first item of sequence. Useful in ``default`` argument."""
        return sequence[0]

    @staticmethod
    def LAST(sequence: Sequence[ValueType]) -> ValueType:
        """Get the last item of sequence. Useful in ``default`` argument."""
        return sequence[-1]


1062
@basic_unit
1063
class Placeholder(nn.Module):
1064
1065
1066
1067
    """
    The API that creates an empty module for later mutations.
    For advanced usages only.
    """
1068

1069
    def __init__(self, label, **related_info):
1070
1071
        self.label = label
        self.related_info = related_info
1072
        super().__init__()
1073
1074

    def forward(self, x):
Yuge Zhang's avatar
Yuge Zhang committed
1075
1076
1077
1078
        """
        Forward of placeholder is not meaningful.
        It returns input directly.
        """
1079
        return x