optimizer.py 15.3 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron optimizer."""
mohammad's avatar
mohammad committed
17
18
19
20
21
22
23
24
25

from abc import ABC
from abc import abstractmethod

import torch

from apex.multi_tensor_apply import multi_tensor_applier
import amp_C

mohammad's avatar
mohammad committed
26
27
from megatron import get_timers
from megatron import mpu
mohammad's avatar
mohammad committed
28
29
from megatron import print_rank_0

Rewon Child's avatar
Rewon Child committed
30
from .clip_grads import clip_grad_norm_fp32, count_zeros_fp32
mohammad's avatar
mohammad committed
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47


def _zero_grad_group_helper(group, set_to_none):
    """Zero out the gradient for a group of parameters.
    Note: copied from torch.optim.optimizer."""
    for param in group:
        if param.grad is not None:
            if set_to_none:
                param.grad = None
            else:
                if param.grad.grad_fn is not None:
                    param.grad.detach_()
                else:
                    param.grad.requires_grad_(False)
                param.grad.zero_()


48
49
50
51
52
53
54
55
56
57
58
59
def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
    """Use multi-tensor-applier to copy values from one list to another."""
    if overflow_buf:
        overflow_buf.fill_(0)
    else:
        overflow_buf = torch.cuda.IntTensor([0])
    # Scaling with factor `1.0` is equivalent to copy.
    multi_tensor_applier(amp_C.multi_tensor_scale,
                         overflow_buf,
                         [this, that],
                         1.0)

mohammad's avatar
mohammad committed
60
61
62
63
64
65
66
67

class MegatronOptimizer(ABC):

    def __init__(self, optimizer):
        """Input optimizer is the base optimizer for example Adam."""
        self.optimizer = optimizer
        assert self.optimizer, 'no optimizer is provided.'

Rewon Child's avatar
Rewon Child committed
68
    def get_parameters(self):
69
70
71
72
        params = []
        for param_group in self.optimizer.param_groups:
            for param in param_group['params']:
                params.append(param)
Rewon Child's avatar
Rewon Child committed
73
74
75
76
        return params

    def clip_grad_norm(self, clip_grad):
        params = self.get_parameters()
77
        return clip_grad_norm_fp32(params, clip_grad)
78

Rewon Child's avatar
Rewon Child committed
79
80
81
82
    def count_zeros(self):
        params = self.get_parameters()
        return count_zeros_fp32(params)

mohammad's avatar
mohammad committed
83
84
85
86
87
88
    @abstractmethod
    def zero_grad(self, set_to_none=True):
        pass

    @abstractmethod
    def get_loss_scale(self):
89
        """The output should be a cuda tensor of size 1."""
mohammad's avatar
mohammad committed
90
91
92
93
94
95
96
97
98
99
        pass

    def scale_loss(self, loss):
        """Simple scaling."""
        return self.get_loss_scale() * loss

    @abstractmethod
    def step(self):
        pass

100
101
    @abstractmethod
    def reload_model_params(self):
102
103
104
105
106
        """Refreshes any internal state from the current model parameters.
        Call whenever the parameters are changed outside of the optimizer.
        For example, when we load a model from a checkpoint  without loading
        the optimizer, the model parameters are updated but for fp16 optimizer
        with main parameters, the main parameters need to also be updated."""
107
108
        pass

mohammad's avatar
mohammad committed
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
    @abstractmethod
    def state_dict(self):
        pass

    @abstractmethod
    def load_state_dict(self, state_dict):
        pass

    # Promote state so it can be retrieved or set via
    # "optimizer_instance.state"
    def _get_state(self):
        return self.optimizer.state

    def _set_state(self, value):
        self.optimizer.state = value

    state = property(_get_state, _set_state)

    # Promote param_groups so it can be retrieved or set via
    # "optimizer_instance.param_groups"
    # (for example, to adjust the learning rate)
    def _get_param_groups(self):
        return self.optimizer.param_groups

    def _set_param_groups(self, value):
        self.optimizer.param_groups = value

    param_groups = property(_get_param_groups, _set_param_groups)



class FP16OptimizerWithFP16Params(MegatronOptimizer):

142
    def __init__(self, optimizer, grad_scaler, clip_grad, log_num_zeros_in_grad):
mohammad's avatar
mohammad committed
143
144
145
146
        super(FP16OptimizerWithFP16Params, self).__init__(optimizer)

        self.grad_scaler = grad_scaler
        self.clip_grad = clip_grad
147
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
mohammad's avatar
mohammad committed
148
149
150
151
152
153
154
155
156

        # Tensor used to determine if a nan/if has happend.
        # Any non-zero value indicates inf/nan.
        self.found_inf = torch.cuda.FloatTensor([0.0])

        # Dummy tensor needed for apex multi-apply tensor.
        self._dummy_overflow_buf = torch.cuda.IntTensor([0])

        # ======================
157
        # main parameter stuff
mohammad's avatar
mohammad committed
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
        # ======================

        # Three groups of parameters:
        #   fp16_groups: original fp16 parameters
        #   fp32_from_fp16_groups: fp32 copy of fp16 parameters
        #   fp32_from_fp32_groups: original fp32 parameters
        self.fp16_groups = []
        self.fp32_from_fp16_groups = []
        self.fp32_from_fp32_groups = []

        # For all the groups in the original optimizer:
        for param_group in self.optimizer.param_groups:
            fp16_params_this_group = []
            fp32_params_this_group = []
            fp32_from_fp16_params_this_group = []
            # For all the parameters in this group:
            for i, param in enumerate(param_group['params']):
                if param.requires_grad:

                    # fp16 params:
                    if param.type() == 'torch.cuda.HalfTensor':
                        fp16_params_this_group.append(param)
                        # Create a copy
181
                        main_param = param.detach().clone().float()
mohammad's avatar
mohammad committed
182
                        # Store grads
183
                        main_param.requires_grad = True
mohammad's avatar
mohammad committed
184
                        # Copy tensor model parallel attributes.
185
                        mpu.copy_tensor_model_parallel_attributes(main_param,
mohammad's avatar
mohammad committed
186
                                                                  param)
187
                        if hasattr(param, 'shared'):
188
                            main_param.shared = param.shared
mohammad's avatar
mohammad committed
189
                        # Replace the optimizer params with the new fp32 copy.
190
191
192
                        param_group['params'][i] = main_param
                        fp32_from_fp16_params_this_group.append(main_param)
                        # Reset existing state dict key to the new main param.
mohammad's avatar
mohammad committed
193
                        if param in self.optimizer.state:
194
                            self.optimizer.state[main_param] \
mohammad's avatar
mohammad committed
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
                                = self.optimizer.state.pop(param)

                    # fp32 params.
                    elif param.type() == 'torch.cuda.FloatTensor':
                        fp32_params_this_group.append(param)
                        param_group['params'][i] = param

                    else:
                        raise TypeError("Wrapped parameters must be either "
                                        "torch.cuda.FloatTensor or "
                                        "torch.cuda.HalfTensor. "
                                        "Received {}".format(param.type()))

            self.fp16_groups.append(fp16_params_this_group)
            self.fp32_from_fp16_groups.append(fp32_from_fp16_params_this_group)
            self.fp32_from_fp32_groups.append(fp32_params_this_group)

        # Leverage state_dict() and load_state_dict() to
        # recast preexisting per-param state tensors
        self.optimizer.load_state_dict(self.optimizer.state_dict())


    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
                fp16_groups & fp32_from_fp32_groups."""
        for group in self.fp16_groups:
            _zero_grad_group_helper(group, set_to_none)
        for group in self.fp32_from_fp32_groups:
            _zero_grad_group_helper(group, set_to_none)


    def get_loss_scale(self):
        return self.grad_scaler.scale


230
    def _copy_model_grads_to_main_grads(self):
mohammad's avatar
mohammad committed
231
232
        # This only needs to be done for the fp16 group.
        model_grads = []
233
234
        main_grads = []
        for model_group, main_group in zip(self.fp16_groups,
mohammad's avatar
mohammad committed
235
                                           self.fp32_from_fp16_groups):
236
            for model_param, main_param in zip(model_group, main_group):
mohammad's avatar
mohammad committed
237
                if model_param.grad is not None:
238
239
                    if main_param.grad is None:
                        main_param.grad = torch.empty_like(main_param)
mohammad's avatar
mohammad committed
240
                    model_grads.append(model_param.grad.data)
241
242
243
                    main_grads.append(main_param.grad.data)
        _multi_tensor_copy_this_to_that(this=model_grads, that=main_grads,
                                        overflow_buf=self._dummy_overflow_buf)
mohammad's avatar
mohammad committed
244
245


246
247
    def _unscale_main_grads_and_check_for_nan(self):
        main_grads = []
mohammad's avatar
mohammad committed
248
        # fp32 params fromm fp16 ones.
249
250
251
252
        for main_group in self.fp32_from_fp16_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
mohammad's avatar
mohammad committed
253
        # Append fp32 parameters.
254
255
256
257
        for main_group in self.fp32_from_fp32_groups:
            for main_param in main_group:
                if main_param.grad is not None:
                    main_grads.append(main_param.grad.data)
