frontend.py 16.6 KB
Newer Older
Michael Carilli's avatar
Michael Carilli committed
1
import torch
2
from ._initialize import _initialize
3
from ._amp_state import _amp_state, warn_or_err, maybe_print
Michael Carilli's avatar
Michael Carilli committed
4
5
6


class Properties(object):
7
8
    """
    This class has two purposes: to establish a set of default properties,
Michael Carilli's avatar
Michael Carilli committed
9
10
11
12
13
    and to route setting of these attributes through __setattr__ so that (in theory)
    they can be checked for consistency with other existing args.
    """
    def __init__(self):
        self.options = {
14
            "enabled" : False,
Michael Carilli's avatar
Michael Carilli committed
15
16
            "opt_level" : None,
            "cast_model_type" : None,
17
18
            "patch_torch_functions" : False,
            "keep_batchnorm_fp32" : None,
19
            "master_weights" : None,
Michael Carilli's avatar
Michael Carilli committed
20
            "loss_scale" : 1.0,
21
22
23
24
            # Reserved for future functionality
            # "fused_optimizer" : False,
            # "enable_ddp_interop" : False,
            }
Michael Carilli's avatar
Michael Carilli committed
25
26

    """
27
    This function allows updating several options at a time without routing through
Michael Carilli's avatar
Michael Carilli committed
28
    __setattr__ checks, to avoid "you can't get there from here" scenarios.
29
30
    Currently not intended to be exposed; users are expected to select an opt_level
    and apply consistent modifications.
Michael Carilli's avatar
Michael Carilli committed
31
    """
32
    def _update_options_dict(new_options):
Michael Carilli's avatar
Michael Carilli committed
33
34
35
36
37
38
        for k, v in new_options:
            if k in self.options:
                self.options[k] = v
            else:
                raise ValueError("Tried to set unexpected option {}".format(k))
    """
39
40
    The members of "options" are not direct attributes of self, so access attempts
    will roll down to __getattr__.  This borrows from the logic in torch.nn.Module.
Michael Carilli's avatar
Michael Carilli committed
41
42
43
44
45
46
47
48
    """
    def __getattr__(self, name):
        if "options" in self.__dict__:
            options =  self.__dict__["options"]
            if name in options:
                return options[name]
        raise AttributeError("'{}' object has no attribute '{}'".format(
            type(self).__name__, name))
49

Michael Carilli's avatar
Michael Carilli committed
50
51
52
    def __setattr__(self, name, value):
        if "options" in self.__dict__:
            if name in self.options:
53
                # print("setting {} {}".format(name, value))
54
55
56
57
58
59
60
61
62
63
64
65
66
67
                if name == "cast_model_type":
                    if self.opt_level == "O1" and value is not None:
                        if value is not torch.float32:
                            warn_or_err("O1 inserts casts around Torch functions rather than "
                                        "model weights, so with O1, the model weights themselves "
                                        "should remain FP32. If you wish to cast the model to a "
                                        "different type, use opt_level='O2' or 'O3'. " +
                                        "cast_model_type was {}".format(value))
                    self.options[name] = value
                elif name == "patch_torch_functions":
                    if self.opt_level != "O1" and value:
                        warn_or_err("Currently, patch_torch_functions=True should only be set by "
                                    "selecting opt_level='O1'.")
                    self.options[name] = value
68
                elif name == "keep_batchnorm_fp32":
69
70
71
72
                    if self.opt_level == "O1" and value is not None:
                        warn_or_err("With opt_level O1, batchnorm functions are automatically patched "
                                    "to run in FP32, so keep_batchnorm_fp32 should be None." +
                                    "keep_batchnorm_fp32 was {}".format(keep_batchnorm_fp32))
73
74
75
76
77
78
79
                    if value == "False":
                        self.options[name] = False
                    elif value == "True":
                        self.options[name] = True
                    else:
                        assert (value is True or value is False or value is None),\
                            "keep_batchnorm_fp32 must be a boolean, the string 'True' or 'False', "\
80
                            "or None, found keep_batchnorm_fp32={}".format(keep_batchnorm_fp32)
81
                        self.options[name] = value
82
83
84
85
86
87
88
89
90
91
                elif name == "master_weights":
                    if self.opt_level == "O1" and value is not None:
                        warn_or_err("It doesn't make sense to use master_weights with O1. "
                                    "With O1, your model weights themselves should be FP32.")
                    self.options[name] = value
                elif name == "loss_scale":
                    if value == "dynamic":
                        self.options[name] = value
                    else:
                        self.options[name] = float(value)
92
93
                else:
                    self.options[name] = value
