distrib_optimizer.py 43.2 KB
Newer Older
liangjing's avatar
v1  
liangjing committed
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
3
4
5

"""Megatron distributed optimizer."""


liangjing's avatar
v1  
liangjing committed
6
from apex.optimizers import FusedAdam as Adam
7
import math
8
import torch
9
10

from megatron import get_args
11
from megatron import get_timers
Lawrence McAfee's avatar
Lawrence McAfee committed
12
from megatron import print_rank_0
13
from megatron.core import mpu, tensor_parallel
14
from megatron.model.module import param_is_not_shared
15
16

from .optimizer import MixedPrecisionOptimizer, _zero_grad_group_helper
17

Lawrence McAfee's avatar
Lawrence McAfee committed
18

19
class Range:
20
21
    """
    A range represents a start and end points for indexing a shard
Lawrence McAfee's avatar
Lawrence McAfee committed
22
    from a full tensor.
23
    """
24
25
26
27
28
    def __init__(self, start, end):
        self.start = start
        self.end = end
        self.size = end - start
    def normalize(self, start = 0):
29
        return Range(start, start + self.size)
30
31
    def __str__(self):
        return "%d,%d [%d]" % (self.start, self.end, self.size)
liangjing's avatar
v1  
liangjing committed
32
33
    def __len__(self):
        return self.end - self.start
34
35


36
class DistributedOptimizer(MixedPrecisionOptimizer):
37
    """Distributed optimizer, for all data types (fp16, bf16, and fp32).
Lawrence McAfee's avatar
Lawrence McAfee committed
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63

    Arguments:
        optimizer: base optimizer such as Adam or SGD
        clip_grad: clip gradeints with this global L2 norm. Note
            that clipping is ignored if clip_grad == 0
        log_num_zeros_in_grad: return number of zeros in the gradients.
        params_have_main_grad: flag indicating if parameters have
            a `main_grad` field. If this is set, we are assuming
            that the model parameters are store in the `main_grad`
            field instead of the typical `grad` field. This happens
            for the DDP cases where there is a continuous buffer
            holding the gradients. For example for bfloat16, we want
            to do gradient accumulation and all-reduces in float32
            and as a result we store those gradients in the main_grad.
            Note that main grad is not necessarily in float32.
        use_contiguous_buffers_in_local_ddp: if true, the local DDP model
            is using a contiguous buffer to hold the model grads.
        fp16: if true, the model is running in fp16.
        bf16: if true, the model is running in bfloat16.
        grad_scaler: used for scaling gradients. Note that this can be
            None. This case happens when `bf16 = True` and we don't
            use any loss scale. Note that for `bf16 = True`, we can have
            a constnat gradient scaler. Also for `bf16 = False`, we
            always require a grad scaler.
        models: list of models (i.e., the virtual pipelining models). This
            is used by the distributed optimizer for mapping parameters.
64
    """
65
66

    @classmethod
67
    def build_model_gbuf_param_range_map(cls, model, dtype, gbuf_world_range):
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        """
        Build mapping from param reference to grad buffer shard ranges.

        This method builds a mapping from parameter references to grad
        buffer shard ranges, specific to each data-parallel (DP) rank's
        set of 'owned' parameters. Each grad buffer (padded to be an even
        multiple of DP-world-size) is conceptually divided into DP-world-size
        contiguous regions, where each DP rank 'owns' a contiguous regions.
        Ownership in this sense means DP rank is responsible for reducing
        the relevant subset of grads, and updating the relevant subset of
        params.

        This conceptual partitioning of the grad buffer does NOT respect
        parameter boundaries, and as such it is assumed that each created
        range references a shard (or subset) of the full parameter. It is
        easiest to think of each DP rank as operating (i.e., reducing,
        gathering) purely on views into the grad buffer, for all model-to-
        main & main-to-model operations.

        This method creates three ranges:
        - The param's range within the entire grad buffer (i.e., world index).
        - The param's range within the DP rank's local view of the grad buffer.
        - The param's range within itself (i.e., its shard).
        """
92

93
        # Param range map.
94
        param_world_index_map = model._grad_buffer_param_index_map[dtype]
95
        param_range_map = {}
96
97
        for param, param_world_indexes in param_world_index_map.items():

98
            # Param range.
99
100
101
            param_world_start, param_world_end = param_world_indexes
            param_local_start = max(
                0,
102
                param_world_start - gbuf_world_range.start)
103
            param_local_end = min(
104
105
                gbuf_world_range.size,
                param_world_end - gbuf_world_range.start)
106

107
            # Add param, if within local gbuf range.
108
            if param_local_end > param_local_start:
109
110
111
112
113
114
115
116
117
                param_local_range = Range(param_local_start, param_local_end)
                param_world_range = param_local_range.normalize(
                    param_local_start + gbuf_world_range.start)
                sub_param_start = max(0, gbuf_world_range.start-param_world_start)
                sub_param_range = param_local_range.normalize(sub_param_start)
                param_range_map[param] = {
                    "gbuf_world" : param_world_range,
                    "gbuf_local" : param_local_range,
                    "param" : sub_param_range,
118
119
                }

120
        return param_range_map
121

Lawrence McAfee's avatar
Lawrence McAfee committed
122

123
    @classmethod
124
    def build_model_gbuf_range(cls, model, dtype):
