parallel_state.py 27.8 KB
Newer Older
Masaki Kozuki's avatar
Masaki Kozuki committed
1
# coding=utf-8
2
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
Masaki Kozuki's avatar
Masaki Kozuki committed
3
4
5
6
7
8
9
10
11
12
13
14
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
15
16
# TODO (mkozuki): Replace assert with RuntimeError.
# TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py
Masaki Kozuki's avatar
Masaki Kozuki committed
17
"""Model and data parallel groups."""
18
from typing import Tuple, Optional
19
import warnings
20

Masaki Kozuki's avatar
Masaki Kozuki committed
21
22
import torch

23
24
25
26
27
from apex.transformer.log_util import get_transformer_logger


_logger = get_transformer_logger(__name__)

28
# N.B. (mkozuki): Diff btwn Megatron-LM & apex parallel_state
29
30
31
32
# set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) =
# {
#     'get_num_layers',
# }
Masaki Kozuki's avatar
Masaki Kozuki committed
33
34
35
36
37
38
39
40
41
42


# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# Inter-layer model parallel group that the current rank belongs to.
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
43
44
# Position embedding group.
_POSITION_EMBEDDING_GROUP = None
Perkz Zheng's avatar
Perkz Zheng committed
45
46
# Relative position embedding group.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
47
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
Masaki Kozuki's avatar
Masaki Kozuki committed
48
49
50
51
52
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None

_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
53
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
Masaki Kozuki's avatar
Masaki Kozuki committed
54
55
56
57
58
59
60

# These values enable us to change the mpu sizes on the fly.
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None

eqy's avatar
eqy committed
61
62
63
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None

64
65
66
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None

Perkz Zheng's avatar
Perkz Zheng committed
67
68
69
70
# A list of ranks that have a copy of the relative position embedding.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = None
_DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = None

Masaki Kozuki's avatar
Masaki Kozuki committed
71
72
73
74
75
76
77
78
79
80
81
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None


def is_unitialized():
    """Useful for code segments that may be accessed with or without mpu initialization"""
    return _DATA_PARALLEL_GROUP is None


