parallel_state.py 21.9 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5

"""Model and data parallel groups."""

import torch
6
from typing import Optional
7

8
# Intra-layer model parallel group that the current rank belongs to.
9
_TENSOR_MODEL_PARALLEL_GROUP = None
10
# Inter-layer model parallel group that the current rank belongs to.
11
12
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
13
_MODEL_PARALLEL_GROUP = None
14
15
# Embedding group.
_EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
16
# Position embedding group.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
17
_POSITION_EMBEDDING_GROUP = None
18
19
20
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

21
22
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
23
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
24

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
25
# These values enable us to change the mpu sizes on the fly.
26
27
28
29
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
30

31
32
33
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None

Vijay Korthikanti's avatar
Vijay Korthikanti committed
34
35
36
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None

37
# A list of global ranks for each pipeline group to ease calculation of the source
38
# rank when broadcasting from the first or last pipeline stage.
39
_PIPELINE_GLOBAL_RANKS = None
40

41
42
43
44
45
46
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None



47
48
49
50
51
def is_unitialized():
    """Useful for code segments that may be accessed with or without mpu initialization"""
    return _DATA_PARALLEL_GROUP is None


52
53
54
55
56
57
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    pipeline_model_parallel_split_rank: Optional[int] = None,
) -> None:
58
59
60
61
    """
    Initialize model data parallel groups.

    Arguments:
62
63
64
65
66
67
68
        tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
        virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
                                              pipeline).
        pipeline_model_parallel_split_rank: for models with both encoder and decoder,
                                            rank in pipeline with split point.

69
    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
70
71
72
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
73
74
75
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
76
        8 tensor model-parallel groups:
77
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
78
        4 pipeline model-parallel groups:
79
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
80
81
82
83
84
85
86
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
    world_size: int = torch.distributed.get_world_size()

    if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
        raise RuntimeError(
            f"world_size ({world_size}) is not divisible by tensor_model_parallel_size ({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
        )

    data_parallel_size: int = world_size // (tensor_model_parallel_size *
                                             pipeline_model_parallel_size)

    num_tensor_model_parallel_groups: int  = world_size // tensor_model_parallel_size
    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
    num_data_parallel_groups: int = world_size // data_parallel_size

    if virtual_pipeline_model_parallel_size is not None:
102
103
104
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
105
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
106

107
    if pipeline_model_parallel_split_rank is not None:
108
        global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
109
        _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
110

111
112
    rank = torch.distributed.get_rank()

113
    # Build the data-parallel groups.
114
    global _DATA_PARALLEL_GROUP
115
    global _DATA_PARALLEL_GLOBAL_RANKS
116
    assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
117
    all_data_parallel_group_ranks = []
118
119
120
121
    for i in range(pipeline_model_parallel_size):
        start_rank = i * num_pipeline_model_parallel_groups
        end_rank = (i + 1) * num_pipeline_model_parallel_groups
        for j in range(tensor_model_parallel_size):
122
            ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
123
124
125
126
            all_data_parallel_group_ranks.append(list(ranks))
            group = torch.distributed.new_group(ranks)
            if rank in ranks:
                _DATA_PARALLEL_GROUP = group
127
                _DATA_PARALLEL_GLOBAL_RANKS = ranks
128
129

    # Build the model-parallel groups.
130
    global _MODEL_PARALLEL_GROUP
131
    assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
132
133
134
    for i in range(data_parallel_size):
        ranks = [data_parallel_group_ranks[i]
                 for data_parallel_group_ranks in all_data_parallel_group_ranks]
135
        group = torch.distributed.new_group(ranks)
136
        if rank in ranks:
137
138
            _MODEL_PARALLEL_GROUP = group

139
140
141
142
143
144
145
    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
    assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
        'tensor model parallel group is already initialized'
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)
146
147
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
148
            _TENSOR_MODEL_PARALLEL_GROUP = group
149

150
151
152
    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).
    global _PIPELINE_MODEL_PARALLEL_GROUP
153
    global _PIPELINE_GLOBAL_RANKS
154
155
    assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
        'pipeline model parallel group is already initialized'
156
    global _EMBEDDING_GROUP
157
    global _EMBEDDING_GLOBAL_RANKS
158
    assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
159
160
161
162
    global _POSITION_EMBEDDING_GROUP
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    assert _POSITION_EMBEDDING_GROUP is None, \
        'position embedding group is already initialized'
163
    for i in range(num_pipeline_model_parallel_groups):
164
        ranks = range(i, world_size, num_pipeline_model_parallel_groups)
165
166
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
167
            _PIPELINE_MODEL_PARALLEL_GROUP = group
168
            _PIPELINE_GLOBAL_RANKS = ranks
169
170
171
172
        # Setup embedding group (to exchange gradients between
        # first and last stages).
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
173
174
175
176
177
178
179
180
181
            position_embedding_ranks = [ranks[0]]
            if pipeline_model_parallel_split_rank_ is not None:
                if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
                    embedding_ranks = [ranks[0],
                                       ranks[pipeline_model_parallel_split_rank_],
                                       ranks[-1]]
                if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks:
                    position_embedding_ranks = [ranks[0],
                                       ranks[pipeline_model_parallel_split_rank_]]
182
183
        else:
            embedding_ranks = ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
184
185
            position_embedding_ranks = ranks

186
187
188
        group = torch.distributed.new_group(embedding_ranks)
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group
189
190
        if rank in ranks:
            _EMBEDDING_GLOBAL_RANKS = embedding_ranks
191

Vijay Korthikanti's avatar
Vijay Korthikanti committed
192
193
194
195
        group = torch.distributed.new_group(position_embedding_ranks)
        if rank in position_embedding_ranks:
            _POSITION_EMBEDDING_GROUP = group
        if rank in ranks:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
196
            _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
197

198
199
200

def model_parallel_is_initialized():
    """Check if model and data parallel groups are initialized."""
201
202
    if _TENSOR_MODEL_PARALLEL_GROUP is None or \
        _PIPELINE_MODEL_PARALLEL_GROUP is None or \
203
        _DATA_PARALLEL_GROUP is None:
204
205
206
207
208
209
210
211
212
213
214
        return False
    return True


def get_model_parallel_group():
    """Get the model parallel group the caller rank belongs to."""
    assert _MODEL_PARALLEL_GROUP is not None, \
        'model parallel group is not initialized'
    return _MODEL_PARALLEL_GROUP


215
216
217
def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
    assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
218
        'intra_layer_model parallel group is not initialized'
219
    return _TENSOR_MODEL_PARALLEL_GROUP
220
221


222
223
224
225
226
def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
        'pipeline_model parallel group is not initialized'
    return _PIPELINE_MODEL_PARALLEL_GROUP
227
228


229
230
231
232
233
234
235
def get_data_parallel_group():
    """Get the data parallel group the caller rank belongs to."""
    assert _DATA_PARALLEL_GROUP is not None, \
        'data parallel group is not initialized'
    return _DATA_PARALLEL_GROUP


236
237
238
239
240
241
242
def get_embedding_group():
    """Get the embedding group the caller rank belongs to."""
    assert _EMBEDDING_GROUP is not None, \
        'embedding group is not initialized'
    return _EMBEDDING_GROUP


Vijay Korthikanti's avatar
Vijay Korthikanti committed
243
244
245
246
247
248
249
def get_position_embedding_group():
    """Get the position embedding group the caller rank belongs to."""
    assert _POSITION_EMBEDDING_GROUP is not None, \
        'position embedding group is not initialized'
    return _POSITION_EMBEDDING_GROUP


250
251
252
253
def set_tensor_model_parallel_world_size(world_size):
    """Set the tensor model parallel size"""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
254
255


256
257
258
259
def set_pipeline_model_parallel_world_size(world_size):
    """Set the pipeline model parallel size"""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
260
261


262
263
264
265
266
267
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
268
269


270
271
272
273
274
275
def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
276
277


278
279
280
281
def set_tensor_model_parallel_rank(rank):
    """Set tensor model parallel rank."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
