parallel_state.py 25.9 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Model and data parallel groups."""

liangjing's avatar
v1  
liangjing committed
5
import os
6
from typing import Optional
7

liangjing's avatar
v1  
liangjing committed
8
9
import torch

10
11
from .utils import GlobalMemoryBuffer

12
# Intra-layer model parallel group that the current rank belongs to.
13
_TENSOR_MODEL_PARALLEL_GROUP = None
14
# Inter-layer model parallel group that the current rank belongs to.
15
16
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
17
_MODEL_PARALLEL_GROUP = None
18
19
# Embedding group.
_EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
20
# Position embedding group.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
21
_POSITION_EMBEDDING_GROUP = None
22
23
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
liangjing's avatar
v1  
liangjing committed
24
25
26
_DATA_PARALLEL_GROUP_GLOO = None
# FP8 amax reduction group.
_AMAX_REDUCTION_GROUP = None
27

28
29
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
30
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
31

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
32
# These values enable us to change the mpu sizes on the fly.
33
34
35
36
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
37

38
39
40
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None

Vijay Korthikanti's avatar
Vijay Korthikanti committed
41
42
43
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None

44
# A list of global ranks for each pipeline group to ease calculation of the source
45
# rank when broadcasting from the first or last pipeline stage.
46
_PIPELINE_GLOBAL_RANKS = None
47

48
49
50
51
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None

52
53
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
54

55

56
57
58
59
60
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    pipeline_model_parallel_split_rank: Optional[int] = None,
liangjing's avatar
v1  
liangjing committed
61
62
    use_fp8: bool = False,
    use_sharp: bool = False,
63
) -> None:
liangjing's avatar
v1  
liangjing committed
64
    """Initialize model data parallel groups.
65
66

    Arguments:
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
        tensor_model_parallel_size (int, default = 1):
            The number of GPUs to split individual tensors across.

        pipeline_model_parallel_size (int, default = 1):
            The number of tensor parallel GPU groups to split the
            Transformer layers across. For example, if
            tensor_model_parallel_size is 4 and
            pipeline_model_parallel_size is 2, the model will be split
            into 2 groups of 4 GPUs.

        virtual_pipeline_model_parallel_size (int, optional):
            The number of stages that each pipeline group will have,
            interleaving as necessary. If None, no interleaving is
            performed. For example, if tensor_model_parallel_size is 1,
            pipeline_model_parallel_size is 4,
            virtual_pipeline_model_parallel_size is 2, and there are
            16 transformer layers in the model, the model will be
            split into 8 stages with two layers each and each GPU
            would get 2 stages as such (layer number starting with 1):

            GPU 0: [1, 2] [9, 10]
            GPU 1: [3, 4] [11, 12]
            GPU 2: [5, 6] [13, 14]
            GPU 3: [7, 8] [15, 16]

        pipeline_model_parallel_split_rank (int, optional):
            For models with both an encoder and decoder, the rank in
            pipeline to switch between encoder and decoder (i.e. the
            first rank of the decoder). This allows the user to set
            the pipeline parallel size of the encoder and decoder
            independently. For example, if
            pipeline_model_parallel_size is 8 and
            pipeline_model_parallel_split_rank is 3, then ranks 0-2
            will be the encoder and ranks 3-7 will be the decoder.
101

liangjing's avatar
v1  
liangjing committed
102
103
104
105
106
107
108
109
110
111
112
        use_fp8 (bool, default = False):
            Construct GPU groups needed for FP8 training, namely for
            amax reduction across the product of the data-parallel and
            tensor-parallel groups.

        use_sharp (bool, default = False):
            Set the use of SHARP for the collective communications of
            data-parallel process groups. When `True`, run barrier
            within each data-parallel process group, which specifies
            the SHARP application target groups.

113
    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
114
115
116
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
117
118
119
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
120
        8 tensor model-parallel groups:
121
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
122
        4 pipeline model-parallel groups:
123
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
124
125
126
127
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
liangjing's avatar
v1  
liangjing committed
128

129
130
131
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
132
133
134
135
    world_size: int = torch.distributed.get_world_size()

    if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
        raise RuntimeError(
136
137
            f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
            f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
138
139
        )