Michael Carilli's avatar
Michael Carilli committed
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
        else:
            super(Properties, self).__setattr__(name, value)


""" O0-O3 are convenience wrappers to establish defaults for typically used mixed precision options. """

class O3:
    brief = "O3:  Pure FP16 training."
    more = "Calls .half() on your model, converting the entire model to FP16.\n"\
        "A casting operation is also inserted to cast incoming Tensors to FP16,\n"\
        "so you don't need to change your data pipeline.\n"\
        "This mode is useful for establishing a performance ceiling.\n"\
        "It's also possible training may 'just work' in this mode.\n"\
        "If not, try other optimization levels."

    def __call__(self, properties):
110
111
        properties.enabled = True
        properties.opt_level = "O3"
Michael Carilli's avatar
Michael Carilli committed
112
        properties.cast_model_type = torch.float16
113
114
        properties.patch_torch_functions = False
        properties.keep_batchnorm_fp32 = False
Michael Carilli's avatar
Michael Carilli committed
115
116
        properties.master_weights = False
        properties.loss_scale = 1.0
117
118
        # properties.fused_optimizer = False
        # properties.enable_ddp_interop = False
Michael Carilli's avatar
Michael Carilli committed
119
120
121
122
123
124
125
126
127
128
129
130
131
132
        return properties # modified in place so this isn't really necessary


class O2:
    brief = "O2:  FP16 training with FP32 batchnorm and FP32 master weights.\n"
    more = "Calls .half() on your model, converting the entire model (except for batchnorms)\n"\
        "to FP16.  Batchnorms are retained in FP32 for additional stability.\n"\
        "The forward pass is patched to cast incoming Tensors to FP16, so you don't need to change\n"\
        "your data pipeline.\n"\
        "O2 creates FP32 master weights outside the model and patches any optimizers to update\n"\
        "these master weights, then copy the master weights into the FP16 model weights.\n"\
        "Master weights can also improve convergence and stability."

    def __call__(self, properties):
133
134
        properties.enabled = True
        properties.opt_level = "O2"
Michael Carilli's avatar
Michael Carilli committed
135
        properties.cast_model_type = torch.float16
136
        properties.patch_torch_functions = False
137
        properties.keep_batchnorm_fp32 = True
Michael Carilli's avatar
Michael Carilli committed
138
        properties.master_weights = True
139
        properties.loss_scale = "dynamic"
140
141
        # properties.fused_optimizer = False
        # properties.enable_ddp_interop = False
Michael Carilli's avatar
Michael Carilli committed
142
143
144
145
146
147
148
149
150
151
152
153
154
        return properties # modified in place so this isn't really necessary


class O1:
    brief = "O1:  Insert automatic casts around Pytorch functions and Tensor methods.\n"
    more = "The type of your model's weights is not altered.  However, internally,\n"\
        "Pytorch functions are patched to cast any Tensor Core-friendly ops to FP16 for speed,\n"\
        "while operations that might benefit from the additional stability of FP32 are patched\n"\
        "to cast their inputs to fp32.\n"\
        "O1 is the safest way to try mixed precision training, and is recommended when\n"\
        "trying mixed precision training for the first time."

    def __call__(self, properties):
155
156
        properties.enabled = True
        properties.opt_level = "O1"
157
        properties.cast_model_type = None
158
        properties.patch_torch_functions = True
159
        properties.keep_batchnorm_fp32 = None
160
        properties.master_weights = None
Michael Carilli's avatar
Michael Carilli committed
161
        properties.loss_scale = "dynamic"
162
163
        # properties.fused_optimizer = False
        # properties.enable_ddp_interop = False
Michael Carilli's avatar
Michael Carilli committed
164
165
166
167
168
169
170
        return properties # modified in place so this isn't really necessary


class O0:
    brief = "O0:  Pure FP32 training.\n"
    more = "Your models are checked to make sure parameters are FP32, but otherwise the\n"\
        "types of weights and internal Pytorch operations are not altered.  This mode disables any\n"\
171
        "FP16 arithmetic, although other optimizations like DDP interop may still be requested.\n"
Michael Carilli's avatar
Michael Carilli committed
172
173

    def __call__(self, properties):
174
175
        properties.enabled = True
        properties.opt_level = "O0"
Michael Carilli's avatar
Michael Carilli committed
176
        properties.cast_model_type = torch.float32
177
        properties.patch_torch_functions = False
178
        properties.keep_batchnorm_fp32 = None
Michael Carilli's avatar
Michael Carilli committed
179
180
        properties.master_weights = False
        properties.loss_scale = 1.0