125
126
127
128
129
130
131
132
133
        """
        Build mapping between params and their grad buffers.

        This method does the initial setup for the method above. This setup
        includes determining the shard ranges into the DDP's grad buffer for
        each data-parallel (DP) rank. Each DP rank keeps range info for
        all other DP ranks, for the purpose of creating args for
        reduce-scatter and all-gather.
        """
134
135
136
137

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()

138
        # Grad buffer range.
139
140
        grad_buffer = model._grad_buffers[dtype]
        gbuf_size = grad_buffer.numel
141
        max_gbuf_range_size = int(math.ceil(gbuf_size / data_parallel_world_size))
142

143
144
        # All world ranges. (i.e., across all data parallel ranks)
        gbuf_world_all_ranges = []
145
        for r in range(data_parallel_world_size):
146
147
148
149
            gbuf_world_start = r * max_gbuf_range_size
            gbuf_world_end = min(gbuf_size, gbuf_world_start+max_gbuf_range_size)
            gbuf_world_range = Range(gbuf_world_start, gbuf_world_end)
            gbuf_world_all_ranges.append(gbuf_world_range)
150

151
152
153
        # Local DP's ranges.
        gbuf_world_range = gbuf_world_all_ranges[data_parallel_rank]
        gbuf_local_range = gbuf_world_range.normalize()
154

155
156
157
158
        # Get each param's ranges.
        param_range_map = cls.build_model_gbuf_param_range_map(model,
                                                               dtype,
                                                               gbuf_world_range)
159

160
        # Group into dict.
161
        data = {
162
163
164
165
166
            "local" : gbuf_local_range,
            "world" : gbuf_world_range,
            "world_all" : gbuf_world_all_ranges,
            "param_map" : param_range_map,
            "max_range_size" : max_gbuf_range_size,
167
168
169
170
        }

        return data

Lawrence McAfee's avatar
Lawrence McAfee committed
171

172
    @classmethod
173
    def build_model_gbuf_range_map(cls, model):
174
175
176
177
        """
        Create param-to-grad-buffer mappings, for grad buffer data types
        within a specific virtual model.
        """
178
        return {
179
            dtype : cls.build_model_gbuf_range(model, dtype)
180
181
182
            for dtype in model._grad_buffers
        }

Lawrence McAfee's avatar
Lawrence McAfee committed
183

184
    @classmethod
185
    def build_model_param_gbuf_map(cls, model_gbuf_ranges):
186
187
188
189
        """
        Create a reverse of the model_gbuf_ranges, for referencing in
        opposite direction.
        """
190
        param_gbuf_map = {}
191
192
193
        for model_index, model_gbuf_range_map in enumerate(model_gbuf_ranges):
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param, param_range_map in gbuf_range_map["param_map"].items():
194
195
196
                    param_gbuf_map[param] = (model_index, dtype)
        return param_gbuf_map

Lawrence McAfee's avatar
Lawrence McAfee committed
197

198
    @classmethod
199
    def build_optimizer_group_ranges(cls, param_groups, model_gbuf_ranges):
200
201
202
203
204
205
206
207
        """
        Create optimizer groups.

        Given the set of parameter shard ranges that are owned by the current
        data-parallel (DP) rank, gather the set of parameters that will be
        used (in the method below) to create the current DP's optimizer
        groups.
        """
208
209
210
211

        num_groups = len(param_groups)

        # Param group map.
liangjing's avatar
v1  
liangjing committed
212
213
214
215
216
217
218
        # World param group map.
        # - Store a mapping of <model_parameter:group_index> for all parameters
        #   across all DP ranks. This is necessary because it is our first
        #   cross reference between the DDP mappings and the optimizer group
        #   parameters. This mapping only for use in the next step of building
        #   the local mapping over this DP rank's parameters.
        world_param_group_map = {}
219
220
221
        for group_index, group in enumerate(param_groups):
            for param in group["params"]:
                assert param.requires_grad
liangjing's avatar
v1  
liangjing committed
222
223
224
225
226
227
228
229
                world_param_group_map[param] = group_index

        # Optimizer group ranges & param-group mapping.
        # - Build a mapping from groups to their contained parameters, and also
        #   from parameters to their containing group index and order within
        #   the group. The group index and order are particularly important for
        #   saving and loading checkpoints.
        local_param_group_map = {}
230
231
232
233
        group_ranges = [ {"params": []} for _ in param_groups ]
        for model_gbuf_range_map in model_gbuf_ranges:
            for dtype, gbuf_range_map in model_gbuf_range_map.items():
                for param in gbuf_range_map["param_map"]:
liangjing's avatar
v1  
liangjing committed
234
                    group_index = world_param_group_map[param]
235
236
                    group_range = group_ranges[group_index]
                    group_range["params"].append(param)
liangjing's avatar
v1  
liangjing committed
237
238
                    local_param_group_map[param] = \
                        (group_index, len(group_range["params"]) - 1)
239

240
241
242
        # Squeeze zero-size group ranges.
        for group_index, group_range in enumerate(group_ranges):
            group_range["orig_group"] = param_groups[group_index]
liangjing's avatar
v1  
liangjing committed
243
            group_range["orig_group_idx"] = param_groups[group_index]
244

liangjing's avatar
v1  
liangjing committed
245
        return local_param_group_map, group_ranges
246

247

248
    @classmethod
