initialize.py 23 KB
Newer Older
1
# coding=utf-8
Mohammad's avatar
Mohammad committed
2
# Copyright (c) 2020, NVIDIA CORPORATION.  All rights reserved.
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


"""Model and data parallel groups."""

import torch

from .utils import ensure_divisibility


24
# Intra-layer model parallel group that the current rank belongs to.
25
_TENSOR_MODEL_PARALLEL_GROUP = None
26
# Inter-layer model parallel group that the current rank belongs to.
27
28
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
29
_MODEL_PARALLEL_GROUP = None
30
31
# Embedding group.
_EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
32
# Position embedding group.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
33
_POSITION_EMBEDDING_GROUP = None
34
35
36
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

37
38
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
39
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
40

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
41
# These values enable us to change the mpu sizes on the fly.
42
43
44
45
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
46

47
48
49
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None

Vijay Korthikanti's avatar
Vijay Korthikanti committed
50
51
52
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None

53
# A list of global ranks for each pipeline group to ease calculation of the source
54
# rank when broadcasting from the first or last pipeline stage.
55
_PIPELINE_GLOBAL_RANKS = None
56

57
58
59
60
61
def is_unitialized():
    """Useful for code segments that may be accessed with or without mpu initialization"""
    return _DATA_PARALLEL_GROUP is None


62
def initialize_model_parallel(tensor_model_parallel_size_=1,
63
                              pipeline_model_parallel_size_=1,
64
65
                              virtual_pipeline_model_parallel_size_=None,
                              pipeline_model_parallel_split_rank_=None):
