parallel_state.py 22.3 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4
5

"""Model and data parallel groups."""

import torch
6
from typing import Optional
7

8
9
from .utils import GlobalMemoryBuffer

10
# Intra-layer model parallel group that the current rank belongs to.
11
_TENSOR_MODEL_PARALLEL_GROUP = None
12
# Inter-layer model parallel group that the current rank belongs to.
13
14
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
15
_MODEL_PARALLEL_GROUP = None
16
17
# Embedding group.
_EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
18
# Position embedding group.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
19
_POSITION_EMBEDDING_GROUP = None
20
21
22
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

23
24
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
25
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
26

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
27
# These values enable us to change the mpu sizes on the fly.
28
29
30
31
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
32

33
34
35
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None

Vijay Korthikanti's avatar
Vijay Korthikanti committed
36
37
38
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None

39
# A list of global ranks for each pipeline group to ease calculation of the source
40
# rank when broadcasting from the first or last pipeline stage.
41
_PIPELINE_GLOBAL_RANKS = None
42

43
44
45
46
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None

47
48
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
49

50

51
52
53
54
55
56
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    pipeline_model_parallel_split_rank: Optional[int] = None,
) -> None:
57
58
59
60
    """
    Initialize model data parallel groups.

    Arguments:
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        tensor_model_parallel_size (int, default = 1):
            The number of GPUs to split individual tensors across.

        pipeline_model_parallel_size (int, default = 1):
            The number of tensor parallel GPU groups to split the
            Transformer layers across. For example, if
            tensor_model_parallel_size is 4 and
            pipeline_model_parallel_size is 2, the model will be split
            into 2 groups of 4 GPUs.

        virtual_pipeline_model_parallel_size (int, optional):
            The number of stages that each pipeline group will have,
            interleaving as necessary. If None, no interleaving is
            performed. For example, if tensor_model_parallel_size is 1,
            pipeline_model_parallel_size is 4,
            virtual_pipeline_model_parallel_size is 2, and there are
            16 transformer layers in the model, the model will be
            split into 8 stages with two layers each and each GPU
            would get 2 stages as such (layer number starting with 1):

            GPU 0: [1, 2] [9, 10]
            GPU 1: [3, 4] [11, 12]
            GPU 2: [5, 6] [13, 14]
            GPU 3: [7, 8] [15, 16]

        pipeline_model_parallel_split_rank (int, optional):
            For models with both an encoder and decoder, the rank in
            pipeline to switch between encoder and decoder (i.e. the
            first rank of the decoder). This allows the user to set
            the pipeline parallel size of the encoder and decoder
            independently. For example, if
            pipeline_model_parallel_size is 8 and
            pipeline_model_parallel_split_rank is 3, then ranks 0-2
            will be the encoder and ranks 3-7 will be the decoder.
95

96
    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
97
98
99
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
100
101
102
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
103
        8 tensor model-parallel groups:
104
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
105
        4 pipeline model-parallel groups:
106
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
107
108
109
110
111
112
113
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
114
115
116
117
    world_size: int = torch.distributed.get_world_size()

    if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
        raise RuntimeError(
118
119
            f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
            f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
120
121
122
123
124
125
126
127
128
129
        )

    data_parallel_size: int = world_size // (tensor_model_parallel_size *
                                             pipeline_model_parallel_size)

    num_tensor_model_parallel_groups: int  = world_size // tensor_model_parallel_size
    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
    num_data_parallel_groups: int = world_size // data_parallel_size

    if virtual_pipeline_model_parallel_size is not None:
shanmugamr's avatar
shanmugamr committed
130
        if not pipeline_model_parallel_size > 2:
131
132
            raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
                               "interleaved schedule")
133
134
135
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
136
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
137

138
    if pipeline_model_parallel_split_rank is not None:
139
        global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
140
        _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
141

142
143
    rank = torch.distributed.get_rank()

144
    # Build the data-parallel groups.
145
    global _DATA_PARALLEL_GROUP
146
    global _DATA_PARALLEL_GLOBAL_RANKS
147
    assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
148
    all_data_parallel_group_ranks = []
149
150
151
152
    for i in range(pipeline_model_parallel_size):
        start_rank = i * num_pipeline_model_parallel_groups
        end_rank = (i + 1) * num_pipeline_model_parallel_groups
        for j in range(tensor_model_parallel_size):
153
            ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
154
155
156
157
            all_data_parallel_group_ranks.append(list(ranks))
            group = torch.distributed.new_group(ranks)
            if rank in ranks:
                _DATA_PARALLEL_GROUP = group
158
                _DATA_PARALLEL_GLOBAL_RANKS = ranks
159
160

    # Build the model-parallel groups.
161
    global _MODEL_PARALLEL_GROUP
162
    assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
163
164
165
    for i in range(data_parallel_size):
        ranks = [data_parallel_group_ranks[i]
                 for data_parallel_group_ranks in all_data_parallel_group_ranks]
166
        group = torch.distributed.new_group(ranks)
167
        if rank in ranks:
168
169
            _MODEL_PARALLEL_GROUP = group

170
171
172
173
174
175
176
    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
    assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
        'tensor model parallel group is already initialized'
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)
177
178
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
179
            _TENSOR_MODEL_PARALLEL_GROUP = group
180

181
182
183
    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).
    global _PIPELINE_MODEL_PARALLEL_GROUP
184
    global _PIPELINE_GLOBAL_RANKS
185
186
    assert _PIPELINE_MODEL_PARALLEL_GROUP is None, \
        'pipeline model parallel group is already initialized'
187
    global _EMBEDDING_GROUP
188
    global _EMBEDDING_GLOBAL_RANKS
189
    assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
190
191
192
193
    global _POSITION_EMBEDDING_GROUP
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    assert _POSITION_EMBEDDING_GROUP is None, \
        'position embedding group is already initialized'
194
    for i in range(num_pipeline_model_parallel_groups):
195
        ranks = range(i, world_size, num_pipeline_model_parallel_groups)
196
197
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
198
            _PIPELINE_MODEL_PARALLEL_GROUP = group
199
            _PIPELINE_GLOBAL_RANKS = ranks
200
201
202
203
        # Setup embedding group (to exchange gradients between
        # first and last stages).
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
Vijay Korthikanti's avatar
Vijay Korthikanti committed
204
            position_embedding_ranks = [ranks[0]]
Jared Casper's avatar
Jared Casper committed
205
206
            if pipeline_model_parallel_split_rank is not None:
                if ranks[pipeline_model_parallel_split_rank] not in embedding_ranks:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
207
                    embedding_ranks = [ranks[0],
Jared Casper's avatar
Jared Casper committed
208
                                       ranks[pipeline_model_parallel_split_rank],
Vijay Korthikanti's avatar
Vijay Korthikanti committed
209
                                       ranks[-1]]
Jared Casper's avatar
Jared Casper committed
210
                if ranks[pipeline_model_parallel_split_rank] not in position_embedding_ranks:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
211
                    position_embedding_ranks = [ranks[0],
Jared Casper's avatar
Jared Casper committed
212
                                       ranks[pipeline_model_parallel_split_rank]]
213
214
        else:
            embedding_ranks = ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
215
216
            position_embedding_ranks = ranks

217
218
219
        group = torch.distributed.new_group(embedding_ranks)
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group
220
221
        if rank in ranks:
            _EMBEDDING_GLOBAL_RANKS = embedding_ranks
222

Vijay Korthikanti's avatar
Vijay Korthikanti committed
223
224
225
226
        group = torch.distributed.new_group(position_embedding_ranks)
        if rank in position_embedding_ranks:
            _POSITION_EMBEDDING_GROUP = group
        if rank in ranks:
Vijay Korthikanti's avatar
Vijay Korthikanti committed
227
            _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
228

229
230
231
232
233
234
    # Initialize global memory buffer
    # This isn't really "parallel state" but there isn't another good place to
    # put this. If we end up with a more generic initialization of megatron-core
    # we could stick it there
    _set_global_memory_buffer()

235
236
237

def model_parallel_is_initialized():
    """Check if model and data parallel groups are initialized."""
238
239
    if _TENSOR_MODEL_PARALLEL_GROUP is None or \
        _PIPELINE_MODEL_PARALLEL_GROUP is None or \
240
        _DATA_PARALLEL_GROUP is None:
241
242
243
244
245
246
247
248
249
250
251
        return False
    return True


def get_model_parallel_group():
    """Get the model parallel group the caller rank belongs to."""
    assert _MODEL_PARALLEL_GROUP is not None, \
        'model parallel group is not initialized'
    return _MODEL_PARALLEL_GROUP


252
253
254
def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
    assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
255
        'intra_layer_model parallel group is not initialized'
256
    return _TENSOR_MODEL_PARALLEL_GROUP
257
258


259
260
261
262
263
def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, \
        'pipeline_model parallel group is not initialized'
    return _PIPELINE_MODEL_PARALLEL_GROUP
264
265


266
267
268
269
270
271
272
def get_data_parallel_group():
    """Get the data parallel group the caller rank belongs to."""
    assert _DATA_PARALLEL_GROUP is not None, \
        'data parallel group is not initialized'
    return _DATA_PARALLEL_GROUP


273
274
275
276
277
278
279
def get_embedding_group():
    """Get the embedding group the caller rank belongs to."""
    assert _EMBEDDING_GROUP is not None, \
        'embedding group is not initialized'
    return _EMBEDDING_GROUP


Vijay Korthikanti's avatar
Vijay Korthikanti committed
280
281
282
283
284
285
286
def get_position_embedding_group():
    """Get the position embedding group the caller rank belongs to."""
    assert _POSITION_EMBEDDING_GROUP is not None, \
        'position embedding group is not initialized'
    return _POSITION_EMBEDDING_GROUP


287
288
289
290
def set_tensor_model_parallel_world_size(world_size):
    """Set the tensor model parallel size"""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
291
292


293
294
295
296
def set_pipeline_model_parallel_world_size(world_size):
    """Set the pipeline model parallel size"""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
297
298


299
300
301
302
303
304
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
305
306


307
308
309
310
311
312
def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
313
314


315
316
317
318
def set_tensor_model_parallel_rank(rank):
    """Set tensor model parallel rank."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