181
182
        # properties.fused_optimizer = False
        # properties.enable_ddp_interop = False
Michael Carilli's avatar
Michael Carilli committed
183
184
185
186
187
188
189
190
191
192
        return properties # modified in place so this isn't really necessary


opt_levels = {"O3": O3(),
              "O2": O2(),
              "O1": O1(),
              "O0": O0()}


# allow user to directly pass Properties struct as well?
193
194
195
196
197
198
199
200
201
def initialize(
    models,
    optimizers,
    enabled=True,
    opt_level=None,
    cast_model_type=None,
    patch_torch_functions=None,
    keep_batchnorm_fp32=None,
    master_weights=None,
202
203
    loss_scale=None,
    verbosity=1,
204
    ):
205
    """
Michael Carilli's avatar
Michael Carilli committed
206
207
208
    Initialize your models, optimizers, and the Torch tensor and functional namespace according to the
    chosen ``opt_level`` and overridden properties, if any.

Michael Carilli's avatar
Michael Carilli committed
209
210
211
212
213
214
    ``amp.initialize`` must be called **after** you have finished constructing your model(s) and
    optimizer(s), but **before** you send your model through any DistributedDataParallel wrapper.
    See `Distributed training`_ in the Imagenet example.

    Any property keyword argument that is not ``None`` will be interpreted as a manual override.

Michael Carilli's avatar
Michael Carilli committed
215
216
217
218
219
220
221
222
    To prevent having to rewrite anything else in your script, name the returned models/optimizers
    to replace the passed models/optimizers, as in the Usage below.

    Args:
        models (torch.nn.Module or list of torch.nn.Modules):  Models to modify/cast.
        optimizers (torch.optim.Optimizer or list of torch.optim.Optimizers):  Optimizers to modify/cast.
        enabled (bool, optional, default=True):  If False, renders all Amp calls no-ops, so your script
            should run as if Amp were not present.
223
        opt_level (str, required):  Pure or mixed precision optimization level.  Accepted values are
Michael Carilli's avatar
Michael Carilli committed
224
            "O0", "O1", "O2", and "O3", explained in detail above.
225
        cast_model_type (``torch.dtype``, optional, default=None):  Optional property override, see
Michael Carilli's avatar
Michael Carilli committed
226
227
228
229
230
            above.
        patch_torch_functions (bool, optional, default=None):  Optional property override.
        keep_batchnorm_fp32 (bool or str, optional, default=None):  Optional property override.  If
            passed as a string, must be the string "True" or "False".
        master_weights (bool, optional, default=None):  Optional property override.
231
        loss_scale (float or str, default=None):  Optional property override.  If passed as a string,
Michael Carilli's avatar
Michael Carilli committed
232
            must be a string representing a number, e.g., "128.0", or the string "dynamic".
233
        verbosity (int, default=1):  Set to 0 to suppress Amp-related output.
Michael Carilli's avatar
Michael Carilli committed
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264

    Returns:
        Model(s) and optimizer(s) modified according to the ``opt_level``.
        If either the ``models`` or ``optimizers`` args were lists, the corresponding return value will
        also be a list.

    Usage::

        model, optim = amp.initialize(model, optim,...)
        model, [optim1, optim2] = amp.initialize(model, [optim1, optim2],...)
        [model1, model2], optim = amp.initialize([model1, model2], optim,...)
        [model1, model2], [optim1, optim2] = amp.initialize([model1, model2], [optim1, optim2],...)

        # This is not an exhaustive list of the cross product of options that are possible,
        # just a set of examples.
        model, optim = amp.initialize(model, optim, opt_level="O0")
        model, optim = amp.initialize(model, optim, opt_level="O0", loss_scale="dynamic"|128.0|"128.0")

        model, optim = amp.initialize(model, optim, opt_level="O1") # uses "loss_scale="dynamic" default
        model, optim = amp.initialize(model, optim, opt_level="O1", loss_scale=128.0|"128.0")

        model, optim = amp.initialize(model, optim, opt_level="O2") # uses "loss_scale="dynamic" default
        model, optim = amp.initialize(model, optim, opt_level="O2", loss_scale=128.0|"128.0")
        model, optim = amp.initialize(model, optim, opt_level="O2", keep_batchnorm_fp32=True|False|"True"|"False")

        model, optim = amp.initialize(model, optim, opt_level="O3") # uses loss_scale=1.0 default
        model, optim = amp.initialize(model, optim, opt_level="O3", loss_scale="dynamic"|128.0|"128.0")
        model, optim = amp.initialize(model, optim, opt_level="O3", keep_batchnorm_fp32=True|False|"True"|"False")

    The `Imagenet example`_ demonstrates live use of various opt_levels and overrides.

Michael Carilli's avatar
Michael Carilli committed
265
266
267
    .. _`Distributed training`:
        https://github.com/NVIDIA/apex/tree/master/examples/imagenet#distributed-training

Michael Carilli's avatar
Michael Carilli committed
268
269
    .. _`Imagenet example`:
        https://github.com/NVIDIA/apex/tree/master/examples/imagenet
270
    """