liangjing's avatar
v1  
liangjing committed
140
141
142
    data_parallel_size: int = world_size // (
        tensor_model_parallel_size * pipeline_model_parallel_size
    )
143

liangjing's avatar
v1  
liangjing committed
144
    num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
145
146
147
148
    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
    num_data_parallel_groups: int = world_size // data_parallel_size

    if virtual_pipeline_model_parallel_size is not None:
shanmugamr's avatar
shanmugamr committed
149
        if not pipeline_model_parallel_size > 2:
liangjing's avatar
v1  
liangjing committed
150
151
152
            raise RuntimeError(
                "pipeline-model-parallel size should be greater than 2 with interleaved schedule"
            )
153
154
155
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
156
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
157

158
    if pipeline_model_parallel_split_rank is not None:
159
        global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
160
        _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
161

162
163
    rank = torch.distributed.get_rank()

164
    # Build the data-parallel groups.
165
    global _DATA_PARALLEL_GROUP
liangjing's avatar
v1  
liangjing committed
166
    global _DATA_PARALLEL_GROUP_GLOO
167
    global _DATA_PARALLEL_GLOBAL_RANKS
168
    assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
169
    all_data_parallel_group_ranks = []
170
171
172
173
    for i in range(pipeline_model_parallel_size):
        start_rank = i * num_pipeline_model_parallel_groups
        end_rank = (i + 1) * num_pipeline_model_parallel_groups
        for j in range(tensor_model_parallel_size):
174
            ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
175
176
            all_data_parallel_group_ranks.append(list(ranks))
            group = torch.distributed.new_group(ranks)
liangjing's avatar
v1  
liangjing committed
177
            group_gloo = torch.distributed.new_group(ranks, backend="gloo")
178
179
            if rank in ranks:
                _DATA_PARALLEL_GROUP = group
liangjing's avatar
v1  
liangjing committed
180
                _DATA_PARALLEL_GROUP_GLOO = group_gloo
181
                _DATA_PARALLEL_GLOBAL_RANKS = ranks
182

liangjing's avatar
v1  
liangjing committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
    # Apply SHARP to DP process groups
    if use_sharp:
        if rank == 0:
            print(
                "The number of process groups to use SHARP with depends on the type "
                "of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
                "process groups and QM2 supports up to 256 process groups. We apply "
                "SHARP to the communications of the data-parallel domain. If the "
                "number of data-parallel process groups is larger than the max "
                "process groups that the network switch supports, the communication "
                "will fall back to non-SHARP operators. To enable SHARP, "
                "`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
            )
        torch.distributed.barrier(
            group=get_data_parallel_group(), device_ids=[torch.cuda.current_device()]
        )
        # Set `NCCL_SHARP_DISABLE=1` to restrict SHARP application to DP process groups
        os.environ["NCCL_SHARP_DISABLE"] = "1"

202
    # Build the model-parallel groups.
203
    global _MODEL_PARALLEL_GROUP
204
    assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
205
    for i in range(data_parallel_size):
liangjing's avatar
v1  
liangjing committed
206
207
208
209
        ranks = [
            data_parallel_group_ranks[i]
            for data_parallel_group_ranks in all_data_parallel_group_ranks
        ]
210
        group = torch.distributed.new_group(ranks)
211
        if rank in ranks:
212
213
            _MODEL_PARALLEL_GROUP = group

214
215
    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
liangjing's avatar
v1  
liangjing committed
216
217
218
    assert (
        _TENSOR_MODEL_PARALLEL_GROUP is None
    ), 'tensor model parallel group is already initialized'
219
    for i in range(num_tensor_model_parallel_groups):
liangjing's avatar
v1  
liangjing committed
220
        ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
221
222
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
223
            _TENSOR_MODEL_PARALLEL_GROUP = group
224

225
226
227
    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).
    global _PIPELINE_MODEL_PARALLEL_GROUP
228
    global _PIPELINE_GLOBAL_RANKS
liangjing's avatar
v1  
liangjing committed
229
230
231
    assert (
        _PIPELINE_MODEL_PARALLEL_GROUP is None
    ), 'pipeline model parallel group is already initialized'
232
    global _EMBEDDING_GROUP
233
    global _EMBEDDING_GLOBAL_RANKS
234
    assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
235
236
    global _POSITION_EMBEDDING_GROUP
    global _POSITION_EMBEDDING_GLOBAL_RANKS