282
283


284
285
286
287
def set_pipeline_model_parallel_rank(rank):
    """Set pipeline model parallel rank."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
288
289


290
291
292
293
294
295
def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
296
297


298
299
300
301
302
303
def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
304
305


306
307
308
309
310
def get_num_layers(args, is_encoder_and_decoder_model):
    """Compute the number of transformer layers resident on the current rank."""
    if get_pipeline_model_parallel_world_size() > 1:
        if is_encoder_and_decoder_model:
            assert args.pipeline_model_parallel_split_rank is not None
Lawrence McAfee's avatar
Lawrence McAfee committed
311
312
313
314
315

            # When a standalone embedding stage is used, a rank is taken from
            # the encoder's ranks, to be used for the encoder's embedding
            # layer. This way, the rank referenced by the 'split rank' remains
            # the same whether or not a standalone embedding stage is used.
316
317
            num_ranks_in_encoder = (
                args.pipeline_model_parallel_split_rank - 1
318
                if args.standalone_embedding_stage else
319
320
321
                args.pipeline_model_parallel_split_rank
            )
            num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
322
            assert args.num_layers % num_ranks_in_encoder == 0, \
323
                    'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder)
324
            assert args.num_layers % num_ranks_in_decoder == 0, \
325
                    'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder)
Lawrence McAfee's avatar
Lawrence McAfee committed
326
            if is_pipeline_stage_before_split():
Lawrence McAfee's avatar
Lawrence McAfee committed
327
328
329
330
331
332
                num_layers = (
                    0
                    if args.standalone_embedding_stage
                    and get_pipeline_model_parallel_rank() == 0 else
                    args.num_layers // num_ranks_in_encoder
                )
333
334
335
            else:
                num_layers = args.num_layers // num_ranks_in_decoder
        else:
Lawrence McAfee's avatar
Lawrence McAfee committed
336
337
            assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
                'num_layers must be divisible by transformer_pipeline_model_parallel_size'
Lawrence McAfee's avatar
Lawrence McAfee committed
338
339
340
341
342

            # When a standalone embedding stage is used, all transformer layers
            # are divided among pipeline rank >= 1, while on pipeline rank 0,
            # ranks either contain the input embedding layer (virtual pp rank 0),
            # or no layers at all (virtual pp rank >= 1).
Lawrence McAfee's avatar
Lawrence McAfee committed
343
344
            num_layers = (
                0
345
                if args.standalone_embedding_stage
Lawrence McAfee's avatar
Lawrence McAfee committed
346
                and get_pipeline_model_parallel_rank() == 0 else
Lawrence McAfee's avatar
Lawrence McAfee committed
347
                args.num_layers // args.transformer_pipeline_model_parallel_size
Lawrence McAfee's avatar
Lawrence McAfee committed
348
            )
349
350
351
352
353
    else:
        num_layers = args.num_layers
    return num_layers


354
def is_pipeline_first_stage(ignore_virtual=False):
355
    """Return True if in the first pipeline model-parallel stage, False otherwise."""
356
    if not ignore_virtual:
357
358
        if get_virtual_pipeline_model_parallel_world_size() is not None and \
            get_virtual_pipeline_model_parallel_rank() != 0:
359
            return False
360
    return get_pipeline_model_parallel_rank() == 0
361
362


363
def is_pipeline_last_stage(ignore_virtual=False):
364
    """Return True if in the last pipeline model-parallel stage, False otherwise."""
365
    if not ignore_virtual:
366
367
368
369
370
        virtual_pipeline_model_parallel_world_size = \
            get_virtual_pipeline_model_parallel_world_size()
        if virtual_pipeline_model_parallel_world_size is not None and \
            get_virtual_pipeline_model_parallel_rank() != (
                virtual_pipeline_model_parallel_world_size - 1):
371
            return False
372
373
    return get_pipeline_model_parallel_rank() == (
        get_pipeline_model_parallel_world_size() - 1)
374
375


376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
def is_rank_in_embedding_group(ignore_virtual=False):
    """Return true if current rank is in embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _EMBEDDING_GLOBAL_RANKS
    if ignore_virtual:
        return rank in _EMBEDDING_GLOBAL_RANKS
    if rank in _EMBEDDING_GLOBAL_RANKS:
        if rank == _EMBEDDING_GLOBAL_RANKS[0]:
            return is_pipeline_first_stage(ignore_virtual=False)
        elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
            return is_pipeline_last_stage(ignore_virtual=False)
        else:
            return True
    return False


