mutables.py 7.83 KB
Newer Older
Yuge Zhang's avatar
Yuge Zhang committed
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

Chi Song's avatar
Chi Song committed
4
5
import logging

6
7
import torch.nn as nn

8
from nni.nas.pytorch.utils import global_mutable_counting
9

Chi Song's avatar
Chi Song committed
10
11
logger = logging.getLogger(__name__)

12

13
class Mutable(nn.Module):
14
15
16
17
18
19
20
21
22
23
24
25
26
27
    """
    Mutable is designed to function as a normal layer, with all necessary operators' weights.
    States and weights of architectures should be included in mutator, instead of the layer itself.

    Mutable has a key, which marks the identity of the mutable. This key can be used by users to share
    decisions among different mutables. In mutator's implementation, mutators should use the key to
    distinguish different mutables. Mutables that share the same key should be "similar" to each other.

    Currently the default scope for keys is global.
    """

    def __init__(self, key=None):
        super().__init__()
        if key is not None:
28
29
            if not isinstance(key, str):
                key = str(key)
Chi Song's avatar
Chi Song committed
30
                logger.warning("Warning: key \"%s\" is not string, converted to string.", key)
31
            self._key = key
32
        else:
33
            self._key = self.__class__.__name__ + str(global_mutable_counting())
34
        self.init_hook = self.forward_hook = None
35
36

    def __deepcopy__(self, memodict=None):
37
38
39
40
41
        raise NotImplementedError("Deep copy doesn't work for mutables.")

    def __call__(self, *args, **kwargs):
        self._check_built()
        return super().__call__(*args, **kwargs)
42
43

    def set_mutator(self, mutator):
44
45
46
        if "mutator" in self.__dict__:
            raise RuntimeError("`set_mutator` is called more than once. Did you parse the search space multiple times? "
                               "Or did you apply multiple fixed architectures?")
47
48
49
        self.__dict__["mutator"] = mutator

    def forward(self, *inputs):
50
        raise NotImplementedError
51

52
53
54
    @property
    def key(self):
        return self._key
55

56
57
58
59
60
61
62
63
    @property
    def name(self):
        return self._name if hasattr(self, "_name") else "_key"

    @name.setter
    def name(self, name):
        self._name = name

64
65
66
    def _check_built(self):
        if not hasattr(self, "mutator"):
            raise ValueError(
67
68
                "Mutator not set for {}. You might have forgotten to initialize and apply your mutator. "
                "Or did you initialize a mutable on the fly in forward pass? Move to `__init__` "
69
70
                "so that trainer can locate all your mutables. See NNI docs for more details.".format(self))

71
72
    def __repr__(self):
        return "{} ({})".format(self.name, self.key)
73

74
75

class MutableScope(Mutable):
76
    """
77
    Mutable scope marks a subgraph/submodule to help mutators make better decisions.
78
79
    Mutators get notified when a mutable scope is entered and exited. Mutators can override ``enter_mutable_scope``
    and ``exit_mutable_scope`` to catch corresponding events, and do status dump or update.
80
    MutableScope are also mutables that are listed in the mutables (search space).
81
82
83
84
85
    """

    def __init__(self, key):
        super().__init__(key=key)

86
87
    def __call__(self, *args, **kwargs):
        try:
88
            self._check_built()
89
90
91
92
            self.mutator.enter_mutable_scope(self)
            return super().__call__(*args, **kwargs)
        finally:
            self.mutator.exit_mutable_scope(self)
93

94

95
class LayerChoice(Mutable):
96
    def __init__(self, op_candidates, reduction="sum", return_mask=False, key=None):
97
        super().__init__(key=key)
98
99
100
101
        self.length = len(op_candidates)
        self.choices = nn.ModuleList(op_candidates)
        self.reduction = reduction
        self.return_mask = return_mask
102
103

    def forward(self, *inputs):
104
        out, mask = self.mutator.on_forward_layer_choice(self, *inputs)