249
    def build_model_and_main_param_groups(cls,
250
251
252
253
254
255
256
257
258
259
260
261
262
                                          model_gbuf_ranges,
                                          param_gbuf_map,
                                          opt_group_ranges):
        """
        Create main parameter groups needed for the optimizer step.

        These groups encompass both: 1) groups used by this class, for
        reducing/gather, and 2) groups used by the inner optimizer for the
        parameter update. Given that the conceptual grad buffer partitioning
        (created in earlier method) doesn't respect parameter boundaries,
        the optimizer operates on shards of the model parameters, rather than
        the full parameters.
        """
263

Lawrence McAfee's avatar
Lawrence McAfee committed
264
265
266
267
268
269
270
271
        # Parameter groups:
        #   model_float16_groups: original float16 parameters
        #   model_fp32_groups: original fp32 parameters
        #   shard_float16_groups: shards of original float16 parameters
        #   shard_fp32_groups: shards of original fp32 parameters
        #   shard_fp32_from_float16_groups: fp32 copy of float16 parameters
        model_float16_groups = []
        model_fp32_groups = []
272
273
274
275
        shard_float16_groups = []
        shard_fp32_groups = []
        shard_fp32_from_float16_groups = []

Lawrence McAfee's avatar
Lawrence McAfee committed
276
        # Allocate (or slice) each group's param shard.
277
278
279
        for group_index, group_range in enumerate(opt_group_ranges):

            # Params of this group.
Lawrence McAfee's avatar
Lawrence McAfee committed
280
281
            model_float16_params_this_group = []
            model_fp32_params_this_group = []
282
283
284
            shard_float16_params_this_group = []
            shard_fp32_params_this_group = []
            shard_fp32_from_float16_params_this_group = []
Lawrence McAfee's avatar
Lawrence McAfee committed
285
286
            model_float16_groups.append(model_float16_params_this_group)
            model_fp32_groups.append(model_fp32_params_this_group)
287
288
289
290
291
292
293
            shard_float16_groups.append(shard_float16_params_this_group)
            shard_fp32_groups.append(shard_fp32_params_this_group)
            shard_fp32_from_float16_groups.append(
                shard_fp32_from_float16_params_this_group)

            for model_param in group_range["params"]:

294
295
                assert model_param.requires_grad

296
297
298
                model_index, dtype = param_gbuf_map[model_param]
                gbuf_range = model_gbuf_ranges[model_index][dtype]
                param_range = gbuf_range["param_map"][model_param]["param"]
299
300

                # fp16, bf16 params.
301
302
303
304
                if model_param.type() in ['torch.cuda.HalfTensor',
                                          'torch.cuda.BFloat16Tensor']:

                    # Clone model -> main.
Lawrence McAfee's avatar
Lawrence McAfee committed
305
306
                    shard_model_param = model_param.detach().view(-1) \
                        [param_range.start:param_range.end]
307
                    shard_main_param = shard_model_param.clone().float()
308
                    tensor_parallel.copy_tensor_model_parallel_attributes(
309
                        shard_model_param, model_param)
310
                    tensor_parallel.copy_tensor_model_parallel_attributes(
311
312
313
314
315
316
                        shard_main_param, model_param)
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
                        shard_main_param.shared = model_param.shared

                    # Add to group.
Lawrence McAfee's avatar
Lawrence McAfee committed
317
                    model_float16_params_this_group.append(model_param)
318
319
                    shard_float16_params_this_group.append(shard_model_param)
                    shard_fp32_from_float16_params_this_group.append(shard_main_param)
320
321

                # fp32 params.
322
                elif model_param.type() == 'torch.cuda.FloatTensor':
Lawrence McAfee's avatar
Lawrence McAfee committed
323
324
                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
325
                    model_fp32_params_this_group.append(model_param)
326
                    shard_fp32_params_this_group.append(shard_model_param)
327
                    tensor_parallel.copy_tensor_model_parallel_attributes(
328
                        shard_model_param, model_param)
329
330
                    if hasattr(model_param, 'shared'):
                        shard_model_param.shared = model_param.shared
331
332
333
334
335
336

                else:
                    raise TypeError('Wrapped parameters must be one of '
                                    'torch.cuda.FloatTensor,  '
                                    'torch.cuda.HalfTensor, or '
                                    'torch.cuda.BFloat16Tensor. '
liangjing's avatar
v1  
liangjing committed
337
                                    'Received {}'.format(model_param.type()))
338

Lawrence McAfee's avatar
Lawrence McAfee committed
339
            # Update optimizer's params.
340
341
342
343
344
345
            group_range["orig_group"]["params"] = [
                *shard_fp32_params_this_group,
                *shard_fp32_from_float16_params_this_group,
            ]

        return (
Lawrence McAfee's avatar
Lawrence McAfee committed
346
347
            model_float16_groups,
            model_fp32_groups,
348
349
350
351
            shard_float16_groups,
            shard_fp32_groups,
            shard_fp32_from_float16_groups,
        )
352

Lawrence McAfee's avatar
Lawrence McAfee committed
353

354
355
    def __init__(self, optimizer, clip_grad, log_num_zeros_in_grad,
                 params_have_main_grad, use_contiguous_buffers_in_local_ddp,
356
                 fp16, bf16, params_dtype, grad_scaler, models):
357
358
        """
        See top of class definition for argument descriptions.
359
360
361
362
363
364

        The steps in this method create the core mapping between DDP grad
        buffers, parameters, and parameter shard ranges, that is needed for
        converting between model param indexes and main parameter shard
        indexes. This method also updates the optimizer parameter groups
        with the newly created shards.
365
        """
