fp16.py 31.8 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stable version of apex FP16 Optimizer"""
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn.parameter import Parameter
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors

from .loss_scaler import DynamicLossScaler, LossScaler
from .fp16util import model_grads_to_master_grads, master_params_to_model_params, clip_grad_norm

Mohammad's avatar
Mohammad committed
25
26
from apex.multi_tensor_apply import multi_tensor_applier
import amp_C
27
28

from megatron.module import MegatronModule
29
from megatron import mpu
30

Raul Puri's avatar
Raul Puri committed
31
32
33
FLOAT_TYPES = (torch.FloatTensor, torch.cuda.FloatTensor)
HALF_TYPES = (torch.HalfTensor, torch.cuda.HalfTensor)

Neel Kant's avatar
Neel Kant committed
34

Raul Puri's avatar
Raul Puri committed
35
36
37
38
def conversion_helper(val, conversion):
    """Apply conversion to val. Recursively apply conversion if `val` is a nested tuple/list structure."""
    if not isinstance(val, (tuple, list)):
        return conversion(val)
Neel Kant's avatar
Neel Kant committed
39
    rtn = [conversion_helper(v, conversion) for v in val]
Raul Puri's avatar
Raul Puri committed
40
41
42
43
    if isinstance(val, tuple):
        rtn = tuple(rtn)
    return rtn

Neel Kant's avatar
Neel Kant committed
44

Raul Puri's avatar
Raul Puri committed
45
46
47
48
49
50
51
52
53
54
55
def fp32_to_fp16(val):
    """Convert fp32 `val` to fp16"""
    def half_conversion(val):
        val_typecheck = val
        if isinstance(val_typecheck, (Parameter, Variable)):
            val_typecheck = val.data
        if isinstance(val_typecheck, FLOAT_TYPES):
            val = val.half()
        return val
    return conversion_helper(val, half_conversion)

Neel Kant's avatar
Neel Kant committed
56

Raul Puri's avatar
Raul Puri committed
57
58
59
60
61
62
63
64
65
66
67
def fp16_to_fp32(val):
    """Convert fp16 `val` to fp32"""
    def float_conversion(val):
        val_typecheck = val
        if isinstance(val_typecheck, (Parameter, Variable)):
            val_typecheck = val.data
        if isinstance(val_typecheck, HALF_TYPES):
            val = val.float()
        return val
    return conversion_helper(val, float_conversion)

Neel Kant's avatar
Neel Kant committed
68

69
class FP16_Module(MegatronModule):
Raul Puri's avatar
Raul Puri committed
70
71
72
73
74
    def __init__(self, module):
        super(FP16_Module, self).__init__()
        self.add_module('module', module.half())

    def forward(self, *inputs, **kwargs):
75
        if mpu.is_pipeline_first_stage():
76
77
            inputs = fp32_to_fp16(inputs)
        outputs = self.module(*inputs, **kwargs)
78
        if mpu.is_pipeline_last_stage():
79
80
            outputs = fp16_to_fp32(outputs)
        return outputs
Raul Puri's avatar
Raul Puri committed
81
82
83
84

    def state_dict(self, destination=None, prefix='', keep_vars=False):
        return self.module.state_dict(destination, prefix, keep_vars)

85
86
87
88
89
    def state_dict_for_save_checkpoint(self, destination=None, prefix='',
                                       keep_vars=False):
        return self.module.state_dict_for_save_checkpoint(destination, prefix,
                                                          keep_vars)

Raul Puri's avatar
Raul Puri committed
90
91
92
93
    def load_state_dict(self, state_dict, strict=True):
        self.module.load_state_dict(state_dict, strict=strict)

# TODO:  Update overflow check + downscale to use Carl's fused kernel.
Neel Kant's avatar
Neel Kant committed
94
95


Raul Puri's avatar
Raul Puri committed
96
97
class FP16_Optimizer(object):
    """
Neel Kant's avatar
Neel Kant committed
98
    :class:`FP16_Optimizer` is designed to wrap an existing PyTorch optimizer,