66
67
68
69
    """
    Initialize model data parallel groups.

    Arguments:
70
71
72
73
74
75
76
        tensor_model_parallel_size: number of GPUs used for tensor model parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model parallelism.
        virtual_pipeline_model_parallel_size: number of virtual stages (interleaved
                                              pipeline).
        pipeline_model_parallel_split_rank: for models with both encoder and decoder,
                                            rank in pipeline with split point.

77
78

    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
79
80
81
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
82
83
84
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
85
        8 tensor model-parallel groups:
86
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
87
        4 pipeline model-parallel groups:
88
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
89
90
91
92
93
94
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    if torch.distributed.get_rank() == 0:
95
96
97
98
        print('> initializing tensor model parallel with size {}'.format(
            tensor_model_parallel_size_))
        print('> initializing pipeline model parallel with size {}'.format(
            pipeline_model_parallel_size_))
99
100
101
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size = torch.distributed.get_world_size()
102
103
    tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
    pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
104
    ensure_divisibility(world_size,
105
106
107
                        tensor_model_parallel_size * pipeline_model_parallel_size)
    data_parallel_size = world_size // (tensor_model_parallel_size *
                                        pipeline_model_parallel_size)
108

109
110
    num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
    num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
111
112
    num_data_parallel_groups = world_size // data_parallel_size

113
114
115
116
117
118
    if virtual_pipeline_model_parallel_size_ is not None:
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_

119
120
121
122
    if pipeline_model_parallel_split_rank_ is not None:
        global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
        _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_

123
124
    rank = torch.distributed.get_rank()

125
    # Build the data-parallel groups.
126
127
128
    global _DATA_PARALLEL_GROUP
    assert _DATA_PARALLEL_GROUP is None, \
        'data parallel group is already initialized'
129
    all_data_parallel_group_ranks = []
130
131
132
133
    for i in range(pipeline_model_parallel_size):
        start_rank = i * num_pipeline_model_parallel_groups
        end_rank = (i + 1) * num_pipeline_model_parallel_groups
        for j in range(tensor_model_parallel_size):
134
            ranks = range(start_rank + j, end_rank,
135
                          tensor_model_parallel_size)
136
137
138
139
140
141
            all_data_parallel_group_ranks.append(list(ranks))
            group = torch.distributed.new_group(ranks)
            if rank in ranks:
                _DATA_PARALLEL_GROUP = group

    # Build the model-parallel groups.
142
143
144
    global _MODEL_PARALLEL_GROUP
    assert _MODEL_PARALLEL_GROUP is None, \
        'model parallel group is already initialized'
145
146
147
    for i in range(data_parallel_size):
        ranks = [data_parallel_group_ranks[i]
                 for data_parallel_group_ranks in all_data_parallel_group_ranks]
148
        group = torch.distributed.new_group(ranks)
149
        if rank in ranks:
150
151
            _MODEL_PARALLEL_GROUP = group

152
153
154
155
156
157
158
    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
    assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
        'tensor model parallel group is already initialized'
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)
159
160
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
161
            _TENSOR_MODEL_PARALLEL_GROUP = group
162

163
164
165
    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).
    global _PIPELINE_MODEL_PARALLEL_GROUP
166
    global _PIPELINE_GLOBAL_RANKS
167
168
    assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
        'pipeline model parallel group is already initialized'
169
    global _EMBEDDING_GROUP
170
    global _EMBEDDING_GLOBAL_RANKS
171
172
    assert _EMBEDDING_GROUP is None, \
        'embedding group is already initialized'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
173
174
175
176
    global _POSITION_EMBEDDING_GROUP
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    assert _POSITION_EMBEDDING_GROUP is None, \
        'position embedding group is already initialized'
177
    for i in range(num_pipeline_model_parallel_groups):
178
        ranks = range(i, world_size,
179
                      num_pipeline_model_parallel_groups)
180
181
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
182
            _PIPELINE_MODEL_PARALLEL_GROUP = group
183
            _PIPELINE_GLOBAL_RANKS = ranks
184
185
186
187
        # Setup embedding group (to exchange gradients between
        # first and last stages).
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
188
189
190
191
192
193
194
195
196
            position_embedding_ranks = [ranks[0]]
            if pipeline_model_parallel_split_rank_ is not None:
                if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
                    embedding_ranks = [ranks[0],
                                       ranks[pipeline_model_parallel_split_rank_],
                                       ranks[-1]]
                if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks:
                    position_embedding_ranks = [ranks[0],
                                       ranks[pipeline_model_parallel_split_rank_]]
197
198
        else:
            embedding_ranks = ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
199
200
            position_embedding_ranks = ranks

201
202
203
        group = torch.distributed.new_group(embedding_ranks)
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group
204
205
        if rank in ranks:
            _EMBEDDING_GLOBAL_RANKS = embedding_ranks
206

Vijay Korthikanti's avatar
Vijay Korthikanti committed
207
208
209
210
        group = torch.distributed.new_group(position_embedding_ranks)
        if rank in position_embedding_ranks:
            _POSITION_EMBEDDING_GROUP = group
        if rank in ranks:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
211
            _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
212

213
214
215

def model_parallel_is_initialized():
    """Check if model and data parallel groups are initialized."""
216
217
    if _TENSOR_MODEL_PARALLEL_GROUP is None or \
        _PIPELINE_MODEL_PARALLEL_GROUP is None or \
218
        _DATA_PARALLEL_GROUP is None:
219
220
221
222
223
224
225
226
227
228
229
        return False
    return True


def get_model_parallel_group():
    """Get the model parallel group the caller rank belongs to."""
    assert _MODEL_PARALLEL_GROUP is not None, \
        'model parallel group is not initialized'
    return _MODEL_PARALLEL_GROUP


230
231
232
def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
    assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
233
        'intra_layer_model parallel group is not initialized'
234
    return _TENSOR_MODEL_PARALLEL_GROUP
235
236


237
238
239
240
241
def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
        'pipeline_model parallel group is not initialized'
    return _PIPELINE_MODEL_PARALLEL_GROUP
242
243


244
245
246
247
248
249
250
def get_data_parallel_group():
    """Get the data parallel group the caller rank belongs to."""
    assert _DATA_PARALLEL_GROUP is not None, \
        'data parallel group is not initialized'
    return _DATA_PARALLEL_GROUP


251
252
253
254
255
256
257
def get_embedding_group():
    """Get the embedding group the caller rank belongs to."""
    assert _EMBEDDING_GROUP is not None, \
        'embedding group is not initialized'
    return _EMBEDDING_GROUP


Vijay Korthikanti's avatar
Vijay Korthikanti committed
258
259
260
261
262
263
264
def get_position_embedding_group():
    """Get the position embedding group the caller rank belongs to."""
    assert _POSITION_EMBEDDING_GROUP is not None, \
        'position embedding group is not initialized'
    return _POSITION_EMBEDDING_GROUP


265
266
267
268
def set_tensor_model_parallel_world_size(world_size):
    """Set the tensor model parallel size"""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
269
270


271
def set_pipeline_model_parallel_world_size(world_size):
Lawrence McAfee's avatar
Lawrence McAfee committed
272
273
274
    # >>>
    raise Exception("hi.")
    # <<<
275
276
277
    """Set the pipeline model parallel size"""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
278
279


280
281
282
283
284
285
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
286
287


288
289
290
291
292
def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
Lawrence McAfee's avatar
Lawrence McAfee committed
293
294
295
    # >>>
    # raise Exception("hi.")
    # <<<
296
    return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
297
298


299
300
301
302
def set_tensor_model_parallel_rank(rank):
    """Set tensor model parallel rank."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