def initialize_model_parallel(
82
83
84
85
    tensor_model_parallel_size_: int = 1,
    pipeline_model_parallel_size_: int = 1,
    virtual_pipeline_model_parallel_size_: Optional[int] = None,
    pipeline_model_parallel_split_rank_: Optional[int] = None,
86
87
88
    *,
    default_backend: Optional[str] = None,
    p2p_backend: Optional[str] = None,
89
) -> None:
Masaki Kozuki's avatar
Masaki Kozuki committed
90
91
92
93
94
95
    """
    Initialize model data parallel groups.

    Arguments:
        tensor_model_parallel_size: number of GPUs used to parallelize model tensor.
        pipeline_model_parallel_size: number of GPUs used to parallelize model pipeline.
96
97
        virtual_pipeline_model_parallel_size: number of virtual stages (interleaved pipeline).
        pipeline_model_parallel_split_rank: for models with both encoder and decoder, rank in pipeline with split point.
98
99
100
101
102
103
104
105
106
    Keyword Arguments:
        default_backend: Backend of process groups except for pipeline parallel ones.
            If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.
        p2p_backend: Backend of process groups for pipeline model parallel.
            If :obj:`None`, the backend specified in `torch.distributed.init_process_group` will be used.

    .. note::
        `torch_ucc <https://github.com/facebookresearch/torch_ucc>`_ is
        necessary for "ucc" backend.
Masaki Kozuki's avatar
Masaki Kozuki committed
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125

    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
        8 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
        4 pipeline model-parallel groups:
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
126
127
128
129
130
131
132
133
    assert default_backend is None or default_backend in ("nccl", "ucc")
    assert p2p_backend is None or p2p_backend in ("nccl", "ucc")
    if "ucc" in (default_backend, p2p_backend):
        check_torch_ucc_availability()
        warnings.warn("`ucc` backend support is experimental", ExperimentalWarning)
    if default_backend == "ucc":
        warnings.warn("The UCC's functionality as `default_backend` is not well verified", ExperimentalWarning)

134
135
136
137
138
139
140
141
142
143
    world_size: int = torch.distributed.get_world_size()
    tensor_model_parallel_size: int = min(tensor_model_parallel_size_, world_size)
    pipeline_model_parallel_size: int = min(pipeline_model_parallel_size_, world_size)
    if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
        raise RuntimeError(
            f"`world_size` ({world_size}) is not divisible by tensor_model_parallel_size ({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
        )
    data_parallel_size: int = world_size // (
        tensor_model_parallel_size * pipeline_model_parallel_size
    )
144
    if torch.distributed.get_rank() == 0:
145
146
147
148
149
150
151
152
153
154
155
156
157
        _logger.info(
            "> initializing tensor model parallel with size {}".format(
                tensor_model_parallel_size
            )
        )
        _logger.info(
            "> initializing pipeline model parallel with size {}".format(
                pipeline_model_parallel_size
            )
        )
        _logger.info(
            "> initializing data parallel with size {}".format(data_parallel_size)
        )
Masaki Kozuki's avatar
Masaki Kozuki committed
158

159
160
161
    num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
    num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
    num_data_parallel_groups: int = world_size // data_parallel_size
Masaki Kozuki's avatar
Masaki Kozuki committed
162
163

    if virtual_pipeline_model_parallel_size_ is not None:
164
165
166
167
168
169
170
        # n.b. (eqy) This check was inherited from Megatron-LM, need to revisit
        # the root cause as we do see numerical mismatches with 2 stages and
        # the interleaved schedule
        assert pipeline_model_parallel_size_ > 2, (
            "pipeline-model-parallel size should be greater than 2 with "
            "interleaved schedule"
        )
Masaki Kozuki's avatar
Masaki Kozuki committed
171
172
173
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
174
175
176
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = (
            virtual_pipeline_model_parallel_size_
        )
Masaki Kozuki's avatar
Masaki Kozuki committed
177

178
179
180
181
    if pipeline_model_parallel_split_rank_ is not None:
        global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
        _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank_

Masaki Kozuki's avatar
Masaki Kozuki committed
182
183
184
185
186
187
188
189
190
191
192
193
    rank = torch.distributed.get_rank()

    # Build the data-parallel groups.
    global _DATA_PARALLEL_GROUP
    assert _DATA_PARALLEL_GROUP is None, "data parallel group is already initialized"
    all_data_parallel_group_ranks = []
    for i in range(pipeline_model_parallel_size):
        start_rank = i * num_pipeline_model_parallel_groups
        end_rank = (i + 1) * num_pipeline_model_parallel_groups
        for j in range(tensor_model_parallel_size):
            ranks = range(start_rank + j, end_rank, tensor_model_parallel_size)
            all_data_parallel_group_ranks.append(list(ranks))
194
            group = torch.distributed.new_group(ranks, backend=default_backend)
Masaki Kozuki's avatar
Masaki Kozuki committed
195
196
197
198
199
200
201
            if rank in ranks:
                _DATA_PARALLEL_GROUP = group

    # Build the model-parallel groups.
    global _MODEL_PARALLEL_GROUP
    assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
    for i in range(data_parallel_size):
202
203
204
205
        ranks = [
            data_parallel_group_ranks[i]
            for data_parallel_group_ranks in all_data_parallel_group_ranks
        ]
206
        group = torch.distributed.new_group(ranks, backend=default_backend)
Masaki Kozuki's avatar
Masaki Kozuki committed
207
208
209
210
211
        if rank in ranks:
            _MODEL_PARALLEL_GROUP = group

    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
212
213
214
    assert (
        _TENSOR_MODEL_PARALLEL_GROUP is None
    ), "tensor model parallel group is already initialized"
Masaki Kozuki's avatar
Masaki Kozuki committed
215
    for i in range(num_tensor_model_parallel_groups):
216
217
218
        ranks = list(
            range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
        )
219
        group = torch.distributed.new_group(ranks, backend=default_backend)
Masaki Kozuki's avatar
Masaki Kozuki committed
220
221
222
223
224
225
226
        if rank in ranks:
            _TENSOR_MODEL_PARALLEL_GROUP = group

    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).
    global _PIPELINE_MODEL_PARALLEL_GROUP
    global _PIPELINE_GLOBAL_RANKS
227
228
229
    assert (
        _PIPELINE_MODEL_PARALLEL_GROUP is None
    ), "pipeline model parallel group is already initialized"
Masaki Kozuki's avatar
Masaki Kozuki committed
230
    global _EMBEDDING_GROUP
eqy's avatar
eqy committed
231
    global _EMBEDDING_GLOBAL_RANKS
Masaki Kozuki's avatar
Masaki Kozuki committed
232
    assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
233
234
235
236
237
    global _POSITION_EMBEDDING_GROUP
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    assert (
        _POSITION_EMBEDDING_GROUP is None
    ), "position embedding group is already initialized"
Perkz Zheng's avatar
Perkz Zheng committed
238
239
240
241
242
243
244
    global _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
    global _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
    global _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
    global _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
    assert _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP is None or \
           _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP is None, \
        'relative position embedding group is already initialized'
Masaki Kozuki's avatar
Masaki Kozuki committed
245
246
    for i in range(num_pipeline_model_parallel_groups):
        ranks = range(i, world_size, num_pipeline_model_parallel_groups)
247
        group = torch.distributed.new_group(ranks, backend=p2p_backend)
Masaki Kozuki's avatar
Masaki Kozuki committed
248
249
250
251
252
        if rank in ranks:
            _PIPELINE_MODEL_PARALLEL_GROUP = group
            _PIPELINE_GLOBAL_RANKS = ranks
        # Setup embedding group (to exchange gradients between
        # first and last stages).
Perkz Zheng's avatar
Perkz Zheng committed
253
254
        encoder_relative_position_embedding_ranks = None
        decoder_relative_position_embedding_ranks = None
Masaki Kozuki's avatar
Masaki Kozuki committed
255
256
        if len(ranks) > 1:
            embedding_ranks = [ranks[0], ranks[-1]]
257
            position_embedding_ranks = [ranks[0]]
Perkz Zheng's avatar
Perkz Zheng committed
258
259
            encoder_relative_position_embedding_ranks = [ranks[0]]
            decoder_relative_position_embedding_ranks = [ranks[0]]
260
            if pipeline_model_parallel_split_rank_ is not None:
Perkz Zheng's avatar
Perkz Zheng committed
261
262
263
264
                encoder_relative_position_embedding_ranks = \
                    ranks[:pipeline_model_parallel_split_rank_]
                decoder_relative_position_embedding_ranks = \
                    ranks[pipeline_model_parallel_split_rank_:]
265
                if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
266
267
268
269
270
271
272
273
274
275
276
277
278
                    embedding_ranks = [
                        ranks[0],
                        ranks[pipeline_model_parallel_split_rank_],
                        ranks[-1],
                    ]
                if (
                    ranks[pipeline_model_parallel_split_rank_]
                    not in position_embedding_ranks
                ):
                    position_embedding_ranks = [
                        ranks[0],
                        ranks[pipeline_model_parallel_split_rank_],
                    ]
Masaki Kozuki's avatar
Masaki Kozuki committed
279
280
        else:
            embedding_ranks = ranks
281
            position_embedding_ranks = ranks
Perkz Zheng's avatar
Perkz Zheng committed
282
283
            encoder_relative_position_embedding_ranks = ranks
            decoder_relative_position_embedding_ranks = ranks
284

285
        group = torch.distributed.new_group(embedding_ranks, backend=default_backend)
Masaki Kozuki's avatar
Masaki Kozuki committed
286
287
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group
eqy's avatar
eqy committed
288
289
        if rank in ranks:
            _EMBEDDING_GLOBAL_RANKS = embedding_ranks
Masaki Kozuki's avatar
Masaki Kozuki committed
290

291
        group = torch.distributed.new_group(position_embedding_ranks, backend=default_backend)
292
293
294
295
296
        if rank in position_embedding_ranks:
            _POSITION_EMBEDDING_GROUP = group
        if rank in ranks:
            _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks

297
298
        if encoder_relative_position_embedding_ranks:
            group = torch.distributed.new_group(encoder_relative_position_embedding_ranks)
Perkz Zheng's avatar
Perkz Zheng committed
299
300
301
302
303
304
        if rank in encoder_relative_position_embedding_ranks:
            _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = group
        if rank in ranks:
            _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = \
                encoder_relative_position_embedding_ranks

305
306
        if decoder_relative_position_embedding_ranks:
            group = torch.distributed.new_group(decoder_relative_position_embedding_ranks)
Perkz Zheng's avatar
Perkz Zheng committed
307
308
309
310
311
        if rank in decoder_relative_position_embedding_ranks:
            _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = group
        if rank in ranks:
            _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = \
                decoder_relative_position_embedding_ranks
312

313
def get_rank_info() -> Tuple[int, int, int]:
314
    """Returns a tuple of (data, tensor, pipeline, virtual pipeline)-parallel-rank for logger."""
315
316
    if model_parallel_is_initialized():
        return (
317
            get_data_parallel_rank(),
318
319
            get_tensor_model_parallel_rank(),
            get_pipeline_model_parallel_rank(),
320
            get_virtual_pipeline_model_parallel_rank(),
321
        )
322
    return (0, 0, 0, 0)
323
324


Masaki Kozuki's avatar
Masaki Kozuki committed
325
326
def model_parallel_is_initialized():
    """Check if model and data parallel groups are initialized."""
327
328
329
330
331
    if (
        _TENSOR_MODEL_PARALLEL_GROUP is None
        or _PIPELINE_MODEL_PARALLEL_GROUP is None
        or _DATA_PARALLEL_GROUP is None
    ):
Masaki Kozuki's avatar
Masaki Kozuki committed
332
333
334
335
336
337
338
339
340
341
342
343
        return False
    return True


def get_model_parallel_group():
    """Get the model parallel group the caller rank belongs to."""
    assert _MODEL_PARALLEL_GROUP is not None, "model parallel group is not initialized"
    return _MODEL_PARALLEL_GROUP


def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
344
345
346
    assert (
        _TENSOR_MODEL_PARALLEL_GROUP is not None
    ), "intra_layer_model parallel group is not initialized"
Masaki Kozuki's avatar
Masaki Kozuki committed
347
348
349
350
351
    return _TENSOR_MODEL_PARALLEL_GROUP


def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
352
353
354
    assert (
        _PIPELINE_MODEL_PARALLEL_GROUP is not None
    ), "pipeline_model parallel group is not initialized"
Masaki Kozuki's avatar
Masaki Kozuki committed
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
    return _PIPELINE_MODEL_PARALLEL_GROUP


def get_data_parallel_group():
    """Get the data parallel group the caller rank belongs to."""
    assert _DATA_PARALLEL_GROUP is not None, "data parallel group is not initialized"
    return _DATA_PARALLEL_GROUP


def get_embedding_group():
    """Get the embedding group the caller rank belongs to."""
    assert _EMBEDDING_GROUP is not None, "embedding group is not initialized"
    return _EMBEDDING_GROUP


370
371
372
373
374
375
376
def get_position_embedding_group():
    """Get the position embedding group the caller rank belongs to."""
    assert (
        _POSITION_EMBEDDING_GROUP is not None
    ), "position embedding group is not initialized"
    return _POSITION_EMBEDDING_GROUP

Perkz Zheng's avatar
Perkz Zheng committed
377
378
379
380
381
382
383
384
385
386
387
def get_encoder_relative_position_embedding_group():
    """Get the encoder relative position embedding group the caller rank belongs to."""
    assert _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP is not None, \
        'encoder relative position embedding group is not initialized'
    return _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP

def get_decoder_relative_position_embedding_group():
    """Get the decoder relative position embedding group the caller rank belongs to."""
    assert _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP is not None, \
        'decoder relative position embedding group is not initialized'
    return _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
388

eqy's avatar
eqy committed
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
def is_rank_in_embedding_group(ignore_virtual=False):
    """Return true if current rank is in embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _EMBEDDING_GLOBAL_RANKS
    if ignore_virtual:
        return rank in _EMBEDDING_GLOBAL_RANKS
    if rank in _EMBEDDING_GLOBAL_RANKS:
        if rank == _EMBEDDING_GLOBAL_RANKS[0]:
            return is_pipeline_first_stage(ignore_virtual=False)
        elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
            return is_pipeline_last_stage(ignore_virtual=False)
        else:
            return True
    return False