105
106
107
        if self.return_mask:
            return out, mask
        return out
108
109


110
class InputChoice(Mutable):
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    """
    Input choice selects `n_chosen` inputs from `choose_from` (contains `n_candidates` keys). For beginners,
    use `n_candidates` instead of `choose_from` is a safe option. To get the most power out of it, you might want to
    know about `choose_from`.

    The keys in `choose_from` can be keys that appear in past mutables, or ``NO_KEY`` if there are no suitable ones.
    The keys are designed to be the keys of the sources. To help mutators make better decisions,
    mutators might be interested in how the tensors to choose from come into place. For example, the tensor is the
    output of some operator, some node, some cell, or some module. If this operator happens to be a mutable (e.g.,
    ``LayerChoice`` or ``InputChoice``), it has a key naturally that can be used as a source key. If it's a
    module/submodule, it needs to be annotated with a key: that's where a ``MutableScope`` is needed.
    """

    NO_KEY = ""

    def __init__(self, n_candidates=None, choose_from=None, n_chosen=None,
127
                 reduction="sum", return_mask=False, key=None):
128
129
130
131
132
        """
        Initialization.

        Parameters
        ----------
133
        n_candidates : int
134
            Number of inputs to choose from.
135
        choose_from : list of str
136
137
138
            List of source keys to choose from. At least of one of `choose_from` and `n_candidates` must be fulfilled.
            If `n_candidates` has a value but `choose_from` is None, it will be automatically treated as `n_candidates`
            number of empty string.
139
        n_chosen : int
140
            Recommended inputs to choose. If None, mutator is instructed to select any.
141
        reduction : str
142
            `mean`, `concat`, `sum` or `none`.
143
        return_mask : bool
144
            If `return_mask`, return output tensor and a mask. Otherwise return tensor only.
145
        key : str
146
147
            Key of the input choice.
        """
148
        super().__init__(key=key)
149
150
151
152
153
154
155
156
        # precondition check
        assert n_candidates is not None or choose_from is not None, "At least one of `n_candidates` and `choose_from`" \
                                                                    "must be not None."
        if choose_from is not None and n_candidates is None:
            n_candidates = len(choose_from)
        elif choose_from is None and n_candidates is not None:
            choose_from = [self.NO_KEY] * n_candidates
        assert n_candidates == len(choose_from), "Number of candidates must be equal to the length of `choose_from`."
157
        assert n_candidates > 0, "Number of candidates must be greater than 0."
158
159
160
        assert n_chosen is None or 0 <= n_chosen <= n_candidates, "Expected selected number must be None or no more " \
                                                                  "than number of candidates."

161
        self.n_candidates = n_candidates
162
163
        self.choose_from = choose_from
        self.n_chosen = n_chosen
164
        self.reduction = reduction
165
166
        self.return_mask = return_mask

167
168
169
170
171
172
    def forward(self, optional_inputs):
        """
        Forward method of LayerChoice.

        Parameters
        ----------
173
        optional_inputs : list or dict
174
175
176
177
178
179
180
181
182
183
184
            Recommended to be a dict. As a dict, inputs will be converted to a list that follows the order of
            `choose_from` in initialization. As a list, inputs must follow the semantic order that is the same as
            `choose_from`.

        Returns
        -------
        tuple of torch.Tensor and torch.Tensor or torch.Tensor
        """
        optional_input_list = optional_inputs
        if isinstance(optional_inputs, dict):
            optional_input_list = [optional_inputs[tag] for tag in self.choose_from]
185
186
        assert isinstance(optional_input_list, list), \
            "Optional input list must be a list, not a {}.".format(type(optional_input_list))
187
188
        assert len(optional_inputs) == self.n_candidates, \
            "Length of the input list must be equal to number of candidates."
189
        out, mask = self.mutator.on_forward_input_choice(self, optional_input_list)
190
191
192
        if self.return_mask:
            return out, mask
        return out