liangjing's avatar
v1  
liangjing committed
237
    assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'
238
    for i in range(num_pipeline_model_parallel_groups):
239
        ranks = range(i, world_size, num_pipeline_model_parallel_groups)
240
241
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
242
            _PIPELINE_MODEL_PARALLEL_GROUP = group
243
            _PIPELINE_GLOBAL_RANKS = ranks
244
245
246
247
        # Setup embedding group (to exchange gradients between
        # first and last stages).
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
248
            position_embedding_ranks = [ranks[0]]
Jared Casper's avatar
Jared Casper committed
249
250
            if pipeline_model_parallel_split_rank is not None:
                if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
liangjing's avatar
v1  
liangjing committed
251
252
253
254
255
                    embedding_ranks = [
                        ranks[0],
                        ranks[pipeline_model_parallel_split_rank],
                        ranks[-1],
                    ]
Jared Casper's avatar
Jared Casper committed
256
                if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
liangjing's avatar
v1  
liangjing committed
257
                    position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank]]
258
259
        else:
            embedding_ranks = ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
260
261
            position_embedding_ranks = ranks

262
263
264
        group = torch.distributed.new_group(embedding_ranks)
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group
265
266
        if rank in ranks:
            _EMBEDDING_GLOBAL_RANKS = embedding_ranks
267

Vijay Korthikanti's avatar
Vijay Korthikanti committed
268
269
270
271
        group = torch.distributed.new_group(position_embedding_ranks)
        if rank in position_embedding_ranks:
            _POSITION_EMBEDDING_GROUP = group
        if rank in ranks:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
272
            _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
273

liangjing's avatar
v1  
liangjing committed
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    # Build the FP8 groups.
    global _AMAX_REDUCTION_GROUP
    assert _AMAX_REDUCTION_GROUP is None, 'FP8 amax reduction group is already initialized'
    if use_fp8:
        amax_group_size: int = tensor_model_parallel_size * data_parallel_size
        num_amax_groups: int = world_size // amax_group_size
        for i in range(num_amax_groups):
            start_rank = i * amax_group_size
            end_rank = (i + 1) * amax_group_size
            ranks = range(start_rank, end_rank)
            group = torch.distributed.new_group(ranks)
            if rank in ranks:
                _AMAX_REDUCTION_GROUP = group

288
289
290
291
292
293
    # Initialize global memory buffer
    # This isn't really "parallel state" but there isn't another good place to
    # put this. If we end up with a more generic initialization of megatron-core
    # we could stick it there
    _set_global_memory_buffer()

294

Abhinav Khattar's avatar
Abhinav Khattar committed
295
296
297
298
299
def is_unitialized():
    """Useful for code segments that may be accessed with or without mpu initialization"""
    return _DATA_PARALLEL_GROUP is None


300
301
def model_parallel_is_initialized():
    """Check if model and data parallel groups are initialized."""
liangjing's avatar
v1  
liangjing committed
302
303
304
305
306
    if (
        _TENSOR_MODEL_PARALLEL_GROUP is None
        or _PIPELINE_MODEL_PARALLEL_GROUP is None
        or _DATA_PARALLEL_GROUP is None
    ):
307
308
309
310
311
312
        return False
    return True


def get_model_parallel_group():
    """Get the model parallel group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
313
    assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized'
314
315
316
    return _MODEL_PARALLEL_GROUP


liangjing's avatar
v1  
liangjing committed
317
def get_tensor_model_parallel_group(check_initialized=True):
318
    """Get the tensor model parallel group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
319
320
321
322
    if check_initialized:
        assert (
            _TENSOR_MODEL_PARALLEL_GROUP is not None
        ), 'tensor model parallel group is not initialized'
323
    return _TENSOR_MODEL_PARALLEL_GROUP
324
325


326
327
def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
328
329
330
    assert (
        _PIPELINE_MODEL_PARALLEL_GROUP is not None
    ), 'pipeline_model parallel group is not initialized'
331
    return _PIPELINE_MODEL_PARALLEL_GROUP
332
333


