initialize.py 21.9 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8
9
10


"""Model and data parallel groups."""

import torch

from .utils import ensure_divisibility


11
# Intra-layer model parallel group that the current rank belongs to.
12
_TENSOR_MODEL_PARALLEL_GROUP = None
13
# Inter-layer model parallel group that the current rank belongs to.
14
15
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
16
_MODEL_PARALLEL_GROUP = None
17
18
# Embedding group.
_EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
19
# Position embedding group.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
20
_POSITION_EMBEDDING_GROUP = None
21
22
23
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

24
25
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
26
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
27

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
28
# These values enable us to change the mpu sizes on the fly.
29
30
31
32
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
33

34
35
36
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None

Vijay Korthikanti's avatar
Vijay Korthikanti committed
37
38
39
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None

40
# A list of global ranks for each pipeline group to ease calculation of the source
41
# rank when broadcasting from the first or last pipeline stage.
42
_PIPELINE_GLOBAL_RANKS = None
43

44
45
46
47
48
49
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None



50
51
52
53
54
def is_unitialized():
    """Useful for code segments that may be accessed with or without mpu initialization"""
    return _DATA_PARALLEL_GROUP is None


55
def initialize_model_parallel(tensor_model_parallel_size_=1,
56
                              pipeline_model_parallel_size_=1,
57
58
                              virtual_pipeline_model_parallel_size_=None,
                              pipeline_model_parallel_split_rank_=None):