mohammad's avatar
mohammad committed
258
259
260
261
        # Reset found inf.
        self.found_inf.fill_(0.0)
        # Unscale and set found inf/nan
        torch._amp_foreach_non_finite_check_and_unscale_(
262
            main_grads, self.found_inf, self.grad_scaler.inv_scale)
mohammad's avatar
mohammad committed
263
264
265
266
        # Update across all model parallel instances.
        torch.distributed.all_reduce(self.found_inf,
                                     op=torch.distributed.ReduceOp.MAX,
                                     group=mpu.get_model_parallel_group())
mohammad's avatar
mohammad committed
267
268
269
270
271
272

        # Check for nan.
        found_inf_flag = (self.found_inf.item() > 0)
        return found_inf_flag


273
    def _get_model_and_main_params_data_fp16(self):
mohammad's avatar
mohammad committed
274
        model_data = []
275
276
        main_data = []
        for model_group, main_group in zip(self.fp16_groups,
mohammad's avatar
mohammad committed
277
                                           self.fp32_from_fp16_groups):
278
            for model_param, main_param in zip(model_group, main_group):
mohammad's avatar
mohammad committed
279
                model_data.append(model_param.data)
280
281
                main_data.append(main_param.data)
        return model_data, main_data
282
283


284
    def _copy_main_params_to_model_params(self):
285
        # Only needed for the fp16 params.
286
287
288
289
290
291
        model_data, main_data = self._get_model_and_main_params_data_fp16()
        _multi_tensor_copy_this_to_that(this=main_data, that=model_data,
                                        overflow_buf=self._dummy_overflow_buf)


    def _copy_model_params_to_main_params(self):
292
        # Only needed for the fp16 params.
293
294
295
        model_data, main_data = self._get_model_and_main_params_data_fp16()
        _multi_tensor_copy_this_to_that(this=model_data, that=main_data,
                                        overflow_buf=self._dummy_overflow_buf)
296
297
298


    def reload_model_params(self):
299
        self._copy_model_params_to_main_params()
mohammad's avatar
mohammad committed
300

mohammad's avatar
mohammad committed
301
302
303
304
305
    @torch.no_grad()
    def step(self):

        timers = get_timers()

306
307
308
309
        # Copy gradients from model params to main params.
        timers('optimizer-copy-to-main-grad').start()
        self._copy_model_grads_to_main_grads()
        timers('optimizer-copy-to-main-grad').stop()
mohammad's avatar
mohammad committed
310
311
312

        # Unscale and check for inf/nan.
        timers('optimizer-unscale-and-check-inf').start()
313
        found_inf_flag = self._unscale_main_grads_and_check_for_nan()
mohammad's avatar
mohammad committed
314
        timers('optimizer-unscale-and-check-inf').stop()
mohammad's avatar
mohammad committed
315
316
317
318
319
320
321

        # We are done with scaling gradients
        # so we can update the loss scale.
        self.grad_scaler.update(found_inf_flag)

        # If we found inf/nan, skip the update.
        if found_inf_flag:
Rewon Child's avatar
Rewon Child committed
322
            return False, None, None
mohammad's avatar
mohammad committed
323

324
325
        # Clip the main gradients.
        timers('optimizer-clip-main-grad').start()
326
327
328
        grad_norm = None
        if self.clip_grad > 0.0:
            grad_norm = self.clip_grad_norm(self.clip_grad)
329
        timers('optimizer-clip-main-grad').stop()
mohammad's avatar
mohammad committed
330

Rewon Child's avatar
Rewon Child committed
331
        # count the zeros in the grads
332
        num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None
Rewon Child's avatar
Rewon Child committed
333

mohammad's avatar
mohammad committed
334
335
336
        # Step the optimizer.
        self.optimizer.step()

337
338
339
340
        # Update params from main params.
        timers('optimizer-copy-main-to-model-params').start()
        self._copy_main_params_to_model_params()
        timers('optimizer-copy-main-to-model-params').stop()
mohammad's avatar
mohammad committed
341

mohammad's avatar
mohammad committed
342
        # Successful update.
343
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
344
345


mohammad's avatar
mohammad committed
346
347
348
349
350
351
352
353
354
    def state_dict(self):
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['fp32_from_fp16_params'] = self.fp32_from_fp16_groups
        return state_dict


    def load_state_dict(self, state_dict):
mohammad's avatar
mohammad committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
            self.grad_scaler.load_state_dict(state_dict['grad_scaler'])

370
        # Copy data for the main params.
mohammad's avatar
mohammad committed
371
372
373
        fp32_from_fp16_params_key = 'fp32_from_fp16_params'
        if fp32_from_fp16_params_key not in state_dict:
            fp32_from_fp16_params_key = 'fp32_from_fp16'
mohammad's avatar
mohammad committed
374
375
        for current_group, saved_group in zip(
                self.fp32_from_fp16_groups,
mohammad's avatar
mohammad committed
376
                state_dict[fp32_from_fp16_params_key]):
mohammad's avatar
mohammad committed
377
378
379
380
381
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)



mohammad's avatar
mohammad committed
382
383
class FP32Optimizer(MegatronOptimizer):

384
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad):
mohammad's avatar
mohammad committed
385
386
387

        super(FP32Optimizer, self).__init__(optimizer)
        self.clip_grad = clip_grad
388
        self.log_num_zeros_in_grad = log_num_zeros_in_grad
mohammad's avatar
mohammad committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
        self._scale = torch.cuda.FloatTensor([1.0])


    def zero_grad(self, set_to_none=True):
        """Copied from torch.optim.optimizer"""
        for group in self.optimizer.param_groups:
            _zero_grad_group_helper(group['params'], set_to_none)


    def get_loss_scale(self):
        """FP32 optimizer does not do any scaling."""
        return self._scale


    @torch.no_grad()
    def step(self):
        """Clip gradients (if needed) and step the base optimizer.
mohammad's avatar
mohammad committed
406
        Always return successful since there is no overflow."""
mohammad's avatar
mohammad committed
407
408

        # Clip gradients.
409
        grad_norm = None
mohammad's avatar
mohammad committed
410
        if self.clip_grad > 0.0:
411
            grad_norm = self.clip_grad_norm(self.clip_grad)
mohammad's avatar
mohammad committed
412

Rewon Child's avatar
Rewon Child committed
413
        # count the zeros in the grads
414
        num_zeros_in_grad = self.count_zeros() if self.log_num_zeros_in_grad else None
Rewon Child's avatar
Rewon Child committed
415

mohammad's avatar
mohammad committed
416
417
418
419
        # Update parameters.
        self.optimizer.step()

        # No overflow for FP32 optimizer.
420
        return True, grad_norm, num_zeros_in_grad
mohammad's avatar
mohammad committed
421
422


423
424
425
426
    def reload_model_params(self):
        pass


mohammad's avatar
mohammad committed
427
428
429
430
431
432
    def state_dict(self):
        return self.optimizer.state_dict()


    def load_state_dict(self, state_dict):
        self.optimizer.load_state_dict(state_dict)