303
304


305
306
307
308
def set_pipeline_model_parallel_rank(rank):
    """Set pipeline model parallel rank."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
309
310


311
312
313
314
315
316
def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
317
318


319
320
321
322
323
324
def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
325
326


327
328
329
330
def get_num_layers(args, is_encoder_and_decoder_model):
    """Compute the number of transformer layers resident on the current rank."""
    if get_pipeline_model_parallel_world_size() > 1:
        if is_encoder_and_decoder_model:
Lawrence McAfee's avatar
Lawrence McAfee committed
331
            # >>>
332
            # raise Exception("fix for t5.")
Lawrence McAfee's avatar
Lawrence McAfee committed
333
            # <<<
334
            assert args.pipeline_model_parallel_split_rank is not None
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
            # >>>
            # num_ranks_in_encoder = args.pipeline_model_parallel_split_rank
            # +++
            num_ranks_in_encoder = (
                args.pipeline_model_parallel_split_rank - 1
                if args.standalone_embed_stage else
                args.pipeline_model_parallel_split_rank
            )
            # <<<
            # >>>
            # num_ranks_in_decoder = get_pipeline_model_parallel_world_size() - num_ranks_in_encoder
            # +++
            num_ranks_in_decoder = args.transformer_pipeline_model_parallel_size - num_ranks_in_encoder
            # <<<
            # >>>
            # raise Exception(">>>> standalone %d, encoder %d, decoder %d. <<<<" % (
            #     args.standalone_embed_stage,
            #     num_ranks_in_encoder,
            #     num_ranks_in_decoder,
            # ))
            # <<<
356
            assert args.num_layers % num_ranks_in_encoder == 0, \
357
                    'num_layers (%d) must be divisible by number of ranks given to encoder (%d)' % (args.num_layers, num_ranks_in_encoder)
358
            assert args.num_layers % num_ranks_in_decoder == 0, \
359
360
                    'num_layers (%d) must be divisible by number of ranks given to decoder (%d)' % (args.num_layers, num_ranks_in_decoder)
            if is_pipeline_stage_before_split(): # args):
361
362
363
364
                num_layers = args.num_layers // num_ranks_in_encoder
            else:
                num_layers = args.num_layers // num_ranks_in_decoder
        else:
365
366
367
368
369
370
371
            # >>>
            # transformer_pipeline_size = (
            #     get_pipeline_model_parallel_world_size() - 1
            #     if args.standalone_embed_stage else
            #     get_pipeline_model_parallel_world_size()
            # )
            # <<<
Lawrence McAfee's avatar
Lawrence McAfee committed
372
373
            assert args.num_layers % args.transformer_pipeline_model_parallel_size == 0, \
                'num_layers must be divisible by transformer_pipeline_model_parallel_size'
Lawrence McAfee's avatar
Lawrence McAfee committed
374
375
376
377
            num_layers = (
                0
                if args.standalone_embed_stage
                and get_pipeline_model_parallel_rank() == 0 else
Lawrence McAfee's avatar
Lawrence McAfee committed
378
                args.num_layers // args.transformer_pipeline_model_parallel_size
Lawrence McAfee's avatar
Lawrence McAfee committed
379
            )
380
381
    else:
        num_layers = args.num_layers
382
    # >>>
383
384
385
386
387
388
389
390
391
    # from lutil import pax
    # pax(7, {
    #     "rank" : torch.distributed.get_rank(),
    #     "pipeline rank" : "%d / %d" % (
    #         get_pipeline_model_parallel_rank(),
    #         get_pipeline_model_parallel_world_size(),
    #     ),
    #     "num_layers" : num_layers,
    # })
392
    # <<<
393
394
395
    return num_layers


396
def is_pipeline_first_stage(ignore_virtual=False):
397
    """Return True if in the first pipeline model-parallel stage, False otherwise."""
398
    if not ignore_virtual:
399
400
        if get_virtual_pipeline_model_parallel_world_size() is not None and \
            get_virtual_pipeline_model_parallel_rank() != 0:
401
            return False
402
    return get_pipeline_model_parallel_rank() == 0
403
404


405
def is_pipeline_last_stage(ignore_virtual=False):
406
    """Return True if in the last pipeline model-parallel stage, False otherwise."""
407
    if not ignore_virtual:
408
409
410
411
412
        virtual_pipeline_model_parallel_world_size = \
            get_virtual_pipeline_model_parallel_world_size()
        if virtual_pipeline_model_parallel_world_size is not None and \
            get_virtual_pipeline_model_parallel_rank() != (
                virtual_pipeline_model_parallel_world_size - 1):
413
            return False
414
415
    return get_pipeline_model_parallel_rank() == (
        get_pipeline_model_parallel_world_size() - 1)
416
417


418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
def is_rank_in_embedding_group(ignore_virtual=False):
    """Return true if current rank is in embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _EMBEDDING_GLOBAL_RANKS
    if ignore_virtual:
        return rank in _EMBEDDING_GLOBAL_RANKS
    if rank in _EMBEDDING_GLOBAL_RANKS:
        if rank == _EMBEDDING_GLOBAL_RANKS[0]:
            return is_pipeline_first_stage(ignore_virtual=False)
        elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
            return is_pipeline_last_stage(ignore_virtual=False)
        else:
            return True
    return False