366
367
368
369

        super().__init__(
            optimizer, clip_grad, log_num_zeros_in_grad,
            params_have_main_grad, use_contiguous_buffers_in_local_ddp,
370
            fp16, bf16, params_dtype, grad_scaler, models)
371

372
373
        # Verify that contiguous buffers are being used.
        # - Note: this should already be checked in arguments.py.
374
        assert use_contiguous_buffers_in_local_ddp
liangjing's avatar
v1  
liangjing committed
375
376
        assert isinstance(optimizer, Adam), \
            "Only Adam currently supported, due to checkpointing requirements."
377

378
379
        # Model grad buffer ranges.
        self.model_gbuf_ranges = []
380
        for model_index, model in enumerate(self.models):
381
382
383
            self.model_gbuf_ranges.append(self.build_model_gbuf_range_map(model))
        self.model_param_gbuf_map = \
            self.build_model_param_gbuf_map(self.model_gbuf_ranges)
384

385
        # Optimizer ranges.
liangjing's avatar
v1  
liangjing committed
386
387
388
        self.model_param_group_index_map, self.opt_group_ranges = \
            self.build_optimizer_group_ranges(self.optimizer.param_groups,
                                              self.model_gbuf_ranges)
389
390

        # Allocate main param shards.
391
        (
Lawrence McAfee's avatar
Lawrence McAfee committed
392
393
            self.model_float16_groups,
            self.model_fp32_groups,
394
395
396
397
398
399
400
            self.shard_float16_groups,
            self.shard_fp32_groups,
            self.shard_fp32_from_float16_groups,
        ) = self.build_model_and_main_param_groups(self.model_gbuf_ranges,
                                                   self.model_param_gbuf_map,
                                                   self.opt_group_ranges)

401
402
403
404
405
406
407
408
        # Initialize param buffers.
        # - These are views on the DDP model's grad buffers, that share
        #   storage & have their own dtype. This is safe because the param
        #   dtype size is always <= grad dtype size.
        self.param_buffers = []
        for model_index, model in enumerate(self.models):
            current_param_buffers = {}
            for dtype, grad_buffer in model._grad_buffers.items():
liangjing's avatar
v1  
liangjing committed
409
410
411
412
413
414
415
416
417
418
419
420

                # Handle older/newer method for getting untyped storage.
                try:
                    storage = grad_buffer.data.storage()._untyped()
                except:
                    storage = grad_buffer.data.storage().untyped()

                # Typed param buffer.
                param_buffer = torch.tensor(
                    storage,
                    dtype = params_dtype,
                    device = grad_buffer.data.device)
421
422
423
424
                param_buffer = param_buffer[:grad_buffer.numel_padded]
                current_param_buffers[dtype] = param_buffer
            self.param_buffers.append(current_param_buffers)

425
426
427
428
        # Update optimizer groups.
        # - Also, leverage state_dict() and load_state_dict() to
        #   recast preexisting per-param state tensors.
        self.optimizer.param_groups = \
429
            [ g["orig_group"] for g in self.opt_group_ranges ]
430
431
        self.optimizer.load_state_dict(self.optimizer.state_dict())

432

433
    def get_model_param_range_map(self, param):
434
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
435
436
        Given a model param, get the index sub-range of the param that this
        data-parallel rank owns.
437
        """
438
439
440
441
442
443
        model_index, dtype = self.model_param_gbuf_map[param]
        gbuf_range_map = self.model_gbuf_ranges[model_index][dtype]
        param_range_map = gbuf_range_map["param_map"][param]
        return param_range_map


444
    def get_model_parallel_group(self):
445
446
447
448
        """
        With the distributed optimizer, the model parallel group is the
        entire world.
        """
449
450
        return None

451
452

    def state_dict(self):
453
        """
liangjing's avatar
v1  
liangjing committed
454
455
456
457
458
        The state dict contains all non-DP-rank-dependent (i.e., non-parameter-
        related) optimizer variables. The returned state dict can be stored in
        the standard model/RNG checkpoint file. The parameter and dependent
        optimizer state (e.g., exp_avg, exp_avg_sq) are stored in a separate
        checkpoint file by calling 'save_parameter_state()'.
459
        """
liangjing's avatar
v1  
liangjing committed
460

461
        state_dict = {}
liangjing's avatar
v1  
liangjing committed
462
463
464
465
466
467
468
469
470
471
472

        # Optimizer state (do not store parameter state here).
        state_dict['optimizer'] = {
            k : v
            for k, v in self.optimizer.state_dict().items()
            if k != "state"
        }
        for param_group in state_dict["optimizer"]["param_groups"]:
            del param_group["params"]

        # Grad scaler state.
473
474
        if self.grad_scaler:
            state_dict['grad_scaler'] = self.grad_scaler.state_dict()
liangjing's avatar
v1  
liangjing committed
475

476
477
478
        return state_dict


479
    def load_state_dict(self, state_dict):
liangjing's avatar
v1  
liangjing committed
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
        """Load the state dict.

        As detailed in state_dict(), the state dict contains all non-
        parameter-related variables. This method is notably longer than
        state_dict(), because the Torch optimizers state has yet to be
        allocated at this point, and so we must do a cross referencing between
        the optimizers state (and the ordering it expects for parameter state)
        and this DP rank's shards. The optimizer at this point does not contain
        any tensor dimension information, so we must get these dimensions from
        the DP shards mapped during DistributedOptimizer.__init__().

        The tensor parameter state is loaded via load_parameter_state(), and
        so this method also must populate the loaded state dict with dummy
        tensor data (i.e., via torch.empty() below). This will be overwritten
        during load_parameter_state().

        ** Note: Torch optimizer's state structure. **
        The Torch optimizer stores its state in two levels. The top level is a
        list of groups, where each group contains a list of integer indexes
        (corresponding to parameters) that index into a master parameter list
        that is shared by all groups. As such, three values are necessary for
        maintaining this ordering:

        - group_index : The group to which a parameter belongs.
        - group_order : The index of a parameter within its group.
        - state_order : The index of a parameter within the shared parameter
            list.