59
60
61
62
    """
    Initialize model data parallel groups.

    Arguments:
63
64
65
66
67
68
69
        tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
        virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
                                              pipeline).
        pipeline_model_parallel_split_rank: for models with both encoder and decoder,
                                            rank in pipeline with split point.

70
71

    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
72
73
74
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
75
76
77
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
78
        8 tensor model-parallel groups:
79
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
80
        4 pipeline model-parallel groups:
81
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
82
83
84
85
86
87
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    if torch.distributed.get_rank() == 0:
88
89
90
91
        print('> initializing tensor model parallel with size {}'.format(
            tensor_model_parallel_size_))
        print('> initializing pipeline model parallel with size {}'.format(
            pipeline_model_parallel_size_))
92
93
94
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size = torch.distributed.get_world_size()
95
96
    tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
    pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
97
    ensure_divisibility(world_size,
98
99
100
                        tensor_model_parallel_size * pipeline_model_parallel_size)
    data_parallel_size = world_size // (tensor_model_parallel_size *
                                        pipeline_model_parallel_size)
101

102
103
    num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
    num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
104
105
    num_data_parallel_groups = world_size // data_parallel_size

106
107
108
109
110
111
    if virtual_pipeline_model_parallel_size_ is not None:
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_

112
113
114
115
    if pipeline_model_parallel_split_rank_ is not None:
        global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
        _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_

116
117
    rank = torch.distributed.get_rank()

118
    # Build the data-parallel groups.
119
    global _DATA_PARALLEL_GROUP
120
    global _DATA_PARALLEL_GLOBAL_RANKS
121
122
    assert _DATA_PARALLEL_GROUP is None, \
        'data parallel group is already initialized'
123
    all_data_parallel_group_ranks = []
124
125
126
127
    for i in range(pipeline_model_parallel_size):
        start_rank = i * num_pipeline_model_parallel_groups
        end_rank = (i + 1) * num_pipeline_model_parallel_groups
        for j in range(tensor_model_parallel_size):
128
            ranks = range(start_rank + j, end_rank,
129
                          tensor_model_parallel_size)
130
131
132
133
            all_data_parallel_group_ranks.append(list(ranks))
            group = torch.distributed.new_group(ranks)
            if rank in ranks:
                _DATA_PARALLEL_GROUP = group
134
                _DATA_PARALLEL_GLOBAL_RANKS = ranks
135
136

    # Build the model-parallel groups.
137
138
139
    global _MODEL_PARALLEL_GROUP
    assert _MODEL_PARALLEL_GROUP is None, \
        'model parallel group is already initialized'
140
141
142
    for i in range(data_parallel_size):
        ranks = [data_parallel_group_ranks[i]
                 for data_parallel_group_ranks in all_data_parallel_group_ranks]
143
        group = torch.distributed.new_group(ranks)
144
        if rank in ranks:
145
146
            _MODEL_PARALLEL_GROUP = group

147
148
149
150
151
152
153
    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
    assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
        'tensor model parallel group is already initialized'
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)
154
155
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
156
            _TENSOR_MODEL_PARALLEL_GROUP = group
157

158
159
160
    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).
    global _PIPELINE_MODEL_PARALLEL_GROUP
161
    global _PIPELINE_GLOBAL_RANKS
162
163
    assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
        'pipeline model parallel group is already initialized'
164
    global _EMBEDDING_GROUP
165
    global _EMBEDDING_GLOBAL_RANKS
166
167
    assert _EMBEDDING_GROUP is None, \
        'embedding group is already initialized'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
168
169
170
171
    global _POSITION_EMBEDDING_GROUP
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    assert _POSITION_EMBEDDING_GROUP is None, \
        'position embedding group is already initialized'
172
    for i in range(num_pipeline_model_parallel_groups):
173
        ranks = range(i, world_size,
174
                      num_pipeline_model_parallel_groups)
175
176
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
177
            _PIPELINE_MODEL_PARALLEL_GROUP = group
178
            _PIPELINE_GLOBAL_RANKS = ranks
179
180
181
182
        # Setup embedding group (to exchange gradients between
        # first and last stages).
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
183
184
185
186
187
188
189
190
191
            position_embedding_ranks = [ranks[0]]
            if pipeline_model_parallel_split_rank_ is not None:
                if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
                    embedding_ranks = [ranks[0],
                                       ranks[pipeline_model_parallel_split_rank_],
                                       ranks[-1]]
                if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks:
                    position_embedding_ranks = [ranks[0],
                                       ranks[pipeline_model_parallel_split_rank_]]
192
193
        else:
            embedding_ranks = ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
194
195
            position_embedding_ranks = ranks

196
197
198
        group = torch.distributed.new_group(embedding_ranks)
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group
199
200
        if rank in ranks:
            _EMBEDDING_GLOBAL_RANKS = embedding_ranks
201

Vijay Korthikanti's avatar
Vijay Korthikanti committed
202
203
204
205
        group = torch.distributed.new_group(position_embedding_ranks)
        if rank in position_embedding_ranks:
            _POSITION_EMBEDDING_GROUP = group
        if rank in ranks:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
206
            _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
207

208
209
210

def model_parallel_is_initialized():
    """Check if model and data parallel groups are initialized."""
211
212
    if _TENSOR_MODEL_PARALLEL_GROUP is None or \
        _PIPELINE_MODEL_PARALLEL_GROUP is None or \
213
        _DATA_PARALLEL_GROUP is None:
214
215
216
217
218
219
220
221
222
223
224
        return False
    return True


def get_model_parallel_group():
    """Get the model parallel group the caller rank belongs to."""
    assert _MODEL_PARALLEL_GROUP is not None, \
        'model parallel group is not initialized'
    return _MODEL_PARALLEL_GROUP


225
226
227
def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
    assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
228
        'intra_layer_model parallel group is not initialized'
229
    return _TENSOR_MODEL_PARALLEL_GROUP
230
231


232
233
234
235
236
def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
        'pipeline_model parallel group is not initialized'
    return _PIPELINE_MODEL_PARALLEL_GROUP
237
238


239
240
241
242
243
244
245
def get_data_parallel_group():
    """Get the data parallel group the caller rank belongs to."""
    assert _DATA_PARALLEL_GROUP is not None, \
        'data parallel group is not initialized'
    return _DATA_PARALLEL_GROUP


246
247
248
249
250
251
252
def get_embedding_group():
    """Get the embedding group the caller rank belongs to."""
    assert _EMBEDDING_GROUP is not None, \
        'embedding group is not initialized'
    return _EMBEDDING_GROUP


Vijay Korthikanti's avatar
Vijay Korthikanti committed
253
254
255
256
257
258
259
def get_position_embedding_group():
    """Get the position embedding group the caller rank belongs to."""
    assert _POSITION_EMBEDDING_GROUP is not None, \
        'position embedding group is not initialized'
    return _POSITION_EMBEDDING_GROUP


260
261
262
263
def set_tensor_model_parallel_world_size(world_size):
    """Set the tensor model parallel size"""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
264
265


266
267
268
269
def set_pipeline_model_parallel_world_size(world_size):
    """Set the pipeline model parallel size"""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
270
271


272
273
274
275
276
277
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
278
279


280
281
282
283
284
285
def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
286
287


288
289
290
291
def set_tensor_model_parallel_rank(rank):
    """Set tensor model parallel rank."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