271
272
273
    _amp_state.opt_properties = Properties()
    _amp_state.opt_properties.verbosity = verbosity

Michael Carilli's avatar
Michael Carilli committed
274
    if not enabled:
275
        return models, optimizers
Michael Carilli's avatar
Michael Carilli committed
276
277

    if opt_level not in opt_levels:
278
279
280
        raise RuntimeError(
            "Unexpected optimization level {}. ".format(opt_level) +
            "Options are 'O0', 'O1', 'O2', 'O3'.")
Michael Carilli's avatar
Michael Carilli committed
281
    else:
282
283
284
        _amp_state.opt_properties = opt_levels[opt_level](_amp_state.opt_properties)
        maybe_print("Selected optimization level {}".format(opt_levels[opt_level].brief), True)
        maybe_print("Defaults for this optimization level are:", True)
285
        for k, v in _amp_state.opt_properties.options.items():
286
            maybe_print("{:22} : {}".format(k, v), True)
Michael Carilli's avatar
Michael Carilli committed
287

288
289
290
    maybe_print("Processing user overrides (additional kwargs that are not None)...", True)
    # I chose to have the keyword arguments listed directly in the argument list,
    # instead of **kwargs, so I can't use kwargs.items() here.
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    if enabled is not None:
        _amp_state.opt_properties.enabled = enabled
    if opt_level is not None:
        _amp_state.opt_properties.opt_level = opt_level
    if cast_model_type is not None:
        _amp_state.opt_properties.cast_model_type = cast_model_type
    if patch_torch_functions is not None:
        _amp_state.opt_properties.patch_torch_functions = patch_torch_functions
    if keep_batchnorm_fp32 is not None:
        _amp_state.opt_properties.keep_batchnorm_fp32 = keep_batchnorm_fp32
    if master_weights is not None:
        _amp_state.opt_properties.master_weights = master_weights
    if loss_scale is not None:
        _amp_state.opt_properties.loss_scale = loss_scale
Michael Carilli's avatar
Michael Carilli committed
305

306
    maybe_print("After processing overrides, optimization options are:", True)
307
    for k, v in _amp_state.opt_properties.options.items():
308
        maybe_print("{:22} : {}".format(k, v), True)
Michael Carilli's avatar
Michael Carilli committed
309

310
    return _initialize(models, optimizers, _amp_state.opt_properties)
Michael Carilli's avatar
Michael Carilli committed
311
312


Michael Carilli's avatar
Michael Carilli committed
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
# TODO:  is this necessary/useful?
# def check_option_consistency(enabled=True,
#                              opt_level=None,
#                              cast_model_type=None,
#                              patch_torch_functions=None,
#                              keep_batchnorm_fp32=None,
#                              master_weights=None,
#                              loss_scale=None,
#                              enable_ddp_interop=None,
#                              hard_override=False):
#     """
#     Utility function that enables users to quickly check if the option combination they intend
#     to use is permitted.  ``check_option_consistency`` does not require models or optimizers
#     to be constructed, and can be called at any point in the script.  ``check_option_consistency``
#     is totally self-contained; it does not set any amp global state or affect anything outside
#     of itself.
#     """
#
#     if not enabled:
#         return
#
#     if opt_level not in opt_levels:
#         raise RuntimeError("Unexpected optimization level.  Options are 'O0', 'O1', 'O2', 'O3'.")
#     else:
#         opt_properties = opt_levels[opt_level](Properties())
#         print("Selected optimization level {}", opt_levels[opt_level].brief)
#         print("Defaults for this optimization level are:")
#         for k, v in opt_properties.options:
#             print("{:22} : {}".format(k, v))
#
#     print("Processing user overrides (additional kwargs that are not None)...")
#     for k, v in kwargs:
#         if k not in _amp_state.opt_properties.options:
#             raise RuntimeError("Unexpected kwarg {}".format(k))
#         if v is not None:
#             setattr(opt_properties, k, v)
#
#     print("After processing overrides, optimization options are:")
#     for k, v in opt_properties.options:
#         print("{:22} : {}".format(k, v))