405
406
407
408
409
410
def is_rank_in_position_embedding_group():
    """Return whether the current rank is in position embedding group."""
    rank = torch.distributed.get_rank()
    global _POSITION_EMBEDDING_GLOBAL_RANKS
    return rank in _POSITION_EMBEDDING_GLOBAL_RANKS

Perkz Zheng's avatar
Perkz Zheng committed
411
412
413
414
415
416
417
418
419
420
421
def is_rank_in_encoder_relative_position_embedding_group():
    """Return true if current rank is in encoder relative position embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
    return rank in _ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS

def is_rank_in_decoder_relative_position_embedding_group():
    """Return true if current rank is in decoder relative position embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
    return rank in _DECODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS
422

423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
def is_pipeline_stage_before_split(rank=None):
    """Return True if pipeline stage executes encoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_after_split(rank=None):
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_at_split():
    """Return true if pipeline stage executes decoder block and next
    stage executes encoder block for a model with both encoder and
    decoder."""
    rank = get_pipeline_model_parallel_rank()
458
459
460
    return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(
        rank + 1
    )
461
462


Masaki Kozuki's avatar
Masaki Kozuki committed
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
def set_tensor_model_parallel_world_size(world_size):
    """Set the tensor model parallel size"""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size


def set_pipeline_model_parallel_world_size(world_size):
    """Set the pipeline model parallel size"""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size


def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())


def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_pipeline_model_parallel_group())


def set_tensor_model_parallel_rank(rank):
    """Set tensor model parallel rank."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = rank