Raul Puri's avatar
Raul Puri committed
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
    and manage static or dynamic loss scaling and master weights in a manner transparent to the user.
    For standard use, only two lines must be changed:  creating the :class:`FP16_Optimizer` instance,
    and changing the call to ``backward``.

    Example::

        model = torch.nn.Linear(D_in, D_out).cuda().half()
        optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
        # Name the FP16_Optimizer instance to replace the existing optimizer
        # (recommended but not required):
        optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
        ...
        # loss.backward() becomes:
        optimizer.backward(loss)
        ...

    Example with dynamic loss scaling::

        ...
        optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
                                   # optional arg to control dynamic loss scaling behavior
                                   # dynamic_loss_args={'scale_window' : 500})
Neel Kant's avatar
Neel Kant committed
121
                                   # Usually, dynamic_loss_args is not necessary.
Raul Puri's avatar
Raul Puri committed
122
123

    Args:
Neel Kant's avatar
Neel Kant committed
124
        init_optimizer (torch.optim.optimizer):  Existing optimizer created with the parameters to optimize.  Internally, :class:`FP16_Optimizer` replaces the passed optimizer's fp16 parameters, if any, with fp32 master parameters copied from the original ones.  :class:`FP16_Optimizer` also stores references to the original fp16 parameters, and updates these fp16 parameters from the master fp32 copy at the end of each :attr:`step`.
Raul Puri's avatar
Raul Puri committed
125
126
127
128
129
        static_loss_scale (float, optional, default=1.0):  Loss scale used internally to scale gradients computed by the model.  Any fp16 gradients will be copied to fp32, then downscaled before being applied to the fp32 master params, so ``static_loss_scale`` should not affect learning rate.
        dynamic_loss_scale (bool, optional, default=False):  Use dynamic loss scaling.  If True, this will override any ``static_loss_scale`` option.
        dynamic_loss_args (dict, optional, default=None):  Dict of kwargs that will be forwarded to the internal :class:`DynamicLossScaler` instance's constructor.  Keys of this dict must match kwargs accepted by :class:`DynamicLossScaler`'s constructor.  If ``dynamic_loss_args`` is unspecified, :class:`DynamicLossScaler`'s defaults will be used.
        verbose (bool, optional, default=True):  By default, FP16_Optimizer's constructor prints out the parameters and parameter groups it is ingesting, as a sanity check.  If this becomes annoying (e.g. for large models), it can be disabled by passing ``verbose=False``.  ``verbose=False`` will not disable printing when the loss scale is readjusted during dynamic loss scaling.

Neel Kant's avatar
Neel Kant committed
130
131
132
    ``init_optimizer`` is expected to have been constructed in the ordinary way.
    It is recommended (although not required) that the newly constructed :class:`FP16_Optimizer` instance be
    named to replace ``init_optimizer``, for two reasons:
Raul Puri's avatar
Raul Puri committed
133
    First, it means that references to the same name
Neel Kant's avatar
Neel Kant committed
134
135
    later in the file will not have to change.
    Second, :class:`FP16_Optimizer` reserves the right (as an implementation detail) to
Raul Puri's avatar
Raul Puri committed
136
137
138
139
    modify ``init_optimizer``.  If you do choose a unique name for the new
    :class:`FP16_Optimizer` instance, you should only work with this new instance,
    because the preexisting optimizer might no longer behave as expected.

Neel Kant's avatar
Neel Kant committed
140
141
142
143
    ``init_optimizer`` may be any Pytorch optimizer.
    It may contain a mixture of fp16 and fp32 parameters organized into any number of
    ``param_groups`` with different hyperparameters.  The :class:`FP16_Optimizer` constructor will
    ingest these ``param_groups`` and remember them.
Raul Puri's avatar
Raul Puri committed
144
145
146

    Calls to ::

Neel Kant's avatar
Neel Kant committed
147
        loss.backward()
Raul Puri's avatar
Raul Puri committed
148
149
150

    must be replaced with ::

Neel Kant's avatar
Neel Kant committed
151
        optimizer.backward(loss)
Raul Puri's avatar
Raul Puri committed
152

Neel Kant's avatar
Neel Kant committed
153
    because :class:`FP16_Optimizer` requires ownership of the backward pass to implement
Raul Puri's avatar
Raul Puri committed
154
155
156
157
158
    loss scaling and copies to master gradients.

    .. note::
        Loss scaling, either static or dynamic, is orthogonal to learning rate, because gradients
        are downscaled before being applied.  This means that adjusting the loss scale, or using
Neel Kant's avatar
Neel Kant committed
159
        dynamic loss scaling, should not require retuning the learning rate or any other
