loss_scaler.py 10.8 KB
Newer Older
Raul Puri's avatar
Raul Puri committed
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
Raul Puri's avatar
Raul Puri committed
3
4
5
6
7
8
9
10
11
12
13
14
15
16
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch
Mohammad's avatar
Mohammad committed
17
18
19
20

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
21
from megatron import mpu
Raul Puri's avatar
Raul Puri committed
22
23

# item() is a recent addition, so this helps with backward compatibility.
Neel Kant's avatar
Neel Kant committed
24
25


Raul Puri's avatar
Raul Puri committed
26
27
28
29
30
31
def to_python_float(t):
    if hasattr(t, 'item'):
        return t.item()
    else:
        return t[0]

Neel Kant's avatar
Neel Kant committed
32

Raul Puri's avatar
Raul Puri committed
33
34
35
36
37
class LossScaler:
    """
    Class that manages a static loss scale.  This class is intended to interact with
    :class:`FP16_Optimizer`, and should not be directly manipulated by the user.

Neel Kant's avatar
Neel Kant committed
38
    Use of :class:`LossScaler` is enabled via the ``static_loss_scale`` argument to
Raul Puri's avatar
Raul Puri committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
    :class:`FP16_Optimizer`'s constructor.

    Args:
        scale (float, optional, default=1.0):  The loss scale.
    """

    def __init__(self, scale=1):
        self.cur_scale = scale

    # `params` is a list / generator of torch.Variable
    def has_overflow(self, params):
        return False

    # `x` is a torch.Tensor
    def _has_inf_or_nan(x):
        return False

    def update_scale(self, overflow):
        pass

    @property
    def loss_scale(self):
        return self.cur_scale

    def scale_gradient(self, module, grad_in, grad_out):
Mohammad's avatar
Mohammad committed
64
65
66
67
68
69
        _overflow_buf = torch.cuda.IntTensor([0])
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             _overflow_buf,
                             [grad_in, grad_in],
                             self.loss_scale)
        return grad_in
Raul Puri's avatar
Raul Puri committed
70

71
72
73
74
75
76
77
    def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
        if output_tensor_grad is None:
            scaled_output_tensor = output_tensor * self.loss_scale
        else:
            scaled_output_tensor = output_tensor
        torch.autograd.backward(scaled_output_tensor, grad_tensors=output_tensor_grad,
                                retain_graph=retain_graph)
Raul Puri's avatar
Raul Puri committed
78

Neel Kant's avatar
Neel Kant committed
79

Raul Puri's avatar
Raul Puri committed
80
81
82
class DynamicLossScaler:
    """
    Class that manages dynamic loss scaling.  It is recommended to use :class:`DynamicLossScaler`
Neel Kant's avatar
Neel Kant committed
83
    indirectly, by supplying ``dynamic_loss_scale=True`` to the constructor of
Raul Puri's avatar
Raul Puri committed
84
85
86
87
88
89
90
    :class:`FP16_Optimizer`.  However, it's important to understand how :class:`DynamicLossScaler`
    operates, because the default options can be changed using the
    the ``dynamic_loss_args`` argument to :class:`FP16_Optimizer`'s constructor.

    Loss scaling is designed to combat the problem of underflowing gradients encountered at long
    times when training fp16 networks.  Dynamic loss scaling begins by attempting a very high loss
    scale.  Ironically, this may result in OVERflowing gradients.  If overflowing gradients are
Neel Kant's avatar
Neel Kant committed
91
    encountered, :class:`DynamicLossScaler` informs :class:`FP16_Optimizer` that an overflow has
Raul Puri's avatar
Raul Puri committed
92
93
    occurred.
    :class:`FP16_Optimizer` then skips the update step for this particular iteration/minibatch,
Neel Kant's avatar
Neel Kant committed
94
    and :class:`DynamicLossScaler` adjusts the loss scale to a lower value.
Raul Puri's avatar
Raul Puri committed
95
96
    If a certain number of iterations occur without overflowing gradients detected,
    :class:`DynamicLossScaler` increases the loss scale once more.
Neel Kant's avatar
Neel Kant committed
97
    In this way :class:`DynamicLossScaler` attempts to "ride the edge" of
Raul Puri's avatar
Raul Puri committed
98
99
100
101
    always using the highest loss scale possible without incurring overflow.

    Args:
        init_scale (float, optional, default=2**32):  Initial loss scale attempted by :class:`DynamicLossScaler.`
Neel Kant's avatar
Neel Kant committed
102
        scale_factor (float, optional, default=2.0):  Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``.  If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
Raul Puri's avatar
Raul Puri committed
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
        scale_window (int, optional, default=1000):  Number of consecutive iterations without an overflow to wait before increasing the loss scale.
    """

    def __init__(self,
                 init_scale=2**32,
                 scale_factor=2.,
                 scale_window=1000,
                 min_scale=1,
                 delayed_shift=1,
                 consecutive_hysteresis=False):
        self.cur_scale = init_scale
        self.cur_iter = 0
        self.last_overflow_iter = -1
        self.scale_factor = scale_factor
        self.scale_window = scale_window
        self.min_scale = min_scale
        self.delayed_shift = delayed_shift
        self.cur_hysteresis = delayed_shift
        self.consecutive_hysteresis = consecutive_hysteresis

    # `params` is a list / generator of torch.Variable
124
    def has_overflow_serial(self, params):
Raul Puri's avatar
Raul Puri committed
125
126
127
128
129
130
        for p in params:
            if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
                return True

        return False