334
335
def get_data_parallel_group():
    """Get the data parallel group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
336
    assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
337
338
339
    return _DATA_PARALLEL_GROUP


liangjing's avatar
v1  
liangjing committed
340
341
342
343
344
345
def get_data_parallel_group_gloo():
    """Get the data parallel group-gloo the caller rank belongs to."""
    assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized'
    return _DATA_PARALLEL_GROUP_GLOO


346
347
def get_embedding_group():
    """Get the embedding group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
348
    assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized'
349
350
351
    return _EMBEDDING_GROUP


Vijay Korthikanti's avatar
Vijay Korthikanti committed
352
353
def get_position_embedding_group():
    """Get the position embedding group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
354
    assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
355
356
357
    return _POSITION_EMBEDDING_GROUP


liangjing's avatar
v1  
liangjing committed
358
359
360
361
362
363
def get_amax_reduction_group():
    """Get the FP8 amax reduction group the caller rank belongs to."""
    assert _AMAX_REDUCTION_GROUP is not None, 'FP8 amax reduction group is not initialized'
    return _AMAX_REDUCTION_GROUP


364
365
366
367
def set_tensor_model_parallel_world_size(world_size):
    """Set the tensor model parallel size"""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
368
369


370
371
372
373
def set_pipeline_model_parallel_world_size(world_size):
    """Set the pipeline model parallel size"""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
374
375


liangjing's avatar
v1  
liangjing committed
376
377
378
379
380
381
def set_virtual_pipeline_model_parallel_world_size(world_size):
    """Set the pipeline model parallel size"""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size


382
383
384
385
386
387
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
388
389


390
391
392
393
394
395
def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
396
397


398
399
400
401
def set_tensor_model_parallel_rank(rank):
    """Set tensor model parallel rank."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
402
403


404
405
406
407
def set_pipeline_model_parallel_rank(rank):
    """Set pipeline model parallel rank."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
408
409


410
411
def set_pipeline_model_parallel_split_rank(rank):
    """Set pipeline model parallel split rank."""
412
413
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
414
415


416
417
418
419
420
421
def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
422
423


424
425
426
427
428
429
def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
430
431


432
433
434
435
436
def get_pipeline_model_parallel_split_rank():
    """Return pipeline model parallel split rank."""
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK

437

438
def is_pipeline_first_stage(ignore_virtual=False):
439
    """Return True if in the first pipeline model-parallel stage, False otherwise."""
440
    if not ignore_virtual:
liangjing's avatar
v1  
liangjing committed
441
442
443
444
        if (
            get_virtual_pipeline_model_parallel_world_size() is not None
            and get_virtual_pipeline_model_parallel_rank() != 0
        ):
445
            return False
446
    return get_pipeline_model_parallel_rank() == 0
447
448


449
def is_pipeline_last_stage(ignore_virtual=False):
450
    """Return True if in the last pipeline model-parallel stage, False otherwise."""
451
    if not ignore_virtual:
liangjing's avatar
v1  
liangjing committed
452
        virtual_pipeline_model_parallel_world_size = (
453
            get_virtual_pipeline_model_parallel_world_size()
liangjing's avatar
v1  
liangjing committed
454
455
456
457
        )
        if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != (
            virtual_pipeline_model_parallel_world_size - 1
        ):
458
            return False
liangjing's avatar
v1  
liangjing committed
459
    return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
460
461


462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
def is_rank_in_embedding_group(ignore_virtual=False):
    """Return true if current rank is in embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _EMBEDDING_GLOBAL_RANKS
    if ignore_virtual:
        return rank in _EMBEDDING_GLOBAL_RANKS
    if rank in _EMBEDDING_GLOBAL_RANKS:
        if rank == _EMBEDDING_GLOBAL_RANKS[0]:
            return is_pipeline_first_stage(ignore_virtual=False)
        elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
            return is_pipeline_last_stage(ignore_virtual=False)
        else:
            return True
    return False


Vijay Korthikanti's avatar
Vijay Korthikanti committed
478
479
480
481
482
483
484
def is_rank_in_position_embedding_group():
    """Return true if current rank is in position embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    return rank in _POSITION_EMBEDDING_GLOBAL_RANKS


485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
def is_pipeline_stage_before_split(rank=None):
    """Return True if pipeline stage executes encoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_after_split(rank=None):
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_at_split():
    """Return true if pipeline stage executes decoder block and next
    stage executes encoder block for a model with both encoder and
    decoder."""
    rank = get_pipeline_model_parallel_rank()