def set_pipeline_model_parallel_rank(rank):
    """Set pipeline model parallel rank."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())


def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())


519
520
521
# TODO (mkozuki): Add [`get_num_layers`](https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/mpu/initialize.py#L321) here, maybe?


522
523
524
525
526
527
def get_pipeline_model_parallel_split_rank():
    """Return my rank for the pipeline model parallel split rank."""
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK


528
529
530
531
532
533
def set_pipeline_model_parallel_split_rank(pipeline_model_parallel_split_rank: int):
    """Set my rank for the pipeline model parallel split rank."""
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank


Masaki Kozuki's avatar
Masaki Kozuki committed
534
535
536
537
538
539
540
541
542
543
544
545
546
547
def is_pipeline_first_stage(ignore_virtual=False):
    """Return True if in the first pipeline model-parallel stage, False otherwise."""
    if not ignore_virtual:
        if (
            get_virtual_pipeline_model_parallel_world_size() is not None
            and get_virtual_pipeline_model_parallel_rank() != 0
        ):
            return False
    return get_pipeline_model_parallel_rank() == 0


def is_pipeline_last_stage(ignore_virtual=False):
    """Return True if in the last pipeline model-parallel stage, False otherwise."""
    if not ignore_virtual:
548
549
550
        virtual_pipeline_model_parallel_world_size = (
            get_virtual_pipeline_model_parallel_world_size()
        )
Masaki Kozuki's avatar
Masaki Kozuki committed
551
552
553
554
        if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != (
            virtual_pipeline_model_parallel_world_size - 1
        ):
            return False
555
556
557
    return get_pipeline_model_parallel_rank() == (
        get_pipeline_model_parallel_world_size() - 1
    )
Masaki Kozuki's avatar
Masaki Kozuki committed
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585


def get_virtual_pipeline_model_parallel_rank():
    """Return the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK


def set_virtual_pipeline_model_parallel_rank(rank):
    """Set the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank


def get_virtual_pipeline_model_parallel_world_size():
    """Return the virtual pipeline-parallel world size."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE


def get_tensor_model_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the tensor model parallel group."""
    global_rank = torch.distributed.get_rank()
    local_world_size = get_tensor_model_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size


586
587
588
589
590
591
592
593
def get_data_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank in the data parallel group."""
    global_rank = torch.distributed.get_rank()
    data_parallel_size: int = get_data_parallel_world_size()
    num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size
    return global_rank % num_data_parallel_groups


Masaki Kozuki's avatar
Masaki Kozuki committed
594
def get_pipeline_model_parallel_first_rank():
595
596
597
    assert (
        _PIPELINE_GLOBAL_RANKS is not None
    ), "Pipeline parallel group is not initialized"
Masaki Kozuki's avatar
Masaki Kozuki committed
598
599
600
601
    return _PIPELINE_GLOBAL_RANKS[0]


def get_pipeline_model_parallel_last_rank():
602
603
604
    assert (
        _PIPELINE_GLOBAL_RANKS is not None
    ), "Pipeline parallel group is not initialized"
Masaki Kozuki's avatar
Masaki Kozuki committed
605
606
607
608
609
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]