319
320


321
322
323
324
def set_pipeline_model_parallel_rank(rank):
    """Set pipeline model parallel rank."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
325
326


327
328
def set_pipeline_model_parallel_split_rank(rank):
    """Set pipeline model parallel split rank."""
329
330
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
331
332


333
334
335
336
337
338
def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
339
340


341
342
343
344
345
346
def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
347
348


349
350
351
352
353
def get_pipeline_model_parallel_split_rank():
    """Return pipeline model parallel split rank."""
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK

354

355
def is_pipeline_first_stage(ignore_virtual=False):
356
    """Return True if in the first pipeline model-parallel stage, False otherwise."""
357
    if not ignore_virtual:
358
359
        if get_virtual_pipeline_model_parallel_world_size() is not None and \
            get_virtual_pipeline_model_parallel_rank() != 0:
360
            return False
361
    return get_pipeline_model_parallel_rank() == 0
362
363


364
def is_pipeline_last_stage(ignore_virtual=False):
365
    """Return True if in the last pipeline model-parallel stage, False otherwise."""
366
    if not ignore_virtual:
367
368
369
370
371
        virtual_pipeline_model_parallel_world_size = \
            get_virtual_pipeline_model_parallel_world_size()
        if virtual_pipeline_model_parallel_world_size is not None and \
            get_virtual_pipeline_model_parallel_rank() != (
                virtual_pipeline_model_parallel_world_size - 1):
372
            return False
373
374
    return get_pipeline_model_parallel_rank() == (
        get_pipeline_model_parallel_world_size() - 1)
375
376


377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
def is_rank_in_embedding_group(ignore_virtual=False):
    """Return true if current rank is in embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _EMBEDDING_GLOBAL_RANKS
    if ignore_virtual:
        return rank in _EMBEDDING_GLOBAL_RANKS
    if rank in _EMBEDDING_GLOBAL_RANKS:
        if rank == _EMBEDDING_GLOBAL_RANKS[0]:
            return is_pipeline_first_stage(ignore_virtual=False)
        elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
            return is_pipeline_last_stage(ignore_virtual=False)
        else:
            return True
    return False


