utils.py 11.1 KB
Newer Older
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import inspect
5
import itertools
6
import warnings
7
from collections import defaultdict
8
from contextlib import contextmanager
9
from typing import Any, List, Dict
10
from pathlib import Path
11

12
13
14
15
from nni.common.hpo_utils import ParameterSpec

__all__ = ['NoContextError', 'ContextStack', 'ModelNamespace', 'original_state_dict_hooks']

QuanluZhang's avatar
QuanluZhang committed
16

17
def import_(target: str, allow_none: bool = False) -> Any:
18
19
20
21
22
    if target is None:
        return None
    path, identifier = target.rsplit('.', 1)
    module = __import__(path, globals(), locals(), [identifier])
    return getattr(module, identifier)
QuanluZhang's avatar
QuanluZhang committed
23

24

25
26
_last_uid = defaultdict(int)

27
28
_DEFAULT_MODEL_NAMESPACE = 'model'

29

30
31
32
def uid(namespace: str = 'default') -> int:
    _last_uid[namespace] += 1
    return _last_uid[namespace]
33
34


35
36
37
38
def reset_uid(namespace: str = 'default') -> None:
    _last_uid[namespace] = 0


39
40
def get_module_name(cls_or_func):
    module_name = cls_or_func.__module__
41
42
43
44
45
46
    if module_name == '__main__':
        # infer the module name with inspect
        for frm in inspect.stack():
            if inspect.getmodule(frm[0]).__name__ == '__main__':
                # main module found
                main_file_path = Path(inspect.getsourcefile(frm[0]))
liuzhe-lz's avatar
liuzhe-lz committed
47
                if not Path().samefile(main_file_path.parent):
48
                    raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
49
                                       f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
50
51
                module_name = main_file_path.stem
                break
52
53
54
    if module_name == '__main__':
        warnings.warn('Callstack exhausted but main module still not found. This will probably cause issues that the '
                      'function/class cannot be imported.')
55
56
57
58
59

    # NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
    # to make LSTM's source code can be found, we should assign original LSTM's __module__ to
    # the wrapped LSTM's __module__
    # TODO: find out all the modules that have the same requirement as LSTM
60
61
    if f'{cls_or_func.__module__}.{cls_or_func.__name__}' == 'torch.nn.modules.rnn.LSTM':
        module_name = cls_or_func.__module__
62

63
64
65
    return module_name


66
def get_importable_name(cls, relocate_module=False):
67
    module_name = get_module_name(cls) if relocate_module else cls.__module__
68
    return module_name + '.' + cls.__name__
69
70


Yuge Zhang's avatar
Yuge Zhang committed
71
class NoContextError(Exception):
72
    """Exception raised when context is missing."""
Yuge Zhang's avatar
Yuge Zhang committed
73
74
75
    pass


76
77
class ContextStack:
    """
78
    This is to maintain a globally-accessible context environment that is visible to everywhere.
79
80
81

    Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
    get the corresponding value in the namespace.
82
83

    Note that this is not multi-processing safe. Also, the values will get cleared for a new process.
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
    """

    _stack: Dict[str, List[Any]] = defaultdict(list)

    def __init__(self, key: str, value: Any):
        self.key = key
        self.value = value

    def __enter__(self):
        self.push(self.key, self.value)
        return self

    def __exit__(self, *args, **kwargs):
        self.pop(self.key)

    @classmethod
    def push(cls, key: str, value: Any):
        cls._stack[key].append(value)

    @classmethod
    def pop(cls, key: str) -> None:
        cls._stack[key].pop()

    @classmethod
    def top(cls, key: str) -> Any:
Yuge Zhang's avatar
Yuge Zhang committed
109
110
        if not cls._stack[key]:
            raise NoContextError('Context is empty.')
111
112
113
        return cls._stack[key][-1]


114
115
class ModelNamespace:
    """
116
117
118
119
120
121
122
    To create an individual namespace for models:

    1. to enable automatic numbering;
    2. to trace general information (like creation of hyper-parameters) of model.

    A namespace is bounded to a key. Namespace bounded to different keys are completed isolated.
    Namespace can have sub-namespaces (with the same key). The numbering will be chained (e.g., ``model_1_4_2``).
123
124
125
126
127
128
    """

    def __init__(self, key: str = _DEFAULT_MODEL_NAMESPACE):
        # for example, key: "model_wrapper"
        self.key = key

129
130
131
132
133
134
135
136
137
138
        # the "path" of current name
        # By default, it's ``[]``
        # If a ``@model_wrapper`` is nested inside a model_wrapper, it will become something like ``[1, 3, 2]``.
        # See ``__enter__``.
        self.name_path: List[int] = []

        # parameter specs.
        # Currently only used trace calls of ModelParameterChoice.
        self.parameter_specs: List[ParameterSpec] = []

139
140
141
142
143
    def __enter__(self):
        # For example, currently the top of stack is [1, 2, 2], and [1, 2, 2, 3] is used,
        # the next thing up is [1, 2, 2, 4].
        # `reset_uid` to count from zero for "model_wrapper_1_2_2_4"
        try:
144
145
146
147
148
            parent_context: 'ModelNamespace' = ModelNamespace.current_context(self.key)
            next_uid = uid(parent_context._simple_name())
            self.name_path = parent_context.name_path + [next_uid]
            ContextStack.push(self.key, self)
            reset_uid(self._simple_name())
149
        except NoContextError:
150
151
152
153
            # not found, no existing namespace
            self.name_path = []
            ContextStack.push(self.key, self)
            reset_uid(self._simple_name())
154
155
156
157

    def __exit__(self, *args, **kwargs):
        ContextStack.pop(self.key)

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    def _simple_name(self) -> str:
        return self.key + ''.join(['_' + str(k) for k in self.name_path])

    def __repr__(self):
        return f'ModelNamespace(name={self._simple_name()}, num_specs={len(self.parameter_specs)})'

    # Access the current context in the model #

    @staticmethod
    def current_context(key: str = _DEFAULT_MODEL_NAMESPACE) -> 'ModelNamespace':
        """Get the current context in key."""
        try:
            return ContextStack.top(key)
        except NoContextError:
            raise NoContextError('ModelNamespace context is missing. You might have forgotten to use `@model_wrapper`.')

174
175
    @staticmethod
    def next_label(key: str = _DEFAULT_MODEL_NAMESPACE) -> str:
176
        """Get the next label for API calls, with automatic numbering."""
177
178
179
180
        try:
            current_context = ContextStack.top(key)
        except NoContextError:
            # fallback to use "default" namespace
181
182
183
184
            # it won't be registered
            warnings.warn('ModelNamespace is missing. You might have forgotten to use `@model_wrapper`. '
                          'Some features might not work. This will be an error in future releases.', RuntimeWarning)
            current_context = ModelNamespace('default')
185

186
187
        next_uid = uid(current_context._simple_name())
        return current_context._simple_name() + '_' + str(next_uid)
188
189


190
191
def get_current_context(key: str) -> Any:
    return ContextStack.top(key)
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307


# map variables to prefix in the state dict
# e.g., {'upsample': 'mynet.module.deconv2.upsample_layer'}
STATE_DICT_PY_MAPPING = '_mapping_'

# map variables to `prefix`.`value` in the state dict
# e.g., {'upsample': 'choice3.upsample_layer'},
# which actually means {'upsample': 'mynet.module.choice3.upsample_layer'},
# and 'upsample' is also in `mynet.module`.
STATE_DICT_PY_MAPPING_PARTIAL = '_mapping_partial_'


@contextmanager
def original_state_dict_hooks(model: Any):
    """
    Use this patch if you want to save/load state dict in the original state dict hierarchy.

    For example, when you already have a state dict for the base model / search space (which often
    happens when you have trained a supernet with one-shot strategies), the state dict isn't organized
    in the same way as when a sub-model is sampled from the search space. This patch will help
    the modules in the sub-model find the corresponding module in the base model.

    The code looks like,

    .. code-block:: python

        with original_state_dict_hooks(model):
            model.load_state_dict(state_dict_from_supernet, strict=False)  # supernet has extra keys

    Or vice-versa,

    .. code-block:: python

        with original_state_dict_hooks(model):
            supernet_style_state_dict = model.state_dict()
    """

    import torch.nn as nn
    assert isinstance(model, nn.Module), 'PyTorch is the only supported framework for now.'

    # the following are written for pytorch only

    # first get the full mapping
    full_mapping = {}

    def full_mapping_in_module(src_prefix, tar_prefix, module):
        if hasattr(module, STATE_DICT_PY_MAPPING):
            # only values are complete
            local_map = getattr(module, STATE_DICT_PY_MAPPING)
        elif hasattr(module, STATE_DICT_PY_MAPPING_PARTIAL):
            # keys and values are both incomplete
            local_map = getattr(module, STATE_DICT_PY_MAPPING_PARTIAL)
            local_map = {k: tar_prefix + v for k, v in local_map.items()}
        else:
            # no mapping
            local_map = {}

        if '__self__' in local_map:
            # special case, overwrite prefix
            tar_prefix = local_map['__self__'] + '.'

        for key, value in local_map.items():
            if key != '' and key not in module._modules:  # not a sub-module, probably a parameter
                full_mapping[src_prefix + key] = value

        if src_prefix != tar_prefix:  # To deal with leaf nodes.
            for name, value in itertools.chain(module._parameters.items(), module._buffers.items()):  # direct children
                if value is None or name in module._non_persistent_buffers_set:
                    # it won't appear in state dict
                    continue
                if (src_prefix + name) not in full_mapping:
                    full_mapping[src_prefix + name] = tar_prefix + name

        for name, child in module.named_children():
            # sub-modules
            full_mapping_in_module(
                src_prefix + name + '.',
                local_map.get(name, tar_prefix + name) + '.',  # if mapping doesn't exist, respect the prefix
                child
            )

    full_mapping_in_module('', '', model)

    def load_state_dict_hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs):
        reverse_mapping = defaultdict(list)
        for src, tar in full_mapping.items():
            reverse_mapping[tar].append(src)

        transf_state_dict = {}
        for src, tar_keys in reverse_mapping.items():
            if src in state_dict:
                value = state_dict.pop(src)
                for tar in tar_keys:
                    transf_state_dict[tar] = value
            else:
                missing_keys.append(src)
        state_dict.update(transf_state_dict)

    def state_dict_hook(module, destination, prefix, local_metadata):
        result = {}
        for src, tar in full_mapping.items():
            if src in destination:
                result[tar] = destination.pop(src)
            else:
                raise KeyError(f'"{src}" not in state dict, but found in mapping.')
        destination.update(result)

    try:
        hooks = []
        hooks.append(model._register_load_state_dict_pre_hook(load_state_dict_hook))
        hooks.append(model._register_state_dict_hook(state_dict_hook))
        yield
    finally:
        for hook in hooks:
            hook.remove()