Vijay Korthikanti's avatar
Vijay Korthikanti committed
434
435
436
437
438
439
440
def is_rank_in_position_embedding_group():
    """Return true if current rank is in position embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    return rank in _POSITION_EMBEDDING_GLOBAL_RANKS


441
442
443
# >>>
# def is_pipeline_stage_before_split(args, rank=None):
# <<<
444
445
446
447
448
449
450
def is_pipeline_stage_before_split(rank=None):
    """Return True if pipeline stage executes encoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
451
452
453
454
455
    # >>>
    # if args.standalone_embed_stage:
    #     rank += 1
    assert isinstance(rank, (type(None), int)), "rank == <%s>." % type(rank).__name__
    # <<<
456
457
458
459
460
461
462
463
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


464
465
466
# >>>
# def is_pipeline_stage_after_split(args, rank=None):
# <<<
467
468
469
470
471
472
473
def is_pipeline_stage_after_split(rank=None):
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
474
475
476
477
478
    # >>>
    # if args.standalone_embed_stage:
    #     rank += 1
    assert isinstance(rank, (type(None), int)), "rank == <%s>." % type(rank).__name__
    # <<<
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_at_split():
    """Return true if pipeline stage executes decoder block and next
    stage executes encoder block for a model with both encoder and
    decoder."""
    rank = get_pipeline_model_parallel_rank()
    return is_pipeline_stage_before_split(rank) and \
            is_pipeline_stage_after_split(rank+1)


496
497
498
499
500
501
502
503
504
505
506
507
def get_virtual_pipeline_model_parallel_rank():
    """Return the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK


def set_virtual_pipeline_model_parallel_rank(rank):
    """Set the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank


508
509
510
511
512
513
def get_virtual_pipeline_model_parallel_world_size():
    """Return the virtual pipeline-parallel world size."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE


514
def get_tensor_model_parallel_src_rank():
515
    """Calculate the global rank corresponding to the first local rank
516
    in the tensor model parallel group."""
517
    global_rank = torch.distributed.get_rank()
518
    local_world_size = get_tensor_model_parallel_world_size()
519
520
    return (global_rank // local_world_size) * local_world_size

521

522
523
524
525
526
527
528
529
530
def get_data_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the tensor model parallel group."""
    global_rank = torch.distributed.get_rank()
    data_parallel_size = get_data_parallel_world_size()
    num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size
    return global_rank % num_data_parallel_groups


531
532
533
534
535
def get_pipeline_model_parallel_first_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    return _PIPELINE_GLOBAL_RANKS[0]

536

537
538
539
540
541
def get_pipeline_model_parallel_last_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]
542

543
def get_pipeline_model_parallel_next_rank():
544
545
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
546
547
548
549
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]

550

551
552
553
554
555
556
def get_pipeline_model_parallel_prev_rank():
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
557

558

559
560
561
562
563
564
565
566
567
568
569
570
def get_data_parallel_world_size():
    """Return world size for the data parallel group."""
    return torch.distributed.get_world_size(group=get_data_parallel_group())


def get_data_parallel_rank():
    """Return my rank for the data parallel group."""
    return torch.distributed.get_rank(group=get_data_parallel_group())


def destroy_model_parallel():
    """Set the groups to none."""
571
572
    global _MODEL_PARALLEL_GROUP
    _MODEL_PARALLEL_GROUP = None
573
574
575
576
    global _TENSOR_MODEL_PARALLEL_GROUP
    _TENSOR_MODEL_PARALLEL_GROUP = None
    global _PIPELINE_MODEL_PARALLEL_GROUP
    _PIPELINE_MODEL_PARALLEL_GROUP = None
577
578
    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
579
580
    global _EMBEDDING_GROUP
    _EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
581
582
    global _POSITION_EMBEDDING_GROUP
    _POSITION_EMBEDDING_GROUP = None