Raul Puri's avatar
Raul Puri committed
160
161
162
163
164
165
166
167
168
        hyperparameters.


    **Advanced options**

    **Closures**:  :class:`FP16_Optimizer` can wrap a Pytorch optimizer that receives a closure.
    See docstring for :attr:`step`.

    **Gradient clipping**:  Use :attr:`clip_master_grads`.
Neel Kant's avatar
Neel Kant committed
169

Raul Puri's avatar
Raul Puri committed
170
171
172
173
174
175
176
177
178
179
    **Multiple losses**:  If your model accumulates gradients from multiple losses,
    this can be made more efficient by supplying ``update_master_grads=False``
    to :attr:`backward`.  See docstring for :attr:`backward`.

    **Manually adjusting loss scale**:  The current loss scale can be retrieved or set via ::

        print(optimizer.loss_scale)
        optimizer.loss_scale = new_loss_scale

    For static loss scaling, manually adjusting the loss scale over time is a reasonable
Neel Kant's avatar
Neel Kant committed
180
    thing to do.  During later epochs, gradients may become smaller, and a
Raul Puri's avatar
Raul Puri committed
181
    higher loss scale may be required, analogous to scheduling the learning rate.  Dynamic loss
Neel Kant's avatar
Neel Kant committed
182
    scaling is more subtle (see :class:`DynamicLossScaler`) and in this case, manually adjusting
Raul Puri's avatar
Raul Puri committed
183
184
185
    the loss scale is not recommended.

    **Multi_GPU training**:  If the wrapped ``init_optimizer`` was created from a model wrapped in
Neel Kant's avatar
Neel Kant committed
186
    Pytorch DistributedDataParallel or Apex DistributedDataParallel, :class:`FP16_Optimizer`
Raul Puri's avatar
Raul Puri committed
187
188
189
    should still work as intended.
    """

Neel Kant's avatar
Neel Kant committed
190
191
192
    def __init__(self,
                 init_optimizer,
                 static_loss_scale=1.0,
Raul Puri's avatar
Raul Puri committed
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
                 dynamic_loss_scale=False,
                 dynamic_loss_args=None,
                 verbose=False):
        if not torch.cuda.is_available:
            raise SystemError("Cannot use fp16 without CUDA.")

        self.verbose = verbose

        self.optimizer = init_optimizer
        # init_state_dict sets up an alternative way to cast per-param state tensors.
        # Stashing here in case https://github.com/pytorch/pytorch/issues/7733 makes it necessary.
        # init_state_dict = init_optimizer.state_dict()

        self.fp16_groups = []
        self.fp32_from_fp16_groups = []
        self.fp32_from_fp32_groups = []
        for i, param_group in enumerate(self.optimizer.param_groups):
            self.maybe_print("FP16_Optimizer processing param group {}:".format(i))
            fp16_params_this_group = []
            fp32_params_this_group = []
            fp32_from_fp16_params_this_group = []
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:
                    if param.type() == 'torch.cuda.HalfTensor':
                        self.maybe_print("FP16_Optimizer received torch.cuda.HalfTensor with {}"
                                         .format(param.size()))
                        fp16_params_this_group.append(param)
                        master_param = param.detach().clone().float()
                        master_param.requires_grad = True
222
                        # Copythe model parallel flag.
223
                        master_param.tensor_model_parallel = param.tensor_model_parallel
Raul Puri's avatar
Raul Puri committed
224
225
226
227
228
                        param_group['params'][i] = master_param
                        fp32_from_fp16_params_this_group.append(master_param)
                        # Reset existing state dict key to the new master param.
                        # We still need to recast per-param state tensors, if any, to FP32.
                        if param in self.optimizer.state:
Neel Kant's avatar
Neel Kant committed
229
                            self.optimizer.state[master_param] = self.optimizer.state.pop(param)
Raul Puri's avatar
Raul Puri committed
230
231
232
233
234
235
236
                    elif param.type() == 'torch.cuda.FloatTensor':
                        self.maybe_print("FP16_Optimizer received torch.cuda.FloatTensor with {}"
                                         .format(param.size()))
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param
                    else:
                        raise TypeError("Wrapped parameters must be either "
Neel Kant's avatar
Neel Kant committed
237
                                        "torch.cuda.FloatTensor or torch.cuda.HalfTensor. "
Raul Puri's avatar
Raul Puri committed
238
                                        "Received {}".format(param.type()))
Neel Kant's avatar
Neel Kant committed
239

Raul Puri's avatar
Raul Puri committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
            self.fp16_groups.append(fp16_params_this_group)
            self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
            self.fp32_from_fp32_groups.append(fp32_params_this_group)

        # Leverage state_dict() and load_state_dict() to recast preexisting per-param state tensors
        self.optimizer.load_state_dict(self.optimizer.state_dict())
        # alternative way to cast per-param state tensors:
        # self.optimizer.load_state_dict(init_state_dict)

        if dynamic_loss_scale:
            self.dynamic_loss_scale = True
            if dynamic_loss_args is not None:
                self.loss_scaler = DynamicLossScaler(**dynamic_loss_args)
            else:
                self.loss_scaler = DynamicLossScaler()
        else:
            self.dynamic_loss_scale = False
            self.loss_scaler = LossScaler(static_loss_scale)

        self.overflow = False
        self.first_closure_call_this_step = True

        self.clip_grad_norm = clip_grad_norm

    def maybe_print(self, msg):
        if self.verbose:
            print(msg)
Neel Kant's avatar
Neel Kant committed
267

Raul Puri's avatar
Raul Puri committed
268
269
270
271
272
273
274
275
276
277
278
279
280
281
    def __getstate__(self):
        raise RuntimeError("FP16_Optimizer should be serialized using state_dict().")

    def __setstate__(self, state):
        raise RuntimeError("FP16_Optimizer should be deserialized using load_state_dict().")

    def zero_grad(self, set_grads_to_None=False):
        """
        Zero fp32 and fp16 parameter grads.
        """
        # In principle, only the .grad attributes of the model params need to be zeroed,
        # because gradients are copied into the FP32 master params.  However, we zero
        # all gradients owned by the optimizer, just to be safe:
        for group in self.optimizer.param_groups:
Neel Kant's avatar
Neel Kant committed
282
283
284
285
286
287
288
            for p in group['params']:
                if set_grads_to_None:
                    p.grad = None
                else:
                    if p.grad is not None:
                        p.grad.detach_()
                        p.grad.zero_()
Raul Puri's avatar
Raul Puri committed
289
290
291
292
293
294
295
296

        # Zero fp16 gradients owned by the model:
        for fp16_group in self.fp16_groups:
            for param in fp16_group:
                if set_grads_to_None:
                    param.grad = None
                else:
                    if param.grad is not None:
Neel Kant's avatar
Neel Kant committed
297
                        param.grad.detach_()  # as in torch.optim.optimizer.zero_grad()
Raul Puri's avatar
Raul Puri committed
298
299
300
                        param.grad.zero_()

    def _check_overflow(self):
Neel Kant's avatar
Neel Kant committed
301
        params = []
Raul Puri's avatar
Raul Puri committed
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
        for group in self.fp16_groups:
            for param in group:
                params.append(param)
        for group in self.fp32_from_fp32_groups:
            for param in group:
                params.append(param)
        self.overflow = self.loss_scaler.has_overflow(params)

    def _update_scale(self, has_overflow=False):
        self.loss_scaler.update_scale(has_overflow)

    def _master_params_to_model_params(self):
        for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
            master_params_to_model_params(fp16_group, fp32_from_fp16_group)

    def _model_params_to_master_params(self):
        for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
            master_params_to_model_params(fp32_from_fp16_group, fp16_group)

Neel Kant's avatar
Neel Kant committed
321
322
323
    # To consider:  Integrate distributed with this wrapper by registering a hook on each variable
    # that does the overflow check, gradient copy + downscale, and fp32
    # allreduce in a different stream.
Raul Puri's avatar
Raul Puri committed
324
325
326
327
328
329
330
    def _model_grads_to_master_grads(self):
        for fp16_group, fp32_from_fp16_group in zip(self.fp16_groups, self.fp32_from_fp16_groups):
            model_grads_to_master_grads(fp16_group, fp32_from_fp16_group)

    def _downscale_master(self):
        if self.loss_scale != 1.0:
            for group in self.optimizer.param_groups:
Mohammad's avatar
Mohammad committed
331
332
333
334
335
336
337
                grads = [p.grad for p in group['params'] if p.grad is not None]
                _overflow_buf = torch.cuda.IntTensor([0])
                multi_tensor_applier(amp_C.multi_tensor_scale,
                                     _overflow_buf,
                                     [grads, grads],
                                     1./self.loss_scale)
      
Raul Puri's avatar
Raul Puri committed
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
    def clip_master_grads(self, max_norm, norm_type=2):
        """
        Clips fp32 master gradients via ``torch.nn.utils.clip_grad_norm``.

        Args:
            max_norm (float or int): max norm of the gradients
            norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
                infinity norm.

        Returns:
            Total norm of the current fp32 gradients (viewed as a single vector).

        .. warning::
            Returns -1 if the most recently computed fp16 gradients overflowed (that is, if ``self.overflow`` is ``True``).
        """
        if not self.overflow:
            fp32_params = []
            for param_group in self.optimizer.param_groups:
                for param in param_group['params']:
                    fp32_params.append(param)
            return self.clip_grad_norm(fp32_params, max_norm, norm_type)
        else:
            return -1

    def state_dict(self):
        """
        Returns a dict containing the current state of this :class:`FP16_Optimizer` instance.
        This dict contains attributes of :class:`FP16_Optimizer`, as well as the state_dict
        of the contained Pytorch optimizer.
        Example::

            checkpoint = {}
            checkpoint['model'] = model.state_dict()
            checkpoint['optimizer'] = optimizer.state_dict()
            torch.save(checkpoint, "saved.pth")
        """
        state_dict = {}
        state_dict['loss_scaler'] = self.loss_scaler
        state_dict['dynamic_loss_scale'] = self.dynamic_loss_scale
        state_dict['overflow'] = self.overflow
        state_dict['first_closure_call_this_step'] = self.first_closure_call_this_step
        state_dict['optimizer_state_dict'] = self.optimizer.state_dict()
        state_dict['fp32_from_fp16'] = self.fp32_from_fp16_groups
        return state_dict

    def load_state_dict(self, state_dict):
        """
Neel Kant's avatar
Neel Kant committed
385
386
387
        Loads a state_dict created by an earlier call to state_dict().
        If ``fp16_optimizer_instance`` was constructed from some ``init_optimizer``,
        whose parameters in turn came from ``model``, it is expected that the user
Raul Puri's avatar
Raul Puri committed
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
        will call ``model.load_state_dict()`` before
        ``fp16_optimizer_instance.load_state_dict()`` is called.

        Example::

            model = torch.nn.Linear(D_in, D_out).cuda().half()
            optimizer = torch.optim.SGD(model.parameters(), lr=1e-3)
            optimizer = FP16_Optimizer(optimizer, static_loss_scale = 128.0)
            ...
            checkpoint = torch.load("saved.pth")
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
        """
        # I think it should actually be ok to reload the optimizer before the model.
        self.loss_scaler = state_dict['loss_scaler']
        self.dynamic_loss_scale = state_dict['dynamic_loss_scale']
        self.overflow = state_dict['overflow']
        self.first_closure_call_this_step = state_dict['first_closure_call_this_step']
        self.optimizer.load_state_dict(state_dict['optimizer_state_dict'])
        # At this point, the optimizer's references to the model's fp32 parameters are up to date.
Neel Kant's avatar
Neel Kant committed
408
        # The optimizer's hyperparameters and internal buffers are also up to date.
Raul Puri's avatar
Raul Puri committed
409
        # However, the fp32 master copies of the model's fp16 params stored by the optimizer are still
Neel Kant's avatar
Neel Kant committed
410
411
        # out of date.  There are two options.
        # 1:  Refresh the master params from the model's fp16 params.
Raul Puri's avatar
Raul Puri committed
412
413
414
        # This requires less storage but incurs precision loss.
        # 2:  Save and restore the fp32 master copies separately.
        # We choose option 2.
Neel Kant's avatar
Neel Kant committed
415
416
417
418
        #
        # Pytorch Optimizer.load_state_dict casts saved buffers (e.g. momentum) to the type and device
        # of their associated parameters, because it's possible those buffers might not exist yet in
        # the current optimizer instance.  In our case, as long as the current FP16_Optimizer has been
Raul Puri's avatar
Raul Puri committed
419
420
        # constructed in the same way as the one whose state_dict we are loading, the same master params
        # are guaranteed to exist, so we can just copy_() from the saved master params.
Neel Kant's avatar
Neel Kant committed
421
422
        for current_group, saved_group in zip(
                self.fp32_from_fp16_groups, state_dict['fp32_from_fp16']):
Raul Puri's avatar
Raul Puri committed
423
424
425
            for current, saved in zip(current_group, saved_group):
                current.data.copy_(saved.data)

Neel Kant's avatar
Neel Kant committed
426
    def step(self, closure=None):  # could add clip option.
Raul Puri's avatar
Raul Puri committed
427
        """
Neel Kant's avatar
Neel Kant committed
428
        If no closure is supplied, :attr:`step` should be called after
Raul Puri's avatar
Raul Puri committed
429
430
431
432
433
434
        ``fp16_optimizer_obj.backward(loss)``.
        :attr:`step` updates the fp32 master copy of parameters using the optimizer supplied to
        :class:`FP16_Optimizer`'s constructor, then copies the updated fp32 params into the fp16 params
        originally referenced by :class:`FP16_Optimizer`'s constructor, so the user may immediately run
        another forward pass using their model.

Neel Kant's avatar
Neel Kant committed
435
        If a closure is supplied, :attr:`step` may be called without a prior call to
Raul Puri's avatar
Raul Puri committed
436
437
438
439
440
441
442
443
444
445
        :attr:`backward(loss)`.
        This control flow is identical to `ordinary Pytorch optimizer use`_ with closures.
        However, the user should take care that any ``loss.backward()`` call within the closure
        has been replaced by ``fp16_optimizer_obj.backward(loss)``.

        Args:
           closure (optional):  Closure that will be supplied to the underlying optimizer originally passed to :class:`FP16_Optimizer`'s constructor.  closure should call :attr:`zero_grad()` on the :class:`FP16_Optimizer` object, compute the loss, call :attr:`backward(loss)`, and return the loss.

        Example with closure::

Neel Kant's avatar
Neel Kant committed
446
            # optimizer is assumed to be an FP16_Optimizer object, previously constructed from an
Raul Puri's avatar
Raul Puri committed
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
            # existing pytorch optimizer.
            for input, target in dataset:
                def closure():
                    optimizer.zero_grad()
                    output = model(input)
                    loss = loss_fn(output, target)
                    # loss.backward() becomes:
                    optimizer.backward(loss)
                    return loss
                optimizer.step(closure)

        .. warning::
            Currently, calling :attr:`step` with a closure is not compatible with dynamic loss scaling.

        .. _`ordinary Pytorch optimizer use`:
            http://pytorch.org/docs/master/optim.html#optimizer-step-closure
        """

        scale = self.loss_scaler.loss_scale
        self._update_scale(self.overflow)

        if self.overflow:
            self.maybe_print("OVERFLOW! Skipping step. Attempted loss scale: {}, reducing to {}"
Neel Kant's avatar
Neel Kant committed
470
                             .format(scale, self.loss_scale))
Raul Puri's avatar
Raul Puri committed
471
            return
Neel Kant's avatar
Neel Kant committed
472

Raul Puri's avatar
Raul Puri committed
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
        if closure is not None:
            retval = self._step_with_closure(closure)
        else:
            retval = self.optimizer.step()

        self._master_params_to_model_params()

        return retval

    def _step_with_closure(self, closure):
        def wrapped_closure():
            # helpful for debugging
            # print("Calling wrapped_closure, first_closure_call_this_step = {}"
            #       .format(self.first_closure_call_this_step))
            if self.first_closure_call_this_step:
                # We expect that the fp16 params are initially fresh on entering self.step(),
                # so _master_params_to_model_params() is unnecessary the first time wrapped_closure()
                # is called within self.optimizer.step().
                self.first_closure_call_this_step = False
            else:
                # If self.optimizer.step() internally calls wrapped_closure more than once,
Neel Kant's avatar
Neel Kant committed
494
                # it may update the fp32 params after each call.  However, self.optimizer
Raul Puri's avatar
Raul Puri committed
495
496
497
498
499
500
501
                # doesn't know about the fp16 params at all.  If the fp32 params get updated,
                # we can't rely on self.optimizer to refresh the fp16 params.  We need
                # to handle that manually:
                self._master_params_to_model_params()
            # Our API expects the user to give us ownership of the backward() call by
            # replacing all calls to loss.backward() with optimizer.backward(loss).
            # This requirement holds whether or not the call to backward() is made within a closure.
Neel Kant's avatar
Neel Kant committed
502
            # If the user is properly calling optimizer.backward(loss) within "closure,"
Raul Puri's avatar
Raul Puri committed
503
            # calling closure() here will give the fp32 master params fresh gradients
Neel Kant's avatar
Neel Kant committed
504
            # for the optimizer to play with, so all wrapped_closure needs to do is call
Raul Puri's avatar
Raul Puri committed
505
            # closure() and return the loss.
Neel Kant's avatar
Neel Kant committed
506
            temp_loss = closure()
Raul Puri's avatar
Raul Puri committed
507
508
509
510
            while(self.overflow):
                scale = self.loss_scaler.loss_scale
                self._update_scale(self.overflow)
                self.maybe_print("OVERFLOW within closure! Skipping step. Attempted loss scale: {}, "
Neel Kant's avatar
Neel Kant committed
511
                                 "reducing to {}".format(scale, self.loss_scale))
Raul Puri's avatar
Raul Puri committed
512
513
514
515
516
517
518
519
520
                temp_loss = closure()
            return temp_loss

        retval = self.optimizer.step(wrapped_closure)

        self.first_closure_call_this_step = True

        return retval

521
522
    def backward(self, output_tensor, update_master_grads=True, retain_graph=False,
                 output_tensor_grad=None):
Neel Kant's avatar
Neel Kant committed
523
        """
Raul Puri's avatar
Raul Puri committed
524
525
526
527
528
529
530
531
532
533
534
535
536
        :attr:`backward` performs the following conceptual steps:

        1. fp32_loss = loss.float() (see first Note below)
        2. scaled_loss = fp32_loss*loss_scale
        3. scaled_loss.backward(), which accumulates scaled gradients into the ``.grad`` attributes of the model's leaves (which may be fp16, fp32, or a mixture, depending how your model was defined).
        4. fp16 grads are then copied to the master params' ``.grad`` attributes (see second Note), which are guaranteed to be fp32.
        5. Finally, master grads are divided by loss_scale.

        In this way, after :attr:`backward`, the master params have fresh gradients,
        and :attr:`step` may be called.

        .. note::
            :attr:`backward` internally converts the loss to fp32 before applying the loss scale.
Neel Kant's avatar
Neel Kant committed
537
538
            This provides some additional safety against overflow if the user has supplied an
            fp16 loss value.
Raul Puri's avatar
Raul Puri committed
539
            However, for maximum overflow safety, the user should
Neel Kant's avatar
Neel Kant committed
540
            compute the loss criterion (MSE, cross entropy, etc) in fp32 before supplying it to
Raul Puri's avatar
Raul Puri committed
541
542
543
            :attr:`backward`.

        .. warning::
Neel Kant's avatar
Neel Kant committed
544
545
546
547
548
549
            The gradients found in a model's leaves after the call to
            :attr:`backward` should not be regarded as valid in general,
            because it's possible
            they have been scaled (and in the case of dynamic loss scaling,
            the scale factor may change over time).
            If the user wants to inspect gradients after a call to :attr:`backward`,
Raul Puri's avatar
Raul Puri committed
550
551
552
553
554
555
556
557
558
559
560
561
562
563
            only the master gradients should be regarded as valid.  These can be retrieved via
            :attr:`inspect_master_grad_data()`.

        Args:
            loss:  The loss output by the user's model.  loss may be either float or half (but see first Note above).
            update_master_grads (bool, optional, default=True):  Option to copy fp16 grads to fp32 grads on this call.  By setting this to False, the user can delay the copy, which is useful to eliminate redundant fp16->fp32 grad copies if :attr:`backward` is being called on multiple losses in one iteration.  If set to False, the user becomes responsible for calling :attr:`update_master_grads` before calling :attr:`step`.
            retain_graph (bool, optional, default=False):  Forwards the usual ``retain_graph=True`` option to the internal call to ``loss.backward``.  If ``retain_graph`` is being used to accumulate gradient values from multiple backward passes before calling ``optimizer.step``, passing ``update_master_grads=False`` is also recommended (see Example below).

        Example::

            # Ordinary operation:
            optimizer.backward(loss)

            # Naive operation with multiple losses (technically valid, but less efficient):
Neel Kant's avatar
Neel Kant committed
564
            # fp32 grads will be correct after the second call,  but
Raul Puri's avatar
Raul Puri committed
565
566
567
568
569
            # the first call incurs an unnecessary fp16->fp32 grad copy.
            optimizer.backward(loss1)
            optimizer.backward(loss2)

            # More efficient way to handle multiple losses:
Neel Kant's avatar
Neel Kant committed
570
            # The fp16->fp32 grad copy is delayed until fp16 grads from all
Raul Puri's avatar
Raul Puri committed
571
572
573
574
            # losses have been accumulated.
            optimizer.backward(loss1, update_master_grads=False)
            optimizer.backward(loss2, update_master_grads=False)
            optimizer.update_master_grads()
Neel Kant's avatar
Neel Kant committed
575
576
        """
        # To consider:  try multiple backward passes using retain_grad=True to find
Raul Puri's avatar
Raul Puri committed
577
        # a loss scale that works.  After you find a loss scale that works, do a final dummy
Neel Kant's avatar
Neel Kant committed
578
579
        # backward pass with retain_graph=False to tear down the graph.  Doing this would avoid
        # discarding the iteration,  but probably wouldn't improve overall efficiency.
580
581
        self.loss_scaler.backward(output_tensor, retain_graph=retain_graph,
                                  output_tensor_grad=output_tensor_grad)
Raul Puri's avatar
Raul Puri committed
582
583
584
585
586
        if update_master_grads:
            self.update_master_grads()

    def update_master_grads(self):
        """
Neel Kant's avatar
Neel Kant committed
587
588
        Copy the ``.grad`` attribute from stored references to fp16 parameters to
        the ``.grad`` attribute of the fp32 master parameters that are directly
Raul Puri's avatar
Raul Puri committed
589
590
591
592
593
        updated by the optimizer.  :attr:`update_master_grads` only needs to be called if
        ``fp16_optimizer_obj.backward`` was called with ``update_master_grads=False``.
        """
        if self.dynamic_loss_scale:
            self._check_overflow()
Neel Kant's avatar
Neel Kant committed
594
595
            if self.overflow:
                return
Raul Puri's avatar
Raul Puri committed
596
597
598
599
600
        self._model_grads_to_master_grads()
        self._downscale_master()

    def inspect_master_grad_data(self):
        """
Neel Kant's avatar
Neel Kant committed
601
        When running with :class:`FP16_Optimizer`,
Raul Puri's avatar
Raul Puri committed
602
        ``.grad`` attributes of a model's fp16 leaves should not be
Neel Kant's avatar
Neel Kant committed
603
        regarded as truthful, because they might be scaled.
Raul Puri's avatar
Raul Puri committed
604
605
        After a call to :attr:`fp16_optimizer_obj.backward(loss)`, if no overflow was encountered,
        the fp32 master params' ``.grad``
Neel Kant's avatar
Neel Kant committed
606
607
        attributes will contain valid gradients properly divided by the loss scale.  However,
        because :class:`FP16_Optimizer` flattens some parameters, accessing them may be
Raul Puri's avatar
Raul Puri committed
608
609
610
611
612
        nonintuitive.  :attr:`inspect_master_grad_data`
        allows those gradients to be viewed with shapes corresponding to their associated model leaves.

        Returns:
            List of lists (one list for each parameter group).  The list for each parameter group
Neel Kant's avatar
Neel Kant committed
613
            is a list of the ``.grad.data`` attributes of the fp32 master params belonging to that group.
Raul Puri's avatar
Raul Puri committed
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
        """
        if self.overflow:
            print("Warning:  calling FP16_Optimizer.inspect_master_grad_data while in an overflow state.  "
                  "Gradients are currently invalid (may be inf, nan, or stale).  Returning None.")
            return None
        else:
            # The optimizer owns only references to master params.
            master_grads_data = []
            for param_group in self.optimizer.param_groups:
                master_grads_this_group = []
                for param in param_group['params']:
                    if param.grad is not None:
                        master_grads_this_group.append(param.grad.data)
                    else:
                        master_grads_this_group.append(None)
                master_grads_data.append(master_grads_this_group)
            return master_grads_data

    # Promote loss scale so it can be retrieved or set via "fp16_optimizer_instance.loss_scale"
Neel Kant's avatar
Neel Kant committed
633

Raul Puri's avatar
Raul Puri committed
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
    def _get_loss_scale(self):
        return self.loss_scaler.loss_scale

    def _set_loss_scale(self, value):
        self.loss_scaler.cur_scale = value

    loss_scale = property(_get_loss_scale, _set_loss_scale)

    # Promote state so it can be retrieved or set via "fp16_optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

    # Promote param_groups so it can be retrieved or set via "fp16_optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)