Vijay Korthikanti's avatar
Vijay Korthikanti committed
393
394
395
396
397
398
399
def is_rank_in_position_embedding_group():
    """Return true if current rank is in position embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    return rank in _POSITION_EMBEDDING_GLOBAL_RANKS


400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
def is_pipeline_stage_before_split(rank=None):
    """Return True if pipeline stage executes encoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_after_split(rank=None):
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_at_split():
    """Return true if pipeline stage executes decoder block and next
    stage executes encoder block for a model with both encoder and
    decoder."""
    rank = get_pipeline_model_parallel_rank()
    return is_pipeline_stage_before_split(rank) and \
            is_pipeline_stage_after_split(rank+1)


439
440
441
442
443
444
445
446
447
448
449
450
def get_virtual_pipeline_model_parallel_rank():
    """Return the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK


def set_virtual_pipeline_model_parallel_rank(rank):
    """Set the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank


451
452
453
454
455
456
def get_virtual_pipeline_model_parallel_world_size():
    """Return the virtual pipeline-parallel world size."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE


457
def get_tensor_model_parallel_src_rank():
458
    """Calculate the global rank corresponding to the first local rank
459
    in the tensor model parallel group."""
460
    global_rank = torch.distributed.get_rank()
461
    local_world_size = get_tensor_model_parallel_world_size()
462
463
    return (global_rank // local_world_size) * local_world_size

464

465
466
def get_data_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
467
468
469
470
    in the data parallel group."""
    assert _DATA_PARALLEL_GLOBAL_RANKS is not None, \
        "Data parallel group is not initialized"
    return _DATA_PARALLEL_GLOBAL_RANKS[0]