292
293


294
295
296
297
def set_pipeline_model_parallel_rank(rank):
    """Set pipeline model parallel rank."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
298
299


300
301
302
303
304
305
def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
306
307


308
309
310
311
312
313
def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
314
315


316
def get_num_layers(args, is_encoder_and_decoder_model, is_decoder=False):
317
318
319
320
    """Compute the number of transformer layers resident on the current rank."""
    if get_pipeline_model_parallel_world_size() > 1:
        if is_encoder_and_decoder_model:
            assert args.pipeline_model_parallel_split_rank is not None
Lawrence McAfee's avatar
Lawrence McAfee committed
321
322
323
324
325

            # When a standalone embedding stage is used, a rank is taken from
            # the encoder's ranks, to be used for the encoder's embedding
            # layer. This way, the rank referenced by the 'split rank' remains
            # the same whether or not a standalone embedding stage is used.
326
327
            num_ranks_in_encoder = (
                args.pipeline_model_parallel_split_rank - 1
328
                if args.standalone_embedding_stage else
329
330
331
                args.pipeline_model_parallel_split_rank
            )
            num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
332
            assert args.encoder_num_layers % num_ranks_in_encoder == 0, \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
333
                    'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.encoder_num_layers, num_ranks_in_encoder)
334
            assert args.decoder_num_layers % num_ranks_in_decoder == 0, \
Vijay Korthikanti's avatar
Vijay Korthikanti committed
335
                    'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.decoder_num_layers, num_ranks_in_decoder)
Lawrence McAfee's avatar
Lawrence McAfee committed
336
            if is_pipeline_stage_before_split():
Lawrence McAfee's avatar
Lawrence McAfee committed
337
338
339
340
                num_layers = (
                    0
                    if args.standalone_embedding_stage
                    and get_pipeline_model_parallel_rank() == 0 else
341
                    args.encoder_num_layers // num_ranks_in_encoder
Lawrence McAfee's avatar
Lawrence McAfee committed
342
                )
343
            else:
344
                num_layers = args.decoder_num_layers // num_ranks_in_decoder
345
        else:
346
            assert args.num_layers == args.encoder_num_layers
Lawrence McAfee's avatar
Lawrence McAfee committed
347
348
            assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
                'num_layers must be divisible by transformer_pipeline_model_parallel_size'
Lawrence McAfee's avatar
Lawrence McAfee committed
349
350
351
352
353

            # When a standalone embedding stage is used, all transformer layers
            # are divided among pipeline rank >= 1, while on pipeline rank 0,
            # ranks either contain the input embedding layer (virtual pp rank 0),
            # or no layers at all (virtual pp rank >= 1).
Lawrence McAfee's avatar
Lawrence McAfee committed
354
355
            num_layers = (
                0
356
                if args.standalone_embedding_stage
Lawrence McAfee's avatar
Lawrence McAfee committed
357
                and get_pipeline_model_parallel_rank() == 0 else
Lawrence McAfee's avatar
Lawrence McAfee committed
358
                args.num_layers // args.transformer_pipeline_model_parallel_size
Lawrence McAfee's avatar
Lawrence McAfee committed
359
            )
360
    else:
361
362
363
364
        if not is_decoder:
            num_layers = args.encoder_num_layers
        else:
            num_layers = args.decoder_num_layers
365
366
367
    return num_layers


368
def is_pipeline_first_stage(ignore_virtual=False):
369
    """Return True if in the first pipeline model-parallel stage, False otherwise."""
370
    if not ignore_virtual:
371
372
        if get_virtual_pipeline_model_parallel_world_size() is not None and \
            get_virtual_pipeline_model_parallel_rank() != 0:
373
            return False
374
    return get_pipeline_model_parallel_rank() == 0
375
376


377
def is_pipeline_last_stage(ignore_virtual=False):
378
    """Return True if in the last pipeline model-parallel stage, False otherwise."""
379
    if not ignore_virtual:
380
381
382
383
384
        virtual_pipeline_model_parallel_world_size = \
            get_virtual_pipeline_model_parallel_world_size()
        if virtual_pipeline_model_parallel_world_size is not None and \
            get_virtual_pipeline_model_parallel_rank() != (
                virtual_pipeline_model_parallel_world_size - 1):
385
            return False
386
387
    return get_pipeline_model_parallel_rank() == (
        get_pipeline_model_parallel_world_size() - 1)
388
389


390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
def is_rank_in_embedding_group(ignore_virtual=False):
    """Return true if current rank is in embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _EMBEDDING_GLOBAL_RANKS
    if ignore_virtual:
        return rank in _EMBEDDING_GLOBAL_RANKS
    if rank in _EMBEDDING_GLOBAL_RANKS:
        if rank == _EMBEDDING_GLOBAL_RANKS[0]:
            return is_pipeline_first_stage(ignore_virtual=False)
        elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
            return is_pipeline_last_stage(ignore_virtual=False)
        else:
            return True
    return False


Vijay Korthikanti's avatar
Vijay Korthikanti committed
406
407
408
409
410
411
412
def is_rank_in_position_embedding_group():
    """Return true if current rank is in position embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    return rank in _POSITION_EMBEDDING_GLOBAL_RANKS


413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
def is_pipeline_stage_before_split(rank=None):
    """Return True if pipeline stage executes encoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_after_split(rank=None):
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_at_split():
    """Return true if pipeline stage executes decoder block and next
    stage executes encoder block for a model with both encoder and
    decoder."""
    rank = get_pipeline_model_parallel_rank()
    return is_pipeline_stage_before_split(rank) and \
            is_pipeline_stage_after_split(rank+1)


452
453
454
455
456
457
458
459
460
461
462
463
def get_virtual_pipeline_model_parallel_rank():
    """Return the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK


def set_virtual_pipeline_model_parallel_rank(rank):
    """Set the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank


464
465
466
467
468
469
def get_virtual_pipeline_model_parallel_world_size():
    """Return the virtual pipeline-parallel world size."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE


470
def get_tensor_model_parallel_src_rank():
471
    """Calculate the global rank corresponding to the first local rank
472
    in the tensor model parallel group."""
473
    global_rank = torch.distributed.get_rank()
474
    local_world_size = get_tensor_model_parallel_world_size()
475
476
    return (global_rank // local_world_size) * local_world_size

477

478
479
def get_data_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
480
481
482
483
    in the data parallel group."""
    assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
        "Data parallel group is not initialized"
    return _DATA_PARALLEL_GLOBAL_RANKS[0]
484
485


486
487
488
489
490
def get_pipeline_model_parallel_first_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    return _PIPELINE_GLOBAL_RANKS[0]

491

492
493
494
495
496
def get_pipeline_model_parallel_last_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]
497

498
def get_pipeline_model_parallel_next_rank():
499
500
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
501
502
503
504
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]

505

506
507
508
509
510
511
def get_pipeline_model_parallel_prev_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
512

513

514
515
516
517
518
519
520
521
522
523
524
525
def get_data_parallel_world_size():
    """Return world size for the data parallel group."""
    return torch.distributed.get_world_size(group=get_data_parallel_group())


def get_data_parallel_rank():
    """Return my rank for the data parallel group."""
    return torch.distributed.get_rank(group=get_data_parallel_group())


def destroy_model_parallel():
    """Set the groups to none."""
526
527
    global _MODEL_PARALLEL_GROUP
    _MODEL_PARALLEL_GROUP = None
528
529
530
531
    global _TENSOR_MODEL_PARALLEL_GROUP
    _TENSOR_MODEL_PARALLEL_GROUP = None
    global _PIPELINE_MODEL_PARALLEL_GROUP
    _PIPELINE_MODEL_PARALLEL_GROUP = None
532
533
    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
534
535
    global _EMBEDDING_GROUP
    _EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
536
537
    global _POSITION_EMBEDDING_GROUP
    _POSITION_EMBEDDING_GROUP = None