def get_pipeline_model_parallel_next_rank():
610
611
612
    assert (
        _PIPELINE_GLOBAL_RANKS is not None
    ), "Pipeline parallel group is not initialized"
Masaki Kozuki's avatar
Masaki Kozuki committed
613
614
615
616
617
618
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]


def get_pipeline_model_parallel_prev_rank():
619
620
621
    assert (
        _PIPELINE_GLOBAL_RANKS is not None
    ), "Pipeline parallel group is not initialized"
Masaki Kozuki's avatar
Masaki Kozuki committed
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]


def get_data_parallel_world_size():
    """Return world size for the data parallel group."""
    return torch.distributed.get_world_size(group=get_data_parallel_group())


def get_data_parallel_rank():
    """Return my rank for the data parallel group."""
    return torch.distributed.get_rank(group=get_data_parallel_group())


637
638
639
# note (mkozuki): `destroy_model_parallel` voids more global variables than Megatron-LM.
# Otherwise pipeline parallel forward_backward functions test hangs possibly because
# the clean-up of the original is NOT enough.
Masaki Kozuki's avatar
Masaki Kozuki committed
640
641
642
643
644
645
646
647
648
649
650
651
def destroy_model_parallel():
    """Set the groups to none."""
    global _MODEL_PARALLEL_GROUP
    _MODEL_PARALLEL_GROUP = None
    global _TENSOR_MODEL_PARALLEL_GROUP
    _TENSOR_MODEL_PARALLEL_GROUP = None
    global _PIPELINE_MODEL_PARALLEL_GROUP
    _PIPELINE_MODEL_PARALLEL_GROUP = None
    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
    global _EMBEDDING_GROUP
    _EMBEDDING_GROUP = None
652
653
    global _POSITION_EMBEDDING_GROUP
    _POSITION_EMBEDDING_GROUP = None
Perkz Zheng's avatar
Perkz Zheng committed
654
655
656
657
    global _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP
    _ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
    global _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP
    _DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
658
659
660
661
662
663
664
665
666
667
668
669
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = None
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
670
671
672
673
674
675
676
677
678
679
680
681
682


# Used to warn when the UCC is specified.
class ExperimentalWarning(Warning): pass


def check_torch_ucc_availability() -> None:
    try:
        import torch_ucc  # NOQA
    except ImportError:
        raise ImportError(
            "UCC backend requires [torch_ucc](https://github.com/facebookresearch/torch_ucc) but not found"
        )