Vijay Korthikanti's avatar
Vijay Korthikanti committed
392
393
394
395
396
397
398
def is_rank_in_position_embedding_group():
    """Return true if current rank is in position embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    return rank in _POSITION_EMBEDDING_GLOBAL_RANKS


399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
def is_pipeline_stage_before_split(rank=None):
    """Return True if pipeline stage executes encoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_after_split(rank=None):
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_at_split():
    """Return true if pipeline stage executes decoder block and next
    stage executes encoder block for a model with both encoder and
    decoder."""
    rank = get_pipeline_model_parallel_rank()
    return is_pipeline_stage_before_split(rank) and \
            is_pipeline_stage_after_split(rank+1)


438
439
440
441
442
443
444
445
446
447
448
449
def get_virtual_pipeline_model_parallel_rank():
    """Return the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK


def set_virtual_pipeline_model_parallel_rank(rank):
    """Set the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank


450
451
452
453
454
455
def get_virtual_pipeline_model_parallel_world_size():
    """Return the virtual pipeline-parallel world size."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE


456
def get_tensor_model_parallel_src_rank():
457
    """Calculate the global rank corresponding to the first local rank
458
    in the tensor model parallel group."""
459
    global_rank = torch.distributed.get_rank()
460
    local_world_size = get_tensor_model_parallel_world_size()
461
462
    return (global_rank // local_world_size) * local_world_size

463

464
465
def get_data_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
466
467
468
469
    in the data parallel group."""
    assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
        "Data parallel group is not initialized"
    return _DATA_PARALLEL_GLOBAL_RANKS[0]
