distrib_optimizer.py 20.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron distributed optimizer."""


import math
20
import torch
21
22

from megatron import get_args
23
24
from megatron import get_timers
from megatron import mpu
25
26
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
27
28

from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
29

Lawrence McAfee's avatar
Lawrence McAfee committed
30

31
32
class Range:

33
34
35
36
37
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
38
        return Range(start, start + self.size)
39
40
41
42
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)


43
class DistributedOptimizer(MixedPrecisionOptimizer):
44
45

    @classmethod
46
    def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
47

48
        # Param range map.
49
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
50
        param_range_map = {}
51
52
        for param, param_world_indexes in param_world_index_map.items():

53
            # Param range.
54
55
56
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
                0,
57
                param_world_start - gbuf_world_range.start)
58
            param_local_end = min(
59
60
                gbuf_world_range.size,
                param_world_end - gbuf_world_range.start)
61

62
            # Add param, if within local gbuf range.
63
            if param_local_end > param_local_start:
64
65
66
67
68
69
70
71
72
                param_local_range = Range(param_local_start, param_local_end)
                param_world_range = param_local_range.normalize(
                    param_local_start + gbuf_world_range.start)
                sub_param_start = max(0, gbuf_world_range.start-param_world_start)
                sub_param_range = param_local_range.normalize(sub_param_start)
                param_range_map[param] = {
                    "gbuf_world" : param_world_range,
                    "gbuf_local" : param_local_range,
                    "param" : sub_param_range,
73
74
                }

75
        return param_range_map
76

Lawrence McAfee's avatar
Lawrence McAfee committed
77

78
    @classmethod
79
    def build_model_gbuf_range(cls, model, dtype):
80
81
82
83

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()

84
        # Grad buffer range.
85
86
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
87
        max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
88

89
90
        # All world ranges. (i.e., across all data parallel ranks)
        gbuf_world_all_ranges = []
91
        for r in range(data_parallel_world_size):
92
93
94
95
            gbuf_world_start = r * max_gbuf_range_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
            gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_ranges.append(gbuf_world_range)
96

97
98
99
        # Local DP's ranges.
        gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
        gbuf_local_range = gbuf_world_range.normalize()
100

101
102
103
104
        # Get each param's ranges.
        param_range_map = cls.build_model_gbuf_param_range_map(model,
                                                               dtype,
                                                               gbuf_world_range)
105
106
107

        # Altogether.
        data = {
108
109
110
111
112
            "local" : gbuf_local_range,
            "world" : gbuf_world_range,
            "world_all" : gbuf_world_all_ranges,
            "param_map" : param_range_map,
            "max_range_size" : max_gbuf_range_size,
113
114
115
116
        }

        return data

Lawrence McAfee's avatar
Lawrence McAfee committed
117

118
    @classmethod
119
    def build_model_gbuf_range_map(cls, model):
120
        return {
121
            dtype : cls.build_model_gbuf_range(model, dtype)
122
123
124
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
125

126
    @classmethod
127
128
    def build_model_param_gbuf_map(cls, model_gbuf_ranges):
        '''Create a reverse of the model_gbuf_ranges, for referencing in
129
        opposite direction.'''
130
        param_gbuf_map = {}
131
132
133
        for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param, param_range_map in gbuf_range_map["param_map"].items():
134
135
136
                    param_gbuf_map[param] = (model_index, dtype)
        return param_gbuf_map

Lawrence McAfee's avatar
Lawrence McAfee committed
137

138
    @classmethod
139
    def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
140
141
142
143
144
145
146
147
148
149

        num_groups = len(param_groups)

        # Param group map.
        param_group_map = {}
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
                param_group_map[param] = group_index

150
151
152
153
154
        # Optimizer group ranges.
        group_ranges = [ {"params": []} for _ in param_groups ]
        for model_gbuf_range_map in model_gbuf_ranges:
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param in gbuf_range_map["param_map"]:
155
                    group_index = param_group_map[param]
156
157
                    group_range = group_ranges[group_index]
                    group_range["params"].append(param)
158

159
160
161
162
        # Squeeze zero-size group ranges.
        for group_index, group_range in enumerate(group_ranges):
            group_range["orig_group"] = param_groups[group_index]
        group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
163

164
        return group_ranges
165

166

167
    @classmethod
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    def build_model_and_main_param_groups(cls,
                                        model_gbuf_ranges,
                                        param_gbuf_map,
                                        opt_group_ranges):

        # Three groups of parameters:
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
        #   fp32_groups: original fp32 parameters
        full_float16_groups = []
        full_fp32_groups = []
        shard_float16_groups = []
        shard_fp32_groups = []
        shard_fp32_from_float16_groups = []

Lawrence McAfee's avatar
Lawrence McAfee committed
183
        # Allocate (or slice) each group's param shard.
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
        for group_index, group_range in enumerate(opt_group_ranges):

            # Params of this group.
            full_float16_params_this_group = []
            full_fp32_params_this_group = []
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_fp32_from_float16_params_this_group = []
            full_float16_groups.append(full_float16_params_this_group)
            full_fp32_groups.append(full_fp32_params_this_group)
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)
            shard_fp32_from_float16_groups.append(
                shard_fp32_from_float16_params_this_group)

            for model_param in group_range["params"]:

201
202
                assert model_param.requires_grad

203
204
205
                model_index, dtype = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[model_index][dtype]
                param_range = gbuf_range["param_map"][model_param]["param"]
206
207

                # fp16, bf16 params.
208
209
210
211
                if model_param.type() in ['torch.cuda.HalfTensor',
                                          'torch.cuda.BFloat16Tensor']:

                    # Clone model -> main.
Lawrence McAfee's avatar
Lawrence McAfee committed
212
213
                    shard_model_param = model_param.detach().view(-1) \
                        [param_range.start:param_range.end]
214
215
216
217
218
219
220
221
222
223
224
225
226
                    shard_main_param = shard_model_param.clone().float()
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_main_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                        shard_main_param.shared = model_param.shared

                    # Add to group.
                    full_float16_params_this_group.append(model_param)
                    shard_float16_params_this_group.append(shard_model_param)
                    shard_fp32_from_float16_params_this_group.append(shard_main_param)
227
228

                # fp32 params.
229
                elif model_param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
230
231
                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
232
233
                    full_fp32_params_this_group.append(model_param)
                    shard_fp32_params_this_group.append(shard_model_param)
234
235
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
236
237
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
238
239
240
241
242
243
244
245

                else:
                    raise TypeError('Wrapped parameters must be one of '
                                    'torch.cuda.FloatTensor,  '
                                    'torch.cuda.HalfTensor, or '
                                    'torch.cuda.BFloat16Tensor. '
                                    'Received {}'.format(param.type()))

Lawrence McAfee's avatar
Lawrence McAfee committed
246
            # Update optimizer's params.
247
248
249
250
251
252
253
254
255
256
257
258
            group_range["orig_group"]["params"] = [
                *shard_fp32_params_this_group,
                *shard_fp32_from_float16_params_this_group,
            ]

        return (
            full_float16_groups,
            full_fp32_groups,
            shard_float16_groups,
            shard_fp32_groups,
            shard_fp32_from_float16_groups,
        )
259

Lawrence McAfee's avatar
Lawrence McAfee committed
260

261
262
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
263
                 fp16, bf16, grad_scaler, models):
264
265
266
267

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
268
            fp16, bf16, grad_scaler, models)
269

270
271
        # Verify that contiguous buffers are being used
        # - Note: this should already be checked in arguments.py
272
        assert use_contiguous_buffers_in_local_ddp
273

274
275
        # Model grad buffer ranges.
        self.model_gbuf_ranges = []
276
        for model_index, model in enumerate(self.models):
277
278
279
            self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
        self.model_param_gbuf_map = \
            self.build_model_param_gbuf_map(self.model_gbuf_ranges)
280

281
282
        # Optimizer ranges.
        self.opt_group_ranges = self.build_optimizer_group_ranges(
283
            self.optimizer.param_groups,
284
            self.model_gbuf_ranges)
285
286

        # Allocate main param shards.
287
288
289
290
291
292
293
294
295
296
        (
            self.full_float16_groups,
            self.full_fp32_groups,
            self.shard_float16_groups,
            self.shard_fp32_groups,
            self.shard_fp32_from_float16_groups,
        ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
                                                   self.model_param_gbuf_map,
                                                   self.opt_group_ranges)

297
298
299
300
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
301
            [ g["orig_group"] for g in self.opt_group_ranges ]
302
303
        self.optimizer.load_state_dict(self.optimizer.state_dict())

304

305
306
307
308
309
310
311
    def get_model_param_range_map(self, param):
        model_index, dtype = self.model_param_gbuf_map[param]
        gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
        param_range_map = gbuf_range_map["param_map"][param]
        return param_range_map


312
313
314
    def get_model_parallel_group(self):
        return None

315
316

    def state_dict(self):
317
318
319
320
321
322
323
324
325
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['shard_fp32_from_float16_groups'] = \
            self.shard_fp32_from_float16_groups
        return state_dict


326
    def load_state_dict(self, state_dict):
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353

        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')

        # Copy data for the main params.
        for current_group, saved_group in zip(
                self.shard_fp32_from_float16_groups,
                state_dict["shard_fp32_from_float16_groups"]):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)
Lawrence McAfee's avatar
Lawrence McAfee committed
354

355

356
357
358
359
360
361
362
363
364
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for groups in (
                self.full_float16_groups,
                self.full_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
365
                self.shard_float16_groups, # grad empty/unused here?
366
                self.shard_fp32_groups, # throws grad-access warning
367
368
369
                self.shard_fp32_from_float16_groups):
            for group in groups:
                _zero_grad_group_helper(group, set_to_none)
370

371

372
373
374
375
376
377
378
379
380
381
382
383
384
    def get_model_grad_buffer_dp_views(self):

        data_parallel_world_size = mpu.get_data_parallel_world_size()

        # Grad buffer views.
        gbuf_view_items = []
        for model_index, model in enumerate(self.models):
            for dtype, gbuf in model._grad_buffers.items():

                assert gbuf.numel_padded % data_parallel_world_size == 0
                shard_size = int(gbuf.numel_padded / data_parallel_world_size)
                gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
                              for r in range(data_parallel_world_size)]
385
                gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
386
387

        return gbuf_view_items
388

Lawrence McAfee's avatar
Lawrence McAfee committed
389

390
    def reduce_model_grads(self, args, timers):
391
392
393
394
        '''Note: this is a different order of reduction, versus the non-
           distributed optimizer, which reduces: 1) all grads, 2) embedding
           grads.
        '''
395

396
397
398
399
400
        # All-reduce embedding grads.
        timers('backward-embedding-all-reduce').start()
        self.allreduce_embedding_grads(args)
        timers('backward-embedding-all-reduce').stop()

401
        # Reduce-scatter setup.
402
403
404
405
406
        timers('backward-params-all-reduce').start()
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_group = mpu.get_data_parallel_group()

407
408
409
410
411
        # Scale grad buffers by '1 / data_parallel_world_size'.
        for model in self.models:
            for dtype, gbuf in model._grad_buffers.items():
                gbuf.data /= data_parallel_world_size

412
        # Reduce-scatter all grads.
413
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
414
        for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
415
            torch.distributed._reduce_scatter_base(
416
                gbuf_views[data_parallel_rank],
417
                gbuf,
418
419
                group = data_parallel_group,
            )
420

421
        timers('backward-params-all-reduce').stop()
422

Lawrence McAfee's avatar
Lawrence McAfee committed
423

424
    def gather_model_params(self, args, timers):
425
426
427
428
429
430
431
432
433
434
435

        timers('backward-params-all-gather').start()

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()

        # All-gather updated main params.
        # - All grad buffer views are guaranteed to have the same num elements
        #   across all data parallel ranks, with grad buffer padding that is done
        #   in distributed.py. Thus, all sub-views will have consistent start/end
        #   indexes across data parallel ranks.
436
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
437
        for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
438
            torch.distributed._all_gather_base(
439
                gbuf,
440
441
                gbuf_views[data_parallel_rank],
                group = data_parallel_group,
442
            )
443
444
445
446
447
448
449
450
451

        # Each model param now contains its updated values in its
        # '.main_grad' field.
        for model in self.models:
            for dtype, param_map in model._grad_buffer_param_index_map.items():
                for param in param_map:
                    param.detach().copy_(param.main_grad)

        timers('backward-params-all-gather').stop()
452

Lawrence McAfee's avatar
Lawrence McAfee committed
453

454
    def _collect_main_grad_data_for_unscaling(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
455
        return [
Lawrence McAfee's avatar
Lawrence McAfee committed
456
457
458
459
            param.grad.data
            for group in self.optimizer.param_groups
            for param in group["params"]
        ]
460
461


Lawrence McAfee's avatar
Lawrence McAfee committed
462
463
464
465
466
467
468
469
470
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.shard_float16_groups,
                                           self.shard_fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
471
472


473
    def _copy_model_grads_to_main_grads(self):
474

Lawrence McAfee's avatar
Lawrence McAfee committed
475
476
477
478
479
480
481
482
        def copy_group_grads(full_model_groups, shard_main_groups):
            for full_model_group, shard_main_group in zip(full_model_groups,
                                                          shard_main_groups):
                for full_model_param, shard_main_param in zip(full_model_group,
                                                              shard_main_group):

                    param_range_map = self.get_model_param_range_map(full_model_param)
                    param_range = param_range_map["param"]
Lawrence McAfee's avatar
Lawrence McAfee committed
483
484
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
485
                    full_model_grad = full_model_param.main_grad
Lawrence McAfee's avatar
Lawrence McAfee committed
486
487
                    shard_model_grad = full_model_grad.view(-1) \
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
488
489
490
491
492
493
                    shard_main_param.grad = shard_model_grad.float()

        copy_group_grads(self.full_float16_groups,
                         self.shard_fp32_from_float16_groups)
        copy_group_grads(self.full_fp32_groups,
                         self.shard_fp32_groups)
494

495
496
497

    def _copy_main_params_to_model_params(self):

Lawrence McAfee's avatar
Lawrence McAfee committed
498
499
500
501
502
503
504
505
506
507
508
509
510
        def copy_group_params(shard_main_groups, full_model_groups):
            for shard_main_group, full_model_group in zip(shard_main_groups,
                                                          full_model_groups):
                for shard_main_param, full_model_param in zip(shard_main_group,
                                                              full_model_group):

                    param_range_map = self.get_model_param_range_map(full_model_param)
                    param_range = param_range_map["param"]
                    assert param_range.size == shard_main_param.nelement()

                    full_model_grad = full_model_param.main_grad
                    shard_model_grad = full_model_grad.view(-1) \
                        [param_range.start:param_range.end]
511

512
                    shard_model_grad.data.copy_(shard_main_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
513
514
515
516
517

        copy_group_params(self.shard_fp32_from_float16_groups,
                          self.full_float16_groups)
        copy_group_params(self.shard_fp32_groups,
                          self.full_fp32_groups)