131
132
133
134
135
136
137
138
139
140
141
    def has_overflow(self, params):
        overflow = self.has_overflow_serial(params)
        # Since each model parallel GPU carries only part of the model,
        # make sure overflow flag is synced across all the model parallel GPUs
        overflow_gpu = torch.cuda.ByteTensor([overflow])
        torch.distributed.all_reduce(overflow_gpu,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=mpu.get_model_parallel_group())
        overflow = overflow_gpu[0].item()
        return bool(overflow)

Raul Puri's avatar
Raul Puri committed
142
    # `x` is a torch.Tensor
Neel Kant's avatar
Neel Kant committed
143

Raul Puri's avatar
Raul Puri committed
144
145
    def _has_inf_or_nan(x):
        try:
Neel Kant's avatar
Neel Kant committed
146
147
            # if x is half, the .float() incurs an additional deep copy, but it's necessary if
            # Pytorch's .sum() creates a one-element tensor of the same type as x
Raul Puri's avatar
Raul Puri committed
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
            # (which is true for some recent version of pytorch).
            cpu_sum = float(x.float().sum())
            # More efficient version that can be used if .sum() returns a Python scalar
            # cpu_sum = float(x.sum())
        except RuntimeError as instance:
            # We want to check if inst is actually an overflow exception.
            # RuntimeError could come from a different error.
            # If so, we still want the exception to propagate.
            if "value cannot be converted" not in instance.args[0]:
                raise
            return True
        else:
            if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
                return True
            return False

    # `overflow` is boolean indicating whether the gradient overflowed
    def update_scale(self, overflow):
166

Raul Puri's avatar
Raul Puri committed
167
168
169
170
171
172
173
174
175
176
177
        if not hasattr(self, 'min_scale'):
            self.min_scale = 1
        if not hasattr(self, 'delayed_shift'):
            self.delayed_shift = 1
        if not hasattr(self, 'cur_hysteresis'):
            self.cur_hysteresis = 1
        if not hasattr(self, 'consecutive_hysteresis'):
            self.consecutive_hysteresis = True
        if overflow:
            # self.cur_scale /= self.scale_factor
            if self.delayed_shift == 1 or self.cur_hysteresis == 1:
Neel Kant's avatar
Neel Kant committed
178
                self.cur_scale = max(self.cur_scale / self.scale_factor, self.min_scale)
Raul Puri's avatar
Raul Puri committed
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
            else:
                self.cur_hysteresis -= 1
            self.last_overflow_iter = self.cur_iter
        else:
            if self.consecutive_hysteresis:
                self.cur_hysteresis = self.delayed_shift
            if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
                if not self.consecutive_hysteresis:
                    self.cur_hysteresis = self.delayed_shift
                self.cur_scale *= self.scale_factor
        self.cur_iter += 1

    @property
    def loss_scale(self):
        return self.cur_scale

    def scale_gradient(self, module, grad_in, grad_out):
Mohammad's avatar
Mohammad committed
196
197
198
199
200
201
        _overflow_buf = torch.cuda.IntTensor([0])
        multi_tensor_applier(amp_C.multi_tensor_scale,
                             _overflow_buf,
                             [grad_in, grad_in],
                             self.loss_scale)
        return grad_in
Raul Puri's avatar
Raul Puri committed
202

203
204
205
206
207
208
209
    def backward(self, output_tensor, retain_graph=False, output_tensor_grad=None):
        if output_tensor_grad is None:
            scaled_output_tensor = output_tensor * self.loss_scale
        else:
            scaled_output_tensor = output_tensor
        torch.autograd.backward(scaled_output_tensor, grad_tensors=output_tensor_grad,
                                retain_graph=retain_graph)
Neel Kant's avatar
Neel Kant committed
210
211
212


##############################################################
Raul Puri's avatar
Raul Puri committed
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
# Example usage below here -- assuming it's in a separate file
##############################################################
"""
TO-DO separate out into an example.
if __name__ == "__main__":
    import torch
    from torch.autograd import Variable
    from dynamic_loss_scaler import DynamicLossScaler

    # N is batch size; D_in is input dimension;
    # H is hidden dimension; D_out is output dimension.
    N, D_in, H, D_out = 64, 1000, 100, 10

    # Create random Tensors to hold inputs and outputs, and wrap them in Variables.
    x = Variable(torch.randn(N, D_in), requires_grad=False)
    y = Variable(torch.randn(N, D_out), requires_grad=False)

    w1 = Variable(torch.randn(D_in, H), requires_grad=True)
    w2 = Variable(torch.randn(H, D_out), requires_grad=True)
    parameters = [w1, w2]

    learning_rate = 1e-6
    optimizer = torch.optim.SGD(parameters, lr=learning_rate)
    loss_scaler = DynamicLossScaler()

    for t in range(500):
        y_pred = x.mm(w1).clamp(min=0).mm(w2)
        loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
        print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
        print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
        print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))

        # Run backprop
        optimizer.zero_grad()
        loss.backward()
Neel Kant's avatar
Neel Kant committed
248

Raul Puri's avatar
Raul Puri committed
249
250
        # Check for overflow
        has_overflow = DynamicLossScaler.has_overflow(parameters)
Neel Kant's avatar
Neel Kant committed
251

Raul Puri's avatar
Raul Puri committed
252
253
254
255
256
257
258
259
260
261
262
263
264
        # If no overflow, unscale grad and update as usual
        if not has_overflow:
            for param in parameters:
                param.grad.data.mul_(1. / loss_scaler.loss_scale)
            optimizer.step()
        # Otherwise, don't do anything -- ie, skip iteration
        else:
            print('OVERFLOW!')

        # Update loss scale for next iteration
        loss_scaler.update_scale(has_overflow)

"""