liangjing's avatar
v1  
liangjing committed
520
    return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
521
522


523
524
525
526
527
528
529
530
531
532
533
534
def get_virtual_pipeline_model_parallel_rank():
    """Return the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK


def set_virtual_pipeline_model_parallel_rank(rank):
    """Set the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank


535
536
537
538
539
540
def get_virtual_pipeline_model_parallel_world_size():
    """Return the virtual pipeline-parallel world size."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE


541
def get_tensor_model_parallel_src_rank():
542
    """Calculate the global rank corresponding to the first local rank
543
    in the tensor model parallel group."""
544
    global_rank = torch.distributed.get_rank()
545
    local_world_size = get_tensor_model_parallel_world_size()
546
547
    return (global_rank // local_world_size) * local_world_size

548

549
550
def get_data_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
551
    in the data parallel group."""
liangjing's avatar
v1  
liangjing committed
552
    assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized"
553
    return _DATA_PARALLEL_GLOBAL_RANKS[0]
554
555


556
def get_pipeline_model_parallel_first_rank():
557
558
    """Return the global rank of the first process in the pipeline for the
    current tensor parallel group"""
liangjing's avatar
v1  
liangjing committed
559
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
560
561
    return _PIPELINE_GLOBAL_RANKS[0]

562

563
def get_pipeline_model_parallel_last_rank():
564
565
    """Return the global rank of the last process in the pipeline for the
    current tensor parallel group"""
liangjing's avatar
v1  
liangjing committed
566
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
567
568
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]
569

liangjing's avatar
v1  
liangjing committed
570

571
def get_pipeline_model_parallel_next_rank():
572
    """Return the global rank that follows the caller in the pipeline"""
liangjing's avatar
v1  
liangjing committed
573
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
574
575
576
577
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]

578

579
def get_pipeline_model_parallel_prev_rank():
580
    """Return the global rank that preceeds the caller in the pipeline"""
liangjing's avatar
v1  
liangjing committed
581
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
582
583
584
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
585

586

587
588
def get_data_parallel_world_size():
    """Return world size for the data parallel group."""
liangjing's avatar
v1  
liangjing committed
589
590
591
592
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_world_size(group=get_data_parallel_group())
    else:
        return 0
593
594
595
596


def get_data_parallel_rank():
    """Return my rank for the data parallel group."""
liangjing's avatar
v1  
liangjing committed
597
598
599
600
601
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(group=get_data_parallel_group())
    else:
        return 0

602

603
604
605
606
607
608
def _set_global_memory_buffer():
    """Initialize global buffer"""
    global _GLOBAL_MEMORY_BUFFER
    assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
    _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()

liangjing's avatar
v1  
liangjing committed
609

610
def get_global_memory_buffer():
611
    """Return the global GlobalMemoryBuffer object"""
612
613
614
615
    assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
    return _GLOBAL_MEMORY_BUFFER


liangjing's avatar
v1  
liangjing committed
616
617
618
619
620
621
def destroy_global_memory_buffer():
    """Sets the global memory buffer to None"""
    global _GLOBAL_MEMORY_BUFFER
    _GLOBAL_MEMORY_BUFFER = None


622
623
def destroy_model_parallel():
    """Set the groups to none."""
624
625
    global _MODEL_PARALLEL_GROUP
    _MODEL_PARALLEL_GROUP = None
626
627
628
629
    global _TENSOR_MODEL_PARALLEL_GROUP
    _TENSOR_MODEL_PARALLEL_GROUP = None
    global _PIPELINE_MODEL_PARALLEL_GROUP
    _PIPELINE_MODEL_PARALLEL_GROUP = None
630
631
    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
632
633
    global _EMBEDDING_GROUP
    _EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
634
635
    global _POSITION_EMBEDDING_GROUP
    _POSITION_EMBEDDING_GROUP = None
liangjing's avatar
v1  
liangjing committed
636
637
    global _AMAX_REDUCTION_GROUP
    _AMAX_REDUCTION_GROUP = None
638
639
640
641
642
643
644
645
646
647
648
649
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = None
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
650
651
    global _GLOBAL_MEMORY_BUFFER
    _GLOBAL_MEMORY_BUFFER = None