507
        """
508

liangjing's avatar
v1  
liangjing committed
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
        # Get the Torch optimizer's state dict.
        # - This 'inner' optimizer at this point is unallocated, and only
        #   contains an integer odering of parameters within each group, and
        #   the ordering of parameters within its flattened parameter state
        #   list.
        inner_state_dict = self.optimizer.state_dict()
        state_dict_param_groups = [{
            **group,
            "params" : list(inner_state_dict["param_groups"][idx]["params"]),
        } for idx, group in enumerate(state_dict["optimizer"]["param_groups"])]

        # Allocate 'dummy' data for optimizer state (i.e., torch.empty() below)
        # - Real data is overwritten during load_parameter_state().
        state_dict_state = []
        for gbuf_range_maps in self.model_gbuf_ranges:
            for gbuf_range_map in gbuf_range_maps.values():
                for model_param, param_range_map in \
                    gbuf_range_map["param_map"].items():

                    # Get parameter ordering information (see method docstring
                    # for details).
                    group_index, group_order = \
                        self.model_param_group_index_map[model_param]
                    state_order = inner_state_dict["param_groups"] \
                        [group_index]["params"][group_order]

                    # Allocate dummy tensors.
                    numel = len(param_range_map["gbuf_world"])
                    init_shard = lambda : torch.empty(
                        (numel,),
                        dtype=torch.float32,
                        device=torch.cuda.current_device())

                    state_dict_state.append((state_order, {
                        "exp_avg" : init_shard(),
                        "exp_avg_sq" : init_shard(),
                    }))

        # Sort by state order (see method docstring for details).
        state_dict_state.sort(key = lambda s : s[0])
        state_dict_state = {s[0]:s[1] for s in state_dict_state}

551
        # Optimizer.
liangjing's avatar
v1  
liangjing committed
552
553
554
555
        self.optimizer.load_state_dict({
            "state" : state_dict_state,
            "param_groups" : state_dict_param_groups,
        })
556
557
558

        # Grad scaler.
        if 'grad_scaler' not in state_dict:
559
560
561
            if self.fp16:
                print_rank_0('***WARNING*** found an old checkpoint, will not '
                             'load grad scaler ...')
562
563
564
565
566
567
568
569
        else:
            if self.grad_scaler:
                self.grad_scaler.load_state_dict(state_dict['grad_scaler'])
            else:
                print_rank_0('***WARNING*** fould the grad scaler in the '
                             'checkpoint but it is None in the class. '
                             'Skipping loading grad scaler ...')

liangjing's avatar
v1  
liangjing committed
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743

    def save_parameter_state(self, filename):
        """Save parameter state (i.e., parameter & optimizer tensors).

        This method performs three steps:
        - For each DP rank, copy param & optimizer shards to contiguous CPU
          buffers. (e.g., one buffer each for main_param, exp_avg, and
          exp_avg_sq).
        - Gather contiguous buffers on DP rank 0 and concatenate to world
          buffers.
        - Save world buffers to disk (i.e., distrib_opt.pt).
        """

        # Data parallelism variables.
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group_gloo = mpu.get_data_parallel_group_gloo()
        data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS)

        # Collect param states.
        state = {}
        for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges):

            # Iterate grad buffers (by data type).
            dtype_state = {}
            assert len(gbuf_range_maps) == 1, "single dtype supported, for now."
            for dtype, gbuf_range_map in gbuf_range_maps.items():

                # Compute local DP contiguous shard's size.
                model = self.models[model_idx]
                gbuf_world_numel = model._grad_buffers[dtype].numel_padded
                gbuf_local_numel = int(gbuf_world_numel/data_parallel_world_size)
                local_shards = {key:torch.empty((gbuf_local_numel,),
                                             dtype=torch.float32,
                                             device="cpu")
                             for key in ("param", "exp_avg", "exp_avg_sq")}

                # Build contiguous DP rank shards (for param + optim states).
                for model_param, param_range_map in \
                    gbuf_range_map["param_map"].items():

                    # Main param & optimizer states.
                    group_index, group_order = \
                        self.model_param_group_index_map[model_param]
                    main_param = self.optimizer.param_groups \
                        [group_index]["params"][group_order]
                    optim_state = self.optimizer.state[main_param]

                    tensors = {
                        "param" : main_param,
                        **optim_state,
                    }

                    # Copy states into contiguous shard.
                    gbuf_local_start = param_range_map["gbuf_local"].start
                    gbuf_local_end = param_range_map["gbuf_local"].end
                    for key in local_shards:
                        local_shards[key][gbuf_local_start:gbuf_local_end] \
                            .data.copy_(tensors[key].detach().cpu())

                # Gather contiguous shards on DP rank 0.
                world_tensors = {}
                for key, send_tensor in local_shards.items():

                    # Gather tensor list.
                    if data_parallel_rank == 0:
                        recv_tensors = [torch.empty((gbuf_local_numel,),
                                                    dtype=torch.float32,
                                                    device="cpu")
                                        for _ in range(data_parallel_world_size)]
                    else:
                        recv_tensors = None

                    # Gather.
                    torch.distributed.gather(
                        send_tensor,
                        recv_tensors,
                        data_parallel_global_ranks[0],
                        data_parallel_group_gloo,
                    )

                    # Concatenate.
                    if data_parallel_rank == 0:
                        world_tensors[key] = torch.cat(recv_tensors)

                # Collect world state.
                dtype_state[dtype] = world_tensors
            state[model_idx] = dtype_state

        # Save param state.
        if data_parallel_rank == 0:
            torch.save(state, filename)


    def load_parameter_state(self, filename):
        """Load parameter state (i.e., parameter & optimizer tensors).

        This method performs the reverse of save_parameter_state():
        - Load world buffers from disk (i.e., distrib_opt.pt).
        - Scatter contiguous buffers from DP rank 0 to each DP rank (each DP
          rank receives its relevant subset of the world buffers).
        - For each DP rank, copy param & optimizer shards from contiguous CPU
          buffers. (e.g., one buffer each for main_param, exp_avg, and
          exp_avg_sq).
        """

        # Data parallelism variables.
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group_gloo = mpu.get_data_parallel_group_gloo()
        data_parallel_global_ranks = list(mpu._DATA_PARALLEL_GLOBAL_RANKS)

        # Load on DP rank 0.
        if data_parallel_rank == 0:
            loaded_state = torch.load(filename)

        # Scatter tensors to all DP ranks.
        for model_idx, gbuf_range_maps in enumerate(self.model_gbuf_ranges):
            for dtype, gbuf_range_map in gbuf_range_maps.items():

                # Compute local DP contiguous shard's size.
                model = self.models[model_idx]
                gbuf_world_numel = model._grad_buffers[dtype].numel_padded
                gbuf_local_numel = int(gbuf_world_numel/data_parallel_world_size)

                # Contiguous local shards (received from DP rank 0).
                local_shards = {key:torch.empty((gbuf_local_numel,),
                                                dtype=torch.float32,
                                                device="cpu")
                                for key in ("param", "exp_avg", "exp_avg_sq")}

                # Scatter local shards from DP rank 0.
                for key, recv_tensor in local_shards.items():

                    # Scatter tensor list.
                    if data_parallel_rank == 0:
                        world_tensor = loaded_state[model_idx][dtype][key]
                        gbuf_start_idxs = \
                            list(range(0, gbuf_world_numel, gbuf_local_numel))
                        send_tensors = [world_tensor[i:(i+gbuf_local_numel)]
                                        for i in gbuf_start_idxs]
                    else:
                        send_tensors = None

                    # Scatter.
                    torch.distributed.scatter(
                        recv_tensor,
                        send_tensors,
                        data_parallel_global_ranks[0],
                        data_parallel_group_gloo,
                    )

                # Copy local contiguous shards to param/optim shards.
                for model_param, param_range_map in \
                    gbuf_range_map["param_map"].items():

                    # Main param & optimizer states.
                    group_index, group_order = \
                        self.model_param_group_index_map[model_param]
                    main_param = self.optimizer.param_groups \
                        [group_index]["params"][group_order]
                    optim_state = self.optimizer.state[main_param]

                    tensors = {
                        "param" : main_param,
                        **optim_state,
                    }

                    # Copy states into contiguous shard.
                    gbuf_local_start = param_range_map["gbuf_local"].start
                    gbuf_local_end = param_range_map["gbuf_local"].end
                    for key in local_shards:
                        tensors[key].data.copy_(
                            local_shards[key][gbuf_local_start:gbuf_local_end])
Lawrence McAfee's avatar
Lawrence McAfee committed
744

745

746
    def zero_grad(self, set_to_none=True):
747
748
749
750
751
752
        """
        Zero grads.

        We only need to zero the model related parameters, i.e.,
        model_float16_groups & model_fp32_groups. We additionally zero
        the remaining groups as a memory optimization to reduce
