distrib_optimizer.py 21 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron distributed optimizer."""


import math
20
import torch
21
22

from megatron import get_args
23
24
from megatron import get_timers
from megatron import mpu
25
26
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
27
28

from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
29

30
31
32
# >>>
from lutil import pax, tp, print_seq
# <<<
33

Lawrence McAfee's avatar
Lawrence McAfee committed
34

35
36
class Range:

37
38
39
40
41
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
42
        return Range(start, start + self.size)
43
44
45
46
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)


47
class DistributedOptimizer(MixedPrecisionOptimizer):
48
49

    @classmethod
50
    def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
51

52
        # Param range map.
53
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
54
        param_range_map = {}
55
56
        for param, param_world_indexes in param_world_index_map.items():

57
            # Param range.
58
59
60
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
                0,
61
                param_world_start - gbuf_world_range.start)
62
            param_local_end = min(
63
64
                gbuf_world_range.size,
                param_world_end - gbuf_world_range.start)
65

66
            # Add param, if within local gbuf range.
67
            if param_local_end > param_local_start:
68
69
70
71
72
73
74
75
76
                param_local_range = Range(param_local_start, param_local_end)
                param_world_range = param_local_range.normalize(
                    param_local_start + gbuf_world_range.start)
                sub_param_start = max(0, gbuf_world_range.start-param_world_start)
                sub_param_range = param_local_range.normalize(sub_param_start)
                param_range_map[param] = {
                    "gbuf_world" : param_world_range,
                    "gbuf_local" : param_local_range,
                    "param" : sub_param_range,
77
78
                }

79
        return param_range_map
80

Lawrence McAfee's avatar
Lawrence McAfee committed
81

82
    @classmethod
83
    def build_model_gbuf_range(cls, model, dtype):
84
85
86
87

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()

88
        # Grad buffer range.
89
90
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
91
        max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
92

93
94
        # All world ranges. (i.e., across all data parallel ranks)
        gbuf_world_all_ranges = []
95
        for r in range(data_parallel_world_size):
96
97
98
99
            gbuf_world_start = r * max_gbuf_range_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
            gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_ranges.append(gbuf_world_range)
100

101
102
103
        # Local DP's ranges.
        gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
        gbuf_local_range = gbuf_world_range.normalize()
104

105
106
107
108
        # Get each param's ranges.
        param_range_map = cls.build_model_gbuf_param_range_map(model,
                                                               dtype,
                                                               gbuf_world_range)
109
110
111

        # Altogether.
        data = {
112
113
114
115
116
            "local" : gbuf_local_range,
            "world" : gbuf_world_range,
            "world_all" : gbuf_world_all_ranges,
            "param_map" : param_range_map,
            "max_range_size" : max_gbuf_range_size,
117
118
119
120
        }

        return data

Lawrence McAfee's avatar
Lawrence McAfee committed
121

122
    @classmethod
123
    def build_model_gbuf_range_map(cls, model):
124
        return {
125
            dtype : cls.build_model_gbuf_range(model, dtype)
126
127
128
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
129

130
    @classmethod
131
132
    def build_model_param_gbuf_map(cls, model_gbuf_ranges):
        '''Create a reverse of the model_gbuf_ranges, for referencing in
133
        opposite direction.'''
134
        param_gbuf_map = {}
135
136
137
        for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param, param_range_map in gbuf_range_map["param_map"].items():
138
139
140
                    param_gbuf_map[param] = (model_index, dtype)
        return param_gbuf_map

Lawrence McAfee's avatar
Lawrence McAfee committed
141

142
    @classmethod
143
    def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
144
145
146
147
148
149
150
151
152
153

        num_groups = len(param_groups)

        # Param group map.
        param_group_map = {}
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
                param_group_map[param] = group_index

154
155
156
157
158
        # Optimizer group ranges.
        group_ranges = [ {"params": []} for _ in param_groups ]
        for model_gbuf_range_map in model_gbuf_ranges:
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param in gbuf_range_map["param_map"]:
159
                    group_index = param_group_map[param]
160
161
                    group_range = group_ranges[group_index]
                    group_range["params"].append(param)
162

163
164
165
166
        # Squeeze zero-size group ranges.
        for group_index, group_range in enumerate(group_ranges):
            group_range["orig_group"] = param_groups[group_index]
        group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
167

168
        return group_ranges
169

170

171
    @classmethod
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
    def build_model_and_main_param_groups(cls,
                                        model_gbuf_ranges,
                                        param_gbuf_map,
                                        opt_group_ranges):

        # Three groups of parameters:
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
        #   fp32_groups: original fp32 parameters
        full_float16_groups = []
        full_fp32_groups = []
        shard_float16_groups = []
        shard_fp32_groups = []
        shard_fp32_from_float16_groups = []

Lawrence McAfee's avatar
Lawrence McAfee committed
187
        # Allocate (or slice) each group's param shard.
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
        for group_index, group_range in enumerate(opt_group_ranges):

            # Params of this group.
            full_float16_params_this_group = []
            full_fp32_params_this_group = []
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_fp32_from_float16_params_this_group = []
            full_float16_groups.append(full_float16_params_this_group)
            full_fp32_groups.append(full_fp32_params_this_group)
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)
            shard_fp32_from_float16_groups.append(
                shard_fp32_from_float16_params_this_group)

            for model_param in group_range["params"]:

205
206
                assert model_param.requires_grad

207
208
209
                model_index, dtype = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[model_index][dtype]
                param_range = gbuf_range["param_map"][model_param]["param"]
210
211

                # fp16, bf16 params.
212
213
214
215
                if model_param.type() in ['torch.cuda.HalfTensor',
                                          'torch.cuda.BFloat16Tensor']:

                    # Clone model -> main.
Lawrence McAfee's avatar
Lawrence McAfee committed
216
217
                    shard_model_param = model_param.detach().view(-1) \
                        [param_range.start:param_range.end]
218
219
220
221
222
223
224
225
226
227
228
229
230
                    shard_main_param = shard_model_param.clone().float()
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_main_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                        shard_main_param.shared = model_param.shared

                    # Add to group.
                    full_float16_params_this_group.append(model_param)
                    shard_float16_params_this_group.append(shard_model_param)
                    shard_fp32_from_float16_params_this_group.append(shard_main_param)
231
232

                # fp32 params.
233
                elif model_param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
234
235
                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
236
237
                    full_fp32_params_this_group.append(model_param)
                    shard_fp32_params_this_group.append(shard_model_param)
238
239
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
240
241
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
242
243
244
245
246
247
248
249

                else:
                    raise TypeError('Wrapped parameters must be one of '
                                    'torch.cuda.FloatTensor,  '
                                    'torch.cuda.HalfTensor, or '
                                    'torch.cuda.BFloat16Tensor. '
                                    'Received {}'.format(param.type()))

Lawrence McAfee's avatar
Lawrence McAfee committed
250
            # Update optimizer's params.
251
252
253
254
255
256
257
258
259
260
261
262
            group_range["orig_group"]["params"] = [
                *shard_fp32_params_this_group,
                *shard_fp32_from_float16_params_this_group,
            ]

        return (
            full_float16_groups,
            full_fp32_groups,
            shard_float16_groups,
            shard_fp32_groups,
            shard_fp32_from_float16_groups,
        )
263

Lawrence McAfee's avatar
Lawrence McAfee committed
264

265
266
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
267
                 fp16, bf16, grad_scaler, models):
268
269
270
271

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
272
            fp16, bf16, grad_scaler, models)
273

274
275
        # Verify that contiguous buffers are being used
        # - Note: this should already be checked in arguments.py
276
        assert use_contiguous_buffers_in_local_ddp
277

278
279
        # Model grad buffer ranges.
        self.model_gbuf_ranges = []
280
        for model_index, model in enumerate(self.models):
281
282
283
            self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
        self.model_param_gbuf_map = \
            self.build_model_param_gbuf_map(self.model_gbuf_ranges)
284

285
286
        # Optimizer ranges.
        self.opt_group_ranges = self.build_optimizer_group_ranges(
287
            self.optimizer.param_groups,
288
            self.model_gbuf_ranges)
289
290

        # Allocate main param shards.
291
292
293
294
295
296
297
298
299
300
        (
            self.full_float16_groups,
            self.full_fp32_groups,
            self.shard_float16_groups,
            self.shard_fp32_groups,
            self.shard_fp32_from_float16_groups,
        ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
                                                   self.model_param_gbuf_map,
                                                   self.opt_group_ranges)

301
302
303
304
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
305
            [ g["orig_group"] for g in self.opt_group_ranges ]
306
307
        self.optimizer.load_state_dict(self.optimizer.state_dict())

308

309
310
311
312
313
314
315
    def get_model_param_range_map(self, param):
        model_index, dtype = self.model_param_gbuf_map[param]
        gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
        param_range_map = gbuf_range_map["param_map"][param]
        return param_range_map


316
317
318
    def get_model_parallel_group(self):
        return None

319
320

    def state_dict(self):
321
322
323
324
325
326
327
328
329
        state_dict = {}
        state_dict['optimizer'] = self.optimizer.state_dict()
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
        state_dict['shard_fp32_from_float16_groups'] = \
            self.shard_fp32_from_float16_groups
        return state_dict


330
    def load_state_dict(self, state_dict):
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357

        # Optimizer.
        optimizer_key = 'optimizer'
        if optimizer_key not in state_dict:
            optimizer_key = 'optimizer_state_dict'
            print_rank_0('***WARNING*** loading optimizer from '
                         'an old checkpoint ...')
        self.optimizer.load_state_dict(state_dict[optimizer_key])

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
            print_rank_0('***WARNING*** found an old checkpoint, will not '
                         'load grad scaler ...')
        else:
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')

        # Copy data for the main params.
        for current_group, saved_group in zip(
                self.shard_fp32_from_float16_groups,
                state_dict["shard_fp32_from_float16_groups"]):
            for current_param, saved_param in zip(current_group, saved_group):
                current_param.data.copy_(saved_param.data)
Lawrence McAfee's avatar
Lawrence McAfee committed
358

359

360
361
362
363
364
365
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
366
367
368
369
370
371
372
373
        # >>>
        # params = [ p for g in self.shard_fp32_groups for p in g ]
        # pax(0, {
        #     "shard_fp32_groups" : self.shard_fp32_groups,
        #     "params" : params,
        #     "grads" : [ p.grad for p in params ],
        # })
        # <<<
374
375
376
        for groups in (
                self.full_float16_groups,
                self.full_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
377
                self.shard_float16_groups, # grad empty/unused here?
378
                self.shard_fp32_groups, # throws grad-access warning
379
380
381
                self.shard_fp32_from_float16_groups):
            for group in groups:
                _zero_grad_group_helper(group, set_to_none)
382

383

384
385
386
387
388
389
390
391
392
393
394
395
396
    def get_model_grad_buffer_dp_views(self):

        data_parallel_world_size = mpu.get_data_parallel_world_size()

        # Grad buffer views.
        gbuf_view_items = []
        for model_index, model in enumerate(self.models):
            for dtype, gbuf in model._grad_buffers.items():

                assert gbuf.numel_padded % data_parallel_world_size == 0
                shard_size = int(gbuf.numel_padded / data_parallel_world_size)
                gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
                              for r in range(data_parallel_world_size)]
397
                gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
398
399

        return gbuf_view_items
400

Lawrence McAfee's avatar
Lawrence McAfee committed
401

402
    def reduce_model_grads(self, args, timers):
403
404
405
406
        '''Note: this is a different order of reduction, versus the non-
           distributed optimizer, which reduces: 1) all grads, 2) embedding
           grads.
        '''
407

408
409
410
411
412
        # All-reduce embedding grads.
        timers('backward-embedding-all-reduce').start()
        self.allreduce_embedding_grads(args)
        timers('backward-embedding-all-reduce').stop()

413
        # Reduce-scatter setup.
414
415
416
417
418
        timers('backward-params-all-reduce').start()
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_group = mpu.get_data_parallel_group()

419
420
421
422
423
        # Scale grad buffers by '1 / data_parallel_world_size'.
        for model in self.models:
            for dtype, gbuf in model._grad_buffers.items():
                gbuf.data /= data_parallel_world_size

424
        # Reduce-scatter all grads.
425
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
426
        for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
427
            torch.distributed._reduce_scatter_base(
428
                gbuf_views[data_parallel_rank],
429
                gbuf,
430
431
                group = data_parallel_group,
            )
432

433
        timers('backward-params-all-reduce').stop()
434

Lawrence McAfee's avatar
Lawrence McAfee committed
435

436
    def gather_model_params(self, args, timers):
437
438
439
440
441
442
443
444
445
446
447

        timers('backward-params-all-gather').start()

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()

        # All-gather updated main params.
        # - All grad buffer views are guaranteed to have the same num elements
        #   across all data parallel ranks, with grad buffer padding that is done
        #   in distributed.py. Thus, all sub-views will have consistent start/end
        #   indexes across data parallel ranks.
448
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
449
        for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
450
            torch.distributed._all_gather_base(
451
                gbuf,
452
453
                gbuf_views[data_parallel_rank],
                group = data_parallel_group,
454
            )
455
456
457
458
459
460
461
462
463

        # Each model param now contains its updated values in its
        # '.main_grad' field.
        for model in self.models:
            for dtype, param_map in model._grad_buffer_param_index_map.items():
                for param in param_map:
                    param.detach().copy_(param.main_grad)

        timers('backward-params-all-gather').stop()
464

Lawrence McAfee's avatar
Lawrence McAfee committed
465

466
    def _collect_main_grad_data_for_unscaling(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
467
        return [
Lawrence McAfee's avatar
Lawrence McAfee committed
468
469
470
471
            param.grad.data
            for group in self.optimizer.param_groups
            for param in group["params"]
        ]
472
473


Lawrence McAfee's avatar
Lawrence McAfee committed
474
475
476
477
478
479
480
481
482
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.shard_float16_groups,
                                           self.shard_fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
483
484


485
    def _copy_model_grads_to_main_grads(self):
486

Lawrence McAfee's avatar
Lawrence McAfee committed
487
488
489
490
491
492
493
494
        def copy_group_grads(full_model_groups, shard_main_groups):
            for full_model_group, shard_main_group in zip(full_model_groups,
                                                          shard_main_groups):
                for full_model_param, shard_main_param in zip(full_model_group,
                                                              shard_main_group):

                    param_range_map = self.get_model_param_range_map(full_model_param)
                    param_range = param_range_map["param"]
Lawrence McAfee's avatar
Lawrence McAfee committed
495
496
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
497
                    full_model_grad = full_model_param.main_grad
Lawrence McAfee's avatar
Lawrence McAfee committed
498
499
                    shard_model_grad = full_model_grad.view(-1) \
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
500
501
502
503
504
505
                    shard_main_param.grad = shard_model_grad.float()

        copy_group_grads(self.full_float16_groups,
                         self.shard_fp32_from_float16_groups)
        copy_group_grads(self.full_fp32_groups,
                         self.shard_fp32_groups)
506

507
508
509

    def _copy_main_params_to_model_params(self):

Lawrence McAfee's avatar
Lawrence McAfee committed
510
511
512
513
514
515
516
517
518
519
520
521
522
        def copy_group_params(shard_main_groups, full_model_groups):
            for shard_main_group, full_model_group in zip(shard_main_groups,
                                                          full_model_groups):
                for shard_main_param, full_model_param in zip(shard_main_group,
                                                              full_model_group):

                    param_range_map = self.get_model_param_range_map(full_model_param)
                    param_range = param_range_map["param"]
                    assert param_range.size == shard_main_param.nelement()

                    full_model_grad = full_model_param.main_grad
                    shard_model_grad = full_model_grad.view(-1) \
                        [param_range.start:param_range.end]
523

524
                    shard_model_grad.data.copy_(shard_main_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
525
526
527
528
529

        copy_group_params(self.shard_fp32_from_float16_groups,
                          self.full_float16_groups)
        copy_group_params(self.shard_fp32_groups,
                          self.full_fp32_groups)