471
472


473
def get_pipeline_model_parallel_first_rank():
474
475
    """Return the global rank of the first process in the pipeline for the
    current tensor parallel group"""
476
477
478
479
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    return _PIPELINE_GLOBAL_RANKS[0]

480

481
def get_pipeline_model_parallel_last_rank():
482
483
    """Return the global rank of the last process in the pipeline for the
    current tensor parallel group"""
484
485
486
487
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]
488

489
def get_pipeline_model_parallel_next_rank():
490
    """Return the global rank that follows the caller in the pipeline"""
491
492
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
493
494
495
496
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]

497

498
def get_pipeline_model_parallel_prev_rank():
499
    """Return the global rank that preceeds the caller in the pipeline"""
500
501
502
503
504
    assert _PIPELINE_GLOBAL_RANKS is not None, \
        "Pipeline parallel group is not initialized"
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
505

506

507
508
509
510
511
512
513
514
515
def get_data_parallel_world_size():
    """Return world size for the data parallel group."""
    return torch.distributed.get_world_size(group=get_data_parallel_group())


def get_data_parallel_rank():
    """Return my rank for the data parallel group."""
    return torch.distributed.get_rank(group=get_data_parallel_group())

516
517
518
519
520
521
522
def _set_global_memory_buffer():
    """Initialize global buffer"""
    global _GLOBAL_MEMORY_BUFFER
    assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
    _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()

def get_global_memory_buffer():
523
    """Return the global GlobalMemoryBuffer object"""
524
525
526
527
    assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
    return _GLOBAL_MEMORY_BUFFER


528
529
def destroy_model_parallel():
    """Set the groups to none."""
530
531
    global _MODEL_PARALLEL_GROUP
    _MODEL_PARALLEL_GROUP = None
532
533
534
535
    global _TENSOR_MODEL_PARALLEL_GROUP
    _TENSOR_MODEL_PARALLEL_GROUP = None
    global _PIPELINE_MODEL_PARALLEL_GROUP
    _PIPELINE_MODEL_PARALLEL_GROUP = None
536
537
    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
538
539
    global _EMBEDDING_GROUP
    _EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
540
541
    global _POSITION_EMBEDDING_GROUP
    _POSITION_EMBEDDING_GROUP = None
542
543
544
545
546
547
548
549
550
551
552
553
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = None
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
554
555
    global _GLOBAL_MEMORY_BUFFER
    _GLOBAL_MEMORY_BUFFER = None