470
471


472
473
474
475
476
def get_pipeline_model_parallel_first_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    return _PIPELINE_GLOBAL_RANKS[0]

477

478
479
480
481
482
def get_pipeline_model_parallel_last_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]
483

484
def get_pipeline_model_parallel_next_rank():
485
486
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
487
488
489
490
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]

491

492
493
494
495
496
497
def get_pipeline_model_parallel_prev_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
498

499

500
501
502
503
504
505
506
507
508
509
510
def get_data_parallel_world_size():
    """Return world size for the data parallel group."""
    return torch.distributed.get_world_size(group=get_data_parallel_group())


def get_data_parallel_rank():
    """Return my rank for the data parallel group."""
    return torch.distributed.get_rank(group=get_data_parallel_group())

def destroy_model_parallel():
    """Set the groups to none."""
511
512
    global _MODEL_PARALLEL_GROUP
    _MODEL_PARALLEL_GROUP = None
513
514
515
516
    global _TENSOR_MODEL_PARALLEL_GROUP
    _TENSOR_MODEL_PARALLEL_GROUP = None
    global _PIPELINE_MODEL_PARALLEL_GROUP
    _PIPELINE_MODEL_PARALLEL_GROUP = None
517
518
    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
519
520
    global _EMBEDDING_GROUP
    _EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
521
522
    global _POSITION_EMBEDDING_GROUP
    _POSITION_EMBEDDING_GROUP = None
523
524
525
526
527
528
529
530
531
532
533
534
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = None
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = None