distrib_optimizer.py 20.7 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Megatron distributed optimizer."""


import math
20
import torch
21
22

from megatron import get_args
23
24
from megatron import get_timers
from megatron import mpu
25
26
from megatron.model.module import param_is_not_shared
from megatron.mpu.layers import param_is_not_tensor_parallel_duplicate
27
28

from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
29

30
31
32
# >>>
from lutil import pax, tp, print_seq
# <<<
33

34
35
class Range:

36
37
38
39
40
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
41
        return Range(start, start + self.size)
42
43
44
45
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)


46
class DistributedOptimizer(MixedPrecisionOptimizer):
47
48

    @classmethod
49
    def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
50

51
        # Param range map.
52
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
53
        param_range_map = {}
54
55
        for param, param_world_indexes in param_world_index_map.items():

56
            # Param range.
57
58
59
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
                0,
60
                param_world_start - gbuf_world_range.start)
61
            param_local_end = min(
62
63
                gbuf_world_range.size,
                param_world_end - gbuf_world_range.start)
64

65
            # Add param, if within local gbuf range.
66
            if param_local_end > param_local_start:
67
68
69
70
71
72
73
74
75
                param_local_range = Range(param_local_start, param_local_end)
                param_world_range = param_local_range.normalize(
                    param_local_start + gbuf_world_range.start)
                sub_param_start = max(0, gbuf_world_range.start-param_world_start)
                sub_param_range = param_local_range.normalize(sub_param_start)
                param_range_map[param] = {
                    "gbuf_world" : param_world_range,
                    "gbuf_local" : param_local_range,
                    "param" : sub_param_range,
76
77
                }

78
        return param_range_map
79

Lawrence McAfee's avatar
Lawrence McAfee committed
80

81
    @classmethod
82
    def build_model_gbuf_range(cls, model, dtype):
83
84
85
86

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()

87
        # Grad buffer range.
88
89
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
90
        max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
91

92
93
        # All world ranges. (i.e., across all data parallel ranks)
        gbuf_world_all_ranges = []
94
        for r in range(data_parallel_world_size):
95
96
97
98
            gbuf_world_start = r * max_gbuf_range_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
            gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_ranges.append(gbuf_world_range)
99

100
101
102
        # Local DP's ranges.
        gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
        gbuf_local_range = gbuf_world_range.normalize()
103

104
105
106
107
        # Get each param's ranges.
        param_range_map = cls.build_model_gbuf_param_range_map(model,
                                                               dtype,
                                                               gbuf_world_range)
108
109
110

        # Altogether.
        data = {
111
112
113
114
115
            "local" : gbuf_local_range,
            "world" : gbuf_world_range,
            "world_all" : gbuf_world_all_ranges,
            "param_map" : param_range_map,
            "max_range_size" : max_gbuf_range_size,
116
117
118
119
        }

        return data

Lawrence McAfee's avatar
Lawrence McAfee committed
120

121
    @classmethod
122
    def build_model_gbuf_range_map(cls, model):
123
        return {
124
            dtype : cls.build_model_gbuf_range(model, dtype)
125
126
127
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
128

129
    @classmethod
130
131
    def build_model_param_gbuf_map(cls, model_gbuf_ranges):
        '''Create a reverse of the model_gbuf_ranges, for referencing in
132
        opposite direction.'''
133
        param_gbuf_map = {}
134
135
136
        for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param, param_range_map in gbuf_range_map["param_map"].items():
137
138
139
                    param_gbuf_map[param] = (model_index, dtype)
        return param_gbuf_map

Lawrence McAfee's avatar
Lawrence McAfee committed
140

141
    @classmethod
142
    def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
143
144
145
146
147
148
149
150
151
152

        num_groups = len(param_groups)

        # Param group map.
        param_group_map = {}
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
                param_group_map[param] = group_index

153
154
155
156
157
        # Optimizer group ranges.
        group_ranges = [ {"params": []} for _ in param_groups ]
        for model_gbuf_range_map in model_gbuf_ranges:
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param in gbuf_range_map["param_map"]:
158
                    group_index = param_group_map[param]
159
160
                    group_range = group_ranges[group_index]
                    group_range["params"].append(param)
161

162
163
164
165
        # Squeeze zero-size group ranges.
        for group_index, group_range in enumerate(group_ranges):
            group_range["orig_group"] = param_groups[group_index]
        group_ranges = [ g for g in group_ranges if len(g["params"]) > 0 ]
166

167
        return group_ranges
168

169

170
    @classmethod
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
    def build_model_and_main_param_groups(cls,
                                        model_gbuf_ranges,
                                        param_gbuf_map,
                                        opt_group_ranges):

        # Three groups of parameters:
        #   float16_groups: original float16 parameters
        #   fp32_from_float16_groups: fp32 copy of float16 parameters
        #   fp32_groups: original fp32 parameters
        full_float16_groups = []
        full_fp32_groups = []
        shard_float16_groups = []
        shard_fp32_groups = []
        shard_fp32_from_float16_groups = []

Lawrence McAfee's avatar
Lawrence McAfee committed
186
        # Allocate (or slice) each group's param shard.
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
        for group_index, group_range in enumerate(opt_group_ranges):

            # Params of this group.
            full_float16_params_this_group = []
            full_fp32_params_this_group = []
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_fp32_from_float16_params_this_group = []
            full_float16_groups.append(full_float16_params_this_group)
            full_fp32_groups.append(full_fp32_params_this_group)
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)
            shard_fp32_from_float16_groups.append(
                shard_fp32_from_float16_params_this_group)

            for model_param in group_range["params"]:

                model_index, dtype = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[model_index][dtype]
                param_range = gbuf_range["param_map"][model_param]["param"]
207
208

                # fp16, bf16 params.
209
210
211
212
                if model_param.type() in ['torch.cuda.HalfTensor',
                                          'torch.cuda.BFloat16Tensor']:

                    # Clone model -> main.
Lawrence McAfee's avatar
Lawrence McAfee committed
213
214
                    shard_model_param = model_param.detach().view(-1) \
                        [param_range.start:param_range.end]
215
216
217
218
219
220
221
222
223
224
225
226
227
                    shard_main_param = shard_model_param.clone().float()
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_model_param, model_param)
                    mpu.copy_tensor_model_parallel_attributes(
                        shard_main_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                        shard_main_param.shared = model_param.shared

                    # Add to group.
                    full_float16_params_this_group.append(model_param)
                    shard_float16_params_this_group.append(shard_model_param)
                    shard_fp32_from_float16_params_this_group.append(shard_main_param)
228
229

                # fp32 params.
230
                elif model_param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
231
232
                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
233
234
                    full_fp32_params_this_group.append(model_param)
                    shard_fp32_params_this_group.append(shard_model_param)
235
236
237
238
239
240
241
242

                else:
                    raise TypeError('Wrapped parameters must be one of '
                                    'torch.cuda.FloatTensor,  '
                                    'torch.cuda.HalfTensor, or '
                                    'torch.cuda.BFloat16Tensor. '
                                    'Received {}'.format(param.type()))

Lawrence McAfee's avatar
Lawrence McAfee committed
243
            # Update optimizer's params.
244
245
246
247
248
249
250
251
252
253
254
255
            group_range["orig_group"]["params"] = [
                *shard_fp32_params_this_group,
                *shard_fp32_from_float16_params_this_group,
            ]

        return (
            full_float16_groups,
            full_fp32_groups,
            shard_float16_groups,
            shard_fp32_groups,
            shard_fp32_from_float16_groups,
        )
256

Lawrence McAfee's avatar
Lawrence McAfee committed
257

258
259
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
260
                 fp16, bf16, grad_scaler, models):
261
262
263
264

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
265
            fp16, bf16, grad_scaler, models)
266

267
268
        # Verify that contiguous buffers are being used
        # - Note: this should already be checked in arguments.py
269
        assert use_contiguous_buffers_in_local_ddp
270

271
272
        # Model grad buffer ranges.
        self.model_gbuf_ranges = []
273
        for model_index, model in enumerate(self.models):
274
275
276
            self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
        self.model_param_gbuf_map = \
            self.build_model_param_gbuf_map(self.model_gbuf_ranges)
277

278
279
        # Optimizer ranges.
        self.opt_group_ranges = self.build_optimizer_group_ranges(
280
            self.optimizer.param_groups,
281
            self.model_gbuf_ranges)
282
283

        # Allocate main param shards.
284
285
286
287
288
289
290
291
292
293
        (
            self.full_float16_groups,
            self.full_fp32_groups,
            self.shard_float16_groups,
            self.shard_fp32_groups,
            self.shard_fp32_from_float16_groups,
        ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
                                                   self.model_param_gbuf_map,
                                                   self.opt_group_ranges)

294
295
296
297
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
298
            [ g["orig_group"] for g in self.opt_group_ranges ]
299
300
        self.optimizer.load_state_dict(self.optimizer.state_dict())

301

302
303
304
305
306
307
308
    def get_model_param_range_map(self, param):
        model_index, dtype = self.model_param_gbuf_map[param]
        gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
        param_range_map = gbuf_range_map["param_map"][param]
        return param_range_map


309
310
311
    def get_model_parallel_group(self):
        return None

312

313
    # >>>
314
315
316
317
318
319
320
    # def state_dict(self):
    #     state_dict = {}
    #     state_dict['optimizer'] = self.optimizer.state_dict()
    #     if self.grad_scaler:
    #         state_dict['grad_scaler'] = self.grad_scaler.state_dict()
    #     state_dict['groups'] = [g['params'] for g in self.optimizer.param_groups]
    #     return state_dict
321
    def state_dict(self):
322
        raise Exception("fix me.")
Lawrence McAfee's avatar
Lawrence McAfee committed
323
    # <<<
324
325


Lawrence McAfee's avatar
Lawrence McAfee committed
326
    # >>>
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
    # def load_state_dict(self, state_dict):
    #     # Optimizer.
    #     optimizer_key = 'optimizer'
    #     if optimizer_key not in state_dict:
    #         optimizer_key = 'optimizer_state_dict'
    #         print_rank_0('***WARNING*** loading optimizer from '
    #                      'an old checkpoint ...')
    #     self.optimizer.load_state_dict(state_dict[optimizer_key])

    #     # Grad scaler.
    #     if 'grad_scaler' not in state_dict:
    #         print_rank_0('***WARNING*** found an old checkpoint, will not '
    #                      'load grad scaler ...')
    #     else:
    #         if self.grad_scaler:
    #             self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
    #         else:
    #             print_rank_0('***WARNING*** fould the grad scaler in the '
    #                          'checkpoint but it is None in the class. '
    #                          'Skipping loading grad scaler ...')

    #     # Copy data for the main params.
    #     current_groups = [ g["params"] for g in self.optimizer.param_groups ]
    #     assert "groups" in state_dict, "key 'groups' not in state_dict."
    #     for current_group, saved_group in zip(current_groups, state_dict["groups"]):
    #         for current_param, saved_param in zip(current_group, saved_group):
    #             current_param.data.copy_(saved_param.data)
354
    def load_state_dict(self, state_dict):
355
        raise Exception("hi.")
Lawrence McAfee's avatar
Lawrence McAfee committed
356
357
    # <<<

358

359
360
361
362
363
364
365
366
367
    def zero_grad(self, set_to_none=True):
        """We only need to zero the model related parameters, i.e.,
        float16_groups & fp32_groups. We additionally zero
        fp32_from_float16_groups as a memory optimization to reduce
        fragmentation; in the case of set_to_none==True, the space
        used by this field can be safely deallocated at this point."""
        for groups in (
                self.full_float16_groups,
                self.full_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
368
369
                self.shard_float16_groups, # grad empty/unused here?
                self.shard_fp32_groups,
370
371
372
                self.shard_fp32_from_float16_groups):
            for group in groups:
                _zero_grad_group_helper(group, set_to_none)
373

374

375
376
377
378
379
380
381
382
383
384
385
386
387
    def get_model_grad_buffer_dp_views(self):

        data_parallel_world_size = mpu.get_data_parallel_world_size()

        # Grad buffer views.
        gbuf_view_items = []
        for model_index, model in enumerate(self.models):
            for dtype, gbuf in model._grad_buffers.items():

                assert gbuf.numel_padded % data_parallel_world_size == 0
                shard_size = int(gbuf.numel_padded / data_parallel_world_size)
                gbuf_views = [gbuf.data[(r*shard_size):((r+1)*shard_size)]
                              for r in range(data_parallel_world_size)]
388
                gbuf_view_items.append((model_index, dtype, gbuf.data, gbuf_views))
389
390

        return gbuf_view_items
391

Lawrence McAfee's avatar
Lawrence McAfee committed
392

393
    def reduce_model_grads(self, args, timers):
394
395
396
397
        '''Note: this is a different order of reduction, versus the non-
           distributed optimizer, which reduces: 1) all grads, 2) embedding
           grads.
        '''
398

399
400
401
402
403
        # All-reduce embedding grads.
        timers('backward-embedding-all-reduce').start()
        self.allreduce_embedding_grads(args)
        timers('backward-embedding-all-reduce').stop()

404
        # Reduce-scatter setup.
405
406
407
408
409
        timers('backward-params-all-reduce').start()
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_group = mpu.get_data_parallel_group()

410
411
412
413
414
        # Scale grad buffers by '1 / data_parallel_world_size'.
        for model in self.models:
            for dtype, gbuf in model._grad_buffers.items():
                gbuf.data /= data_parallel_world_size

415
        # Reduce-scatter all grads.
416
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
417
        for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
418
            torch.distributed._reduce_scatter_base(
419
                gbuf_views[data_parallel_rank],
420
                gbuf,
421
422
                group = data_parallel_group,
            )
423

424
        timers('backward-params-all-reduce').stop()
425

Lawrence McAfee's avatar
Lawrence McAfee committed
426

427
    def gather_model_params(self, args, timers):
428
429
430
431
432
433
434
435
436
437
438

        timers('backward-params-all-gather').start()

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()

        # All-gather updated main params.
        # - All grad buffer views are guaranteed to have the same num elements
        #   across all data parallel ranks, with grad buffer padding that is done
        #   in distributed.py. Thus, all sub-views will have consistent start/end
        #   indexes across data parallel ranks.
439
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
440
        for index, (model_index, dtype, gbuf, gbuf_views) in enumerate(gbuf_view_items):
441
            torch.distributed._all_gather_base(
442
                gbuf,
443
444
                gbuf_views[data_parallel_rank],
                group = data_parallel_group,
445
            )
446
447
448
449
450
451
452
453
454

        # Each model param now contains its updated values in its
        # '.main_grad' field.
        for model in self.models:
            for dtype, param_map in model._grad_buffer_param_index_map.items():
                for param in param_map:
                    param.detach().copy_(param.main_grad)

        timers('backward-params-all-gather').stop()
455

Lawrence McAfee's avatar
Lawrence McAfee committed
456

457
    def _collect_main_grad_data_for_unscaling(self):
Lawrence McAfee's avatar
Lawrence McAfee committed
458
        return [
Lawrence McAfee's avatar
Lawrence McAfee committed
459
460
461
462
            param.grad.data
            for group in self.optimizer.param_groups
            for param in group["params"]
        ]
463
464


Lawrence McAfee's avatar
Lawrence McAfee committed
465
466
467
468
469
470
471
472
473
    def _get_model_and_main_params_data_float16(self):
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.shard_float16_groups,
                                           self.shard_fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
474
475


476
    def _copy_model_grads_to_main_grads(self):
477

Lawrence McAfee's avatar
Lawrence McAfee committed
478
479
480
481
482
483
484
485
        def copy_group_grads(full_model_groups, shard_main_groups):
            for full_model_group, shard_main_group in zip(full_model_groups,
                                                          shard_main_groups):
                for full_model_param, shard_main_param in zip(full_model_group,
                                                              shard_main_group):

                    param_range_map = self.get_model_param_range_map(full_model_param)
                    param_range = param_range_map["param"]
Lawrence McAfee's avatar
Lawrence McAfee committed
486
487
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
488
                    full_model_grad = full_model_param.main_grad
Lawrence McAfee's avatar
Lawrence McAfee committed
489
490
                    shard_model_grad = full_model_grad.view(-1) \
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
491
492
493
494
495
496
                    shard_main_param.grad = shard_model_grad.float()

        copy_group_grads(self.full_float16_groups,
                         self.shard_fp32_from_float16_groups)
        copy_group_grads(self.full_fp32_groups,
                         self.shard_fp32_groups)
497

498
499
500

    def _copy_main_params_to_model_params(self):

Lawrence McAfee's avatar
Lawrence McAfee committed
501
502
503
504
505
506
507
508
509
510
511
512
513
        def copy_group_params(shard_main_groups, full_model_groups):
            for shard_main_group, full_model_group in zip(shard_main_groups,
                                                          full_model_groups):
                for shard_main_param, full_model_param in zip(shard_main_group,
                                                              full_model_group):

                    param_range_map = self.get_model_param_range_map(full_model_param)
                    param_range = param_range_map["param"]
                    assert param_range.size == shard_main_param.nelement()

                    full_model_grad = full_model_param.main_grad
                    shard_model_grad = full_model_grad.view(-1) \
                        [param_range.start:param_range.end]
514

515
                    shard_model_grad.data.copy_(shard_main_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
516
517
518
519
520

        copy_group_params(self.shard_fp32_from_float16_groups,
                          self.full_float16_groups)
        copy_group_params(self.shard_fp32_groups,
                          self.full_fp32_groups)