753
        fragmentation; in the case of set_to_none==True, the space
754
755
        used by this field can be safely deallocated at this point.
        """
756
        for groups in (
Lawrence McAfee's avatar
Lawrence McAfee committed
757
758
                self.model_float16_groups,
                self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
759
                self.shard_float16_groups, # grad empty/unused here?
760
                self.shard_fp32_groups, # throws grad-access warning
761
762
763
                self.shard_fp32_from_float16_groups):
            for group in groups:
                _zero_grad_group_helper(group, set_to_none)
764

765

766
767
    @staticmethod
    def get_model_buffer_dp_views(model_buffers):
768
        """
769
        Get shard views of each of the DDP's param/grad buffers.
770
771

        In this nested list, the top level is grouped by the virtual model
772
773
774
        index and the buffer's data type. The sub-level is a list of
        shards of that buffer, where each shard in the list represents
        a contiguous view of the buffer, that is owned by a data-parallel
775
776
777
778
        rank. The shard boundary does not respect parameter boundaries, and
        so the elements of some parameters are split across data parallel
        ranks.

779
        Additionally, return references to the entire buffers, for use
780
781
        in _reduce_scatter_base and _all_gather_base.
        """
782
783
784

        data_parallel_world_size = mpu.get_data_parallel_world_size()

785
786
787
788
789
790
791
792
793
794
        # Buffer views.
        view_items = []
        for model_index, buffers in enumerate(model_buffers):
            for dtype, buf in buffers.items():

                assert buf.numel() % data_parallel_world_size == 0
                shard_size = int(buf.numel() / data_parallel_world_size)
                buf_views = [buf[(r*shard_size):((r+1)*shard_size)]
                             for r in range(data_parallel_world_size)]
                view_items.append((model_index, dtype, buf, buf_views))
795

796
        return view_items
797

798
799
800
801
802
803
804
805
806
807

    def get_model_grad_buffer_dp_views(self):
        return self.get_model_buffer_dp_views([
            {dtype : mem_buffer.data}
            for model in self.models
            for dtype, mem_buffer in model._grad_buffers.items()])


    def get_model_param_buffer_dp_views(self):
        return self.get_model_buffer_dp_views(self.param_buffers)
808

Lawrence McAfee's avatar
Lawrence McAfee committed
809

810
    def reduce_model_grads(self, args, timers):
811
        """
812
813
814
815
816
        Reduce-scatter model grads.

        The DDP's grad buffer is used for the reduce-scatter, and thus no
        tensors are dynamically allocated.

817
        Note: this is a different order of reduction, versus the non-
818
819
        distributed optimizer, which reduces: 1) layernorm grads, 2) all
        grads, 3) embedding grads.
820
        """
821

822
        # All-reduce layer-norm grads (for sequence parallelism).
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
823
824
        timers('layernorm-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
825
        self.allreduce_layernorm_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
826
        timers('layernorm-grads-all-reduce').stop()
827

828
        # All-reduce embedding grads.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
829
830
        timers('embedding-grads-all-reduce', log_level=1).start(
            barrier=args.barrier_with_L1_time)
831
        self.allreduce_embedding_grads(args)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
832
        timers('embedding-grads-all-reduce').stop()
833

834
        # Reduce-scatter setup.
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
835
836
        timers('grads-reduce-scatter', log_level=1).start(
            barrier=args.barrier_with_L1_time)
837
838
839
840
        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_world_size = mpu.get_data_parallel_world_size()
        data_parallel_group = mpu.get_data_parallel_group()

841
842
843
844
845
        # Scale grad buffers by '1 / data_parallel_world_size'.
        for model in self.models:
            for dtype, gbuf in model._grad_buffers.items():
                gbuf.data /= data_parallel_world_size

846
        # Reduce-scatter all grads.
847
        gbuf_view_items = self.get_model_grad_buffer_dp_views()
848
849
850
        for index, (model_index, dtype, gbuf, gbuf_views) \
            in enumerate(gbuf_view_items):

851
            torch.distributed._reduce_scatter_base(
852
                gbuf_views[data_parallel_rank],
853
                gbuf,
854
855
                group = data_parallel_group,
            )
856

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
857
        timers('grads-reduce-scatter').stop()
858

Lawrence McAfee's avatar
Lawrence McAfee committed
859

liangjing's avatar
v1  
liangjing committed
860

861
    def gather_model_params(self, args, timers):
862
863
864
        """
        All-gather updated model params.

865
        The DDP's param buffer is used for the all-gather, and thus no
866
        tensors are dynamically allocated. After the all-gather, the params
867
        can be copied from the param buffer to the param.
868
        """
869

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
870
871
        timers('params-all-gather', log_level=1).start(
            barrier=args.barrier_with_L1_time)
872
873
874
875
876

        data_parallel_rank = mpu.get_data_parallel_rank()
        data_parallel_group = mpu.get_data_parallel_group()

        # All-gather updated main params.
877
878
879
880
881
882
883
884
        # - All param buffer views are guaranteed to have the same num elements
        #   across all data parallel ranks, due to grad buffer padding that is
        #   done in distributed.py, and extended to the param buffers. Thus,
        #   all sub-views will have consistent start/end indexes across data
        #   parallel ranks.
        pbuf_view_items = self.get_model_param_buffer_dp_views()
        for index, (model_index, dtype, pbuf, pbuf_views) \
            in enumerate(pbuf_view_items):
885

886
            torch.distributed._all_gather_base(
887
888
                pbuf,
                pbuf_views[data_parallel_rank],
889
                group = data_parallel_group,
890
            )
891

892
893
        # Copy from param buffer to each param.
        for model_id, model in enumerate(self.models):
894
            for dtype, param_map in model._grad_buffer_param_index_map.items():
liangjing's avatar
v1  
liangjing committed
895
                for param, (buf_start, buf_end) in param_map.items():
896
                    param_buf = self.param_buffers[model_id][dtype]
liangjing's avatar
v1  
liangjing committed
897
                    param_buf_shard = param_buf[buf_start:buf_end]
898
                    param.view(-1).detach().copy_(param_buf_shard)
899

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
900
        timers('params-all-gather').stop()
901

Lawrence McAfee's avatar
Lawrence McAfee committed
902

903
    def _collect_main_grad_data_for_unscaling(self):
904
905
906
907
        """
        Note: this should be equivalent to the float-16 optimizer's method,
        but writtent differently, so the two should be combined.
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
908
        return [
Lawrence McAfee's avatar
Lawrence McAfee committed
909
910
911
912
            param.grad.data
            for group in self.optimizer.param_groups
            for param in group["params"]
        ]
913
914


Lawrence McAfee's avatar
Lawrence McAfee committed
915
    def _get_model_and_main_params_data_float16(self):
916
917
918
        """
        Get aligned list of model and main params.
        """
Lawrence McAfee's avatar
Lawrence McAfee committed
919
920
921
922
923
924
925
926
        model_data = []
        main_data = []
        for model_group, main_group in zip(self.shard_float16_groups,
                                           self.shard_fp32_from_float16_groups):
            for model_param, main_param in zip(model_group, main_group):
                model_data.append(model_param.data)
                main_data.append(main_param.data)
        return model_data, main_data
927
928


929
    def _copy_model_grads_to_main_grads(self):
930
931
932
933
934
935
936
        """
        Copy model grads to main grads.

        Since this step follows a reduce-scatter through the DDP's grad
        buffer, this method is responsible for copying the updated grads
        from the grad buffer to the main shard's grad field.
        """
937

938
        # Utility method for copying group grads.
Lawrence McAfee's avatar
Lawrence McAfee committed
939
940
941
942
943
        def copy_group_grads(model_groups, shard_main_groups):
            for model_group, shard_main_group in zip(model_groups,
                                                     shard_main_groups):
                for model_param, shard_main_param in zip(model_group,
                                                         shard_main_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
944

Lawrence McAfee's avatar
Lawrence McAfee committed
945
                    param_range_map = self.get_model_param_range_map(model_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
946
                    param_range = param_range_map["param"]
Lawrence McAfee's avatar
Lawrence McAfee committed
947
948
                    assert param_range.size == shard_main_param.nelement()

Lawrence McAfee's avatar
Lawrence McAfee committed
949
950
                    model_grad = model_param.main_grad
                    shard_model_grad = model_grad.view(-1) \
Lawrence McAfee's avatar
Lawrence McAfee committed
951
                        [param_range.start:param_range.end]
Lawrence McAfee's avatar
Lawrence McAfee committed
952
953
                    shard_main_param.grad = shard_model_grad.float()

954
        # Copy model groups to shard groups.
Lawrence McAfee's avatar
Lawrence McAfee committed
955
        copy_group_grads(self.model_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
956
                         self.shard_fp32_from_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
957
        copy_group_grads(self.model_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
958
                         self.shard_fp32_groups)
959

960
961

    def _copy_main_params_to_model_params(self):
962
963
964
965
966
967
968
        """
        Copy main params to model params.

        Since this step is followed by an all-gather through the DDP's grad
        buffer, this method is responsible for copying the updated params
        from the main shards into the correct position in the grad buffer.
        """
969

970
        # Utility method for copying group params.
Lawrence McAfee's avatar
Lawrence McAfee committed
971
972
973
974
975
        def copy_group_params(shard_main_groups, model_groups):
            for shard_main_group, model_group in zip(shard_main_groups,
                                                     model_groups):
                for shard_main_param, model_param in zip(shard_main_group,
                                                         model_group):
Lawrence McAfee's avatar
Lawrence McAfee committed
976

Lawrence McAfee's avatar
Lawrence McAfee committed
977
                    param_range_map = self.get_model_param_range_map(model_param)
978
                    world_range = param_range_map["gbuf_world"]
Lawrence McAfee's avatar
Lawrence McAfee committed
979

980
981
982
983
984
985
986
                    assert world_range.size == shard_main_param.nelement()

                    model_id, dtype = self.model_param_gbuf_map[model_param]
                    model_param_buffer = self.param_buffers[model_id][dtype]

                    shard_model_param = model_param_buffer.view(-1) \
                        [world_range.start:world_range.end]
987

988
                    shard_model_param.data.copy_(shard_main_param)
Lawrence McAfee's avatar
Lawrence McAfee committed
989

990
        # Copy shard groups to model groups.
Lawrence McAfee's avatar
Lawrence McAfee committed
991
        copy_group_params(self.shard_fp32_from_float16_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
992
                          self.model_float16_groups)
Lawrence McAfee's avatar
Lawrence McAfee committed
993
        copy_group_params(self.shard_fp32_groups,
Lawrence McAfee's avatar
Lawrence McAfee committed
994
                          self.model_fp32_groups)
liangjing's avatar
v1  
liangjing committed
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025


    def _copy_model_params_to_main_params(self):
        """
        Copy model params to main params.

        During finetuning, this method is used to reload the main params from
        the model params. This copy does not make use of the grad buffer as
        an intermediary.
        """

        # Utility method for copying group params.
        def copy_group_params(model_groups, shard_main_groups):
            for model_group, shard_main_group in zip(model_groups,
                                                     shard_main_groups):
                for model_param, shard_main_param in zip(model_group,
                                                         shard_main_group):

                    param_range_map = self.get_model_param_range_map(model_param)
                    param_range = param_range_map["param"]
                    assert param_range.size == shard_main_param.nelement()

                    shard_model_param = model_param.view(-1) \
                        [param_range.start:param_range.end]
                    shard_main_param.data.copy_(shard_model_param)

        # Copy model groups to shard groups.
        copy_group_params(self.model_float16_groups,
                          self.shard_fp32_from_float16_groups)
        copy_group_params(self.model_fp32_groups,
                          self.shard_fp32_groups)