serializer.py 5.29 KB
Newer Older
1
2
3
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

4
import inspect
Yuge Zhang's avatar
Yuge Zhang committed
5
import os
6
7
import warnings
from typing import Any, TypeVar, Union
8

9
10
from nni.common.serializer import Traceable, is_traceable, trace, _copy_class_wrapper_attributes
from .utils import ModelNamespace
11

12
13
__all__ = ['get_init_parameters_or_fail', 'serialize', 'serialize_cls', 'basic_unit', 'model_wrapper',
           'is_basic_unit', 'is_model_wrapped']
14

15
T = TypeVar('T')
16
17


18
19
20
21
22
23
24
25
def get_init_parameters_or_fail(obj: Any):
    if is_traceable(obj):
        return obj.trace_kwargs
    raise ValueError(f'Object {obj} needs to be serializable but `trace_kwargs` is not available. '
                     'If it is a built-in module (like Conv2d), please import it from retiarii.nn. '
                     'If it is a customized module, please to decorate it with @basic_unit. '
                     'For other complex objects (e.g., trainer, optimizer, dataset, dataloader), '
                     'try to use @nni.trace.')
26
27


28
29
30
def serialize(cls, *args, **kwargs):
    """
    To create an serializable instance inline without decorator. For example,
31

32
    .. code-block:: python
33

34
35
36
37
38
39
        self.op = serialize(MyCustomOp, hidden_units=128)
    """
    warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
                  'Try to use nni.trace, e.g., nni.trace(torch.optim.Adam)(learning_rate=1e-4) instead.',
                  category=DeprecationWarning)
    return trace(cls)(*args, **kwargs)
40
41


42
def serialize_cls(cls):
43
    """
44
    To create an serializable class.
45
    """
46
47
48
    warnings.warn('nni.retiarii.serialize is deprecated and will be removed in future release. ' +
                  'Try to use nni.trace instead.', category=DeprecationWarning)
    return trace(cls)
49
50


51
52
53
def basic_unit(cls: T, basic_unit_tag: bool = True) -> Union[T, Traceable]:
    """
    To wrap a module as a basic unit, is to make it a primitive and stop the engine from digging deeper into it.
54

55
56
    ``basic_unit_tag`` is true by default. If set to false, it will not be explicitly mark as a basic unit, and
    graph parser will continue to parse. Currently, this is to handle a special case in ``nn.Sequential``.
57

58
59
60
61
    Although ``basic_unit`` calls ``trace`` in its implementation, it is not for serialization. Rather, it is meant
    to capture the initialization arguments for mutation. Also, graph execution engine will stop digging into the inner
    modules when it reaches a module that is decorated with ``basic_unit``.

62
    .. code-block:: python
63

64
65
66
67
        @basic_unit
        class PrimitiveOp(nn.Module):
            ...
    """
Yuge Zhang's avatar
Yuge Zhang committed
68
69
70
71
72
73

    # Internal flag. See nni.trace
    nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
    if nni_trace_flag.lower() == 'disable':
        return cls

74
    _check_wrapped(cls)
75

76
77
    import torch.nn as nn
    assert issubclass(cls, nn.Module), 'When using @basic_unit, the class must be a subclass of nn.Module.'
78

79
80
    cls = trace(cls)
    cls._nni_basic_unit = basic_unit_tag
81

82
83
84
85
86
87
88
89
90
    # HACK: for torch script
    # https://github.com/pytorch/pytorch/pull/45261
    # https://github.com/pytorch/pytorch/issues/54688
    # I'm not sure whether there will be potential issues
    import torch
    cls._get_nni_attr = torch.jit.ignore(cls._get_nni_attr)
    cls.trace_symbol = torch.jit.unused(cls.trace_symbol)
    cls.trace_args = torch.jit.unused(cls.trace_args)
    cls.trace_kwargs = torch.jit.unused(cls.trace_kwargs)
91

92
    return cls
93
94


95
def model_wrapper(cls: T) -> Union[T, Traceable]:
96
    """
97
    Wrap the base model (search space). For example,
98

99
    .. code-block:: python
100

101
102
103
        @model_wrapper
        class MyModel(nn.Module):
            ...
104

105
    The wrapper serves two purposes:
106

Yuge Zhang's avatar
Yuge Zhang committed
107
108
    1. Capture the init parameters of python class so that it can be re-instantiated in another process.
    2. Reset uid in namespace so that the auto label counting in each model stably starts from zero.
109
110
111

    Currently, NNI might not complain in simple cases where ``@model_wrapper`` is actually not needed.
    But in future, we might enforce ``@model_wrapper`` to be required for base model.
112
    """
Yuge Zhang's avatar
Yuge Zhang committed
113
114
115
116
117
118

    # Internal flag. See nni.trace
    nni_trace_flag = os.environ.get('NNI_TRACE_FLAG', '')
    if nni_trace_flag.lower() == 'disable':
        return cls

119
    _check_wrapped(cls)
120

121
122
    import torch.nn as nn
    assert issubclass(cls, nn.Module)
123

124
    wrapper = trace(cls)
125

126
127
128
129
    class reset_wrapper(wrapper):
        def __init__(self, *args, **kwargs):
            with ModelNamespace():
                super().__init__(*args, **kwargs)
130

131
132
133
134
    _copy_class_wrapper_attributes(wrapper, reset_wrapper)
    reset_wrapper.__wrapped__ = wrapper.__wrapped__
    reset_wrapper._nni_model_wrapper = True
    return reset_wrapper
135
136


137
138
139
140
def is_basic_unit(cls_or_instance) -> bool:
    if not inspect.isclass(cls_or_instance):
        cls_or_instance = cls_or_instance.__class__
    return getattr(cls_or_instance, '_nni_basic_unit', False)
141
142


143
144
145
146
147
148
149
150
151
def is_model_wrapped(cls_or_instance) -> bool:
    if not inspect.isclass(cls_or_instance):
        cls_or_instance = cls_or_instance.__class__
    return getattr(cls_or_instance, '_nni_model_wrapper', False)


def _check_wrapped(cls: T) -> bool:
    if getattr(cls, '_traced', False) or getattr(cls, '_nni_model_wrapper', False):
        raise TypeError(f'{cls} is already wrapped with trace wrapper (basic_unit / model_wrapper / trace). Cannot wrap again.')