parallel_state.py 64.9 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Model and data parallel groups."""

liangjing's avatar
v1  
liangjing committed
5
import os
liangjing's avatar
liangjing committed
6
7
8
9
10
import warnings
from datetime import timedelta
from functools import partial
from itertools import cycle
from typing import Callable, List, Optional
11

liangjing's avatar
v1  
liangjing committed
12
13
import torch

14
15
from .utils import GlobalMemoryBuffer

16
# Intra-layer model parallel group that the current rank belongs to.
17
_TENSOR_MODEL_PARALLEL_GROUP = None
18
# Inter-layer model parallel group that the current rank belongs to.
19
20
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
21
_MODEL_PARALLEL_GROUP = None
liangjing's avatar
liangjing committed
22
23
# Model parallel group (both intra-, pipeline, and expert) that the current rank belongs to.
_MODEL_AND_EXPERT_PARALLEL_GROUP = None
24
25
# Embedding group.
_EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
# Position embedding group.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
27
_POSITION_EMBEDDING_GROUP = None
28
29
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
liangjing's avatar
v1  
liangjing committed
30
_DATA_PARALLEL_GROUP_GLOO = None
liangjing's avatar
liangjing committed
31
32
33
34
35
36
37
38
39
40
41
# tensor model parallel group and data parallel group combined
# used for fp8 and moe training
_TENSOR_AND_DATA_PARALLEL_GROUP = None
# Expert parallel group that the current rank belongs to.
_EXPERT_MODEL_PARALLEL_GROUP = None
_TENSOR_AND_EXPERT_PARALLEL_GROUP = None
_DATA_MODULO_EXPERT_PARALLEL_GROUP = None
_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = None
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = None
_DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = None

42

43
44
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
45
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
46

liangjing's avatar
liangjing committed
47
48
_PIPELINE_MODEL_PARALLEL_DECODER_START = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
49
# These values enable us to change the mpu sizes on the fly.
50
51
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
liangjing's avatar
liangjing committed
52
53
54
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_DATA_PARALLEL_WORLD_SIZE = None
_MPU_DATA_PARALLEL_RANK = None
55
56
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
liangjing's avatar
liangjing committed
57
_MPU_EXPERT_MODEL_PARALLEL_RANK = None
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
58

59
60
61
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None

Vijay Korthikanti's avatar
Vijay Korthikanti committed
62
63
64
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None

65
# A list of global ranks for each pipeline group to ease calculation of the source
66
# rank when broadcasting from the first or last pipeline stage.
67
_PIPELINE_GLOBAL_RANKS = None
68

69
70
71
72
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None

liangjing's avatar
liangjing committed
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# A list of global ranks for each tensor model parallel group to ease calculation of
# the first local rank in the tensor model parallel group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None

# Context parallel group that the current rank belongs to
_CONTEXT_PARALLEL_GROUP = None
# A list of global ranks for each context parallel group to ease calculation of the
# destination rank when exchanging KV/dKV between context parallel_ranks
_CONTEXT_PARALLEL_GLOBAL_RANKS = None

# Data parallel group information with context parallel combined.
_DATA_PARALLEL_GROUP_WITH_CP = None
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None

# combined parallel group of TP and CP
_TENSOR_AND_CONTEXT_PARALLEL_GROUP = None

# combined parallel group of TP, DP, and CP used for fp8
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None

94
95
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
96

liangjing's avatar
liangjing committed
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
# MOE logging
_MOE_LAYER_WISE_LOGGING_TRACKER = {}


def get_nccl_options(pg_name, nccl_comm_cfgs):
    """Set the NCCL process group options.

    Args:
        pg_name (str): process group name
        nccl_comm_cfgs (dict): nccl communicator configurations

    When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
    """
    if pg_name in nccl_comm_cfgs:
        nccl_options = torch.distributed.ProcessGroupNCCL.Options()
        nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
        nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
        nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
        return nccl_options
    else:
        return None


def generate_masked_orthogonal_rank_groups(
    world_size: int, parallel_size: List[int], mask: List[bool]
) -> List[List[int]]:
    """Generate orthogonal parallel groups based on the parallel size and mask.

    Arguments:
        world_size (int): world size

        parallel_size (List[int]):
            The parallel size of each orthogonal parallel type. For example, if
            tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,
            and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].

        mask (List[bool]):
            The mask controls which parallel methods the generated groups represent. If mask[i] is
            True, it means the generated group contains the i-th parallelism method. For example,
            if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then
            the generated group is the `tp-dp` group, if the mask = [False, True, False], then the
            generated group is the `pp` group.

    Algorithm:
        For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and
        local_rank satisfy the following equation:
            global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1)
                tp_rank \in [0, tp_size)
                dp_rank \in [0, dp_size)
                pp_rank \in [0, pp_size)

        If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
        For example,  if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
        dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
        The tp_rank and pp_rank will be combined to form the `dp_group_index`.
            dp_group_index = tp_rank + pp_rank * tp_size (2)

        So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
        range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
        equation (1).

        This function solve this math problem.

    For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],
    and the mask = [False, True, False]. Then,
        dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2
        dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2
        ...
        dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2

        dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
        dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
        ...
        dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
    """

    def prefix_product(a: List[int], init=1) -> List[int]:
        r = [init]
        for v in a:
            init = init * v
            r.append(init)
        return r

    def inner_product(a: List[int], b: List[int]) -> int:
        return sum([x * y for x, y in zip(a, b)])

    def decompose(index, shape, stride=None):
        '''
        This function solve the math problem below:
            There is an equation:
                index = sum(idx[i] * stride[i])
            And given the value of index, stride.
            Return the idx.
        This function will used to get the pp/dp/pp_rank
        from group_index and rank_in_group.
        '''
        if stride is None:
            stride = prefix_product(shape)
        idx = [(index // d) % s for s, d in zip(shape, stride)]
        # stride is a prefix_product result. And the value of stride[-1]
        # is not used.
        assert (
            sum([x * y for x, y in zip(idx, stride[:-1])]) == index
        ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)
        return idx

    masked_shape = [s for s, m in zip(parallel_size, mask) if m]
    unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m]

    global_stride = prefix_product(parallel_size)
    masked_stride = [d for d, m in zip(global_stride, mask) if m]
    unmasked_stride = [d for d, m in zip(global_stride, mask) if not m]

    group_size = prefix_product(masked_shape)[-1]
    num_of_group = world_size // group_size

    ranks = []
    for group_index in range(num_of_group):
        # get indices from unmaksed for group_index.
        decomposed_group_idx = decompose(group_index, unmasked_shape)
        rank = []
        for rank_in_group in range(group_size):
            # get indices from masked for rank_in_group.
            decomposed_rank_idx = decompose(rank_in_group, masked_shape)
            rank.append(
                inner_product(decomposed_rank_idx, masked_stride)
                + inner_product(decomposed_group_idx, unmasked_stride)
            )
        ranks.append(rank)
    return ranks


class RankGenerator(object):
    """A class for generating rank groups for different modes of parallelism."""

    def __init__(
        self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0
    ) -> None:
        self.tp = tp
        self.ep = ep
        self.dp = dp
        self.pp = pp
        self.cp = cp
        self.rank_offset = rank_offset
        self.world_size = tp * dp * pp * cp

        self.name_to_size = {
            "tp": self.tp,
            "pp": self.pp,
            "dp": self.dp,
            "ep": self.ep,
            "cp": self.cp,
        }
        self.order = order
        order = order.lower()

        if 'ep' in order:
            if 'ep-dp' not in order and 'dp-ep' not in order:
                raise RuntimeError(f"The ep and dp must be adjacent in order ({self.order}).")

        for name in self.name_to_size.keys():
            if name not in order and self.name_to_size[name] != 1:
                raise RuntimeError(
                    f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't"
                    f"specified the order ({self.order})."
                )
            elif name not in order:
                order = order + '-' + name

        self.order_w_ep = order
        self.order_wo_ep = '-'.join([token for token in order.split('-') if token != 'ep'])
        self.ordered_size_wo_ep = []
        self.ordered_size_w_ep = []

        for token in order.split('-'):
            if token == 'dp':
                self.ordered_size_w_ep.append(self.dp // self.ep)
                self.ordered_size_wo_ep.append(self.dp)
            elif token == 'ep':
                self.ordered_size_w_ep.append(self.ep)
            else:
                self.ordered_size_w_ep.append(self.name_to_size[token])
                self.ordered_size_wo_ep.append(self.name_to_size[token])

    def get_mask(self, order: str, token: str):
        """Create a mask for the specified tokens based on the given order.

        Args:
            order (str): The order of parallelism types (e.g., 'tp-dp-pp').
            token (str): The specific parallelism types to include in the mask,
                         separated by hyphens (e.g., 'tp-dp').
        """
        ordered_token = order.split('-')
        token = token.split('-')
        mask = [False] * len(ordered_token)
        for t in token:
            mask[ordered_token.index(t)] = True
        return mask

    def get_ranks(self, token, independent_ep=False):
        """Get rank group by input token.

        Args:
            token (str):
                Specify the ranks type that want to get. If we want
                to obtain multiple parallel types, we can use a hyphen
                '-' to separate them. For example, if we want to obtain
                the TP_DP group, the token should be 'tp-dp'.

            independent_ep (bool: True):
                This flag controls whether we treat EP and DP independently.
                EP shares ranks with DP, if we want to get ranks related to
                EP, we should set the flag. For example, get_ranks('dp', True)
                will get DP modulo EP group, and get_ranks('dp', False) will
                get full DP group.
        """
        if independent_ep:
            parallel_size = self.ordered_size_w_ep
            order = self.order_w_ep
        else:
            parallel_size = self.ordered_size_wo_ep
            order = self.order_wo_ep
        mask = self.get_mask(order, token)
        ranks = generate_masked_orthogonal_rank_groups(self.world_size, parallel_size, mask)
        if self.rank_offset > 0:
            for rank_group in ranks:
                for i in range(len(rank_group)):
                    rank_group[i] += self.rank_offset
        return ranks


def default_embedding_ranks(pp_ranks, split_rank=None):
    """Return the default ranks that constitute the stages on which the word embeddings live.
    For most models, these are the first and last pipeline stages.

    We also support the deprecated split rank argument for backwards compatibility."""
    if len(pp_ranks) == 1:
        return [pp_ranks[0]]
    elif split_rank is not None and pp_ranks[split_rank] not in (pp_ranks[0], pp_ranks[-1]):
        return [pp_ranks[0], pp_ranks[split_rank], pp_ranks[-1]]
    else:
        return [pp_ranks[0], pp_ranks[-1]]


def default_position_embedding_ranks(pp_ranks, split_rank=None):
    """Return the default ranks that constitute the stages on which the position embeddings live.
    For most models, this is only the first pipeline stage.

    We also support the deprecated split rank argument for backwards compatibility."""
    if split_rank is not None and pp_ranks[0] != pp_ranks[split_rank]:
        return [pp_ranks[0], pp_ranks[split_rank]]
    else:
        return [pp_ranks[0]]

351

352
353
354
355
356
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    pipeline_model_parallel_split_rank: Optional[int] = None,
liangjing's avatar
v1  
liangjing committed
357
    use_sharp: bool = False,
liangjing's avatar
liangjing committed
358
359
360
361
362
363
364
365
366
    context_parallel_size: int = 1,
    expert_model_parallel_size: int = 1,
    nccl_communicator_config_path: Optional[str] = None,
    distributed_timeout_minutes: int = 30,
    order: str = "tp-cp-ep-dp-pp",
    encoder_tensor_model_parallel_size: Optional[int] = 0,
    encoder_pipeline_model_parallel_size: Optional[int] = 0,
    get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
    get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
367
) -> None:
liangjing's avatar
liangjing committed
368
    # pylint: disable=line-too-long
liangjing's avatar
v1  
liangjing committed
369
    """Initialize model data parallel groups.
370

liangjing's avatar
liangjing committed
371
    Args:
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
        tensor_model_parallel_size (int, default = 1):
            The number of GPUs to split individual tensors across.

        pipeline_model_parallel_size (int, default = 1):
            The number of tensor parallel GPU groups to split the
            Transformer layers across. For example, if
            tensor_model_parallel_size is 4 and
            pipeline_model_parallel_size is 2, the model will be split
            into 2 groups of 4 GPUs.

        virtual_pipeline_model_parallel_size (int, optional):
            The number of stages that each pipeline group will have,
            interleaving as necessary. If None, no interleaving is
            performed. For example, if tensor_model_parallel_size is 1,
            pipeline_model_parallel_size is 4,
            virtual_pipeline_model_parallel_size is 2, and there are
            16 transformer layers in the model, the model will be
            split into 8 stages with two layers each and each GPU
            would get 2 stages as such (layer number starting with 1):

            GPU 0: [1, 2] [9, 10]
            GPU 1: [3, 4] [11, 12]
            GPU 2: [5, 6] [13, 14]
            GPU 3: [7, 8] [15, 16]

        pipeline_model_parallel_split_rank (int, optional):
liangjing's avatar
liangjing committed
398
            DEPRECATED. For models with both an encoder and decoder, the rank in
399
400
401
402
403
404
405
            pipeline to switch between encoder and decoder (i.e. the
            first rank of the decoder). This allows the user to set
            the pipeline parallel size of the encoder and decoder
            independently. For example, if
            pipeline_model_parallel_size is 8 and
            pipeline_model_parallel_split_rank is 3, then ranks 0-2
            will be the encoder and ranks 3-7 will be the decoder.
406

liangjing's avatar
v1  
liangjing committed
407
408
409
410
411
412
        use_sharp (bool, default = False):
            Set the use of SHARP for the collective communications of
            data-parallel process groups. When `True`, run barrier
            within each data-parallel process group, which specifies
            the SHARP application target groups.

liangjing's avatar
liangjing committed
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
        context_parallel_size (int, default = 1):
            The number of tensor parallel GPU groups to split the
            network input sequence length across. Compute of attention
            module requires tokens of full sequence length, so GPUs
            in a context parallel group need to communicate with each
            other to exchange information of other sequence chunks.
            Each GPU and its counterparts in other tensor parallel
            groups compose a context parallel group.

            For example, assume we have 8 GPUs, if tensor model parallel
            size is 4 and context parallel size is 2, the network input
            will be split into two sequence chunks, which are processed
            by 2 different groups of 4 GPUs. One chunk is processed by
            GPU0-3, the other chunk is processed by GPU4-7. Four groups
            are build to do context parallel communications: [GPU0, GPU4],
            [GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].

            Context parallelism partitions sequence length, so it has no
            impact on weights, which means weights are duplicated among
            GPUs in a context parallel group. Hence, weight gradients
            all-reduce is required in backward. For simplicity, we piggyback
            GPUs of context parallelism on data parallel group for
            weight gradient all-reduce.

        expert_model_parallel_size (int, default = 1):
            The number of Mixture of Experts parallel GPUs in each expert
            parallel group.

        nccl_communicator_config_path (str, default = None):
            Path to the yaml file of NCCL communicator configurations.
            `min_ctas`, `max_ctas`, and `cga_cluster_size` can be set
            for each communicator.

        distributed_timeout_minutes (int, default = 30): Timeout, in
            minutes,for operations executed against distributed
            process groups. See PyTorch documentation at
            https://pytorch.org/docs/stable/distributed.html for
            caveats.

        order (str, default=tp-dp-pp):
            The rank initialization order of parallelism. Now we support
            tp-dp-pp and tp-pp-dp orders.

        encoder_tensor_model_parallel_size (int, default = 0):
            The number of GPUs to split individual tensors across in the encoder. If 0,
            then we use the default, decoder's tensor model parallel size.

        encoder_pipeline_model_parallel_size (int, default = 0):
            The number of tensor parallel GPU groups to allocate to the encoder. As an example,
            if pipeline_model_parallel_size is 4 and encoder_pipeline_model_parallel_size is 2,
            then the encoder will use the first two pipeline stages for its layers, and the total
            amount of pipelineing is 6.

        get_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None):
            A function that takes in a list of ranks for a pipeline group and returns
            those ranks that should have embeddings.

        get_position_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None):
            A function that takes in a list of ranks for a pipeline group, and returns
            those ranks that should have position embeddings.

474
    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
475
476
477
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
478
479
480
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
481
        8 tensor model-parallel groups:
482
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
483
        4 pipeline model-parallel groups:
484
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
485
486
487
488
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
liangjing's avatar
v1  
liangjing committed
489

490
    """
liangjing's avatar
liangjing committed
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
    if encoder_pipeline_model_parallel_size is None:
        encoder_pipeline_model_parallel_size = 0

    if encoder_tensor_model_parallel_size == 0 and encoder_pipeline_model_parallel_size > 0:
        encoder_tensor_model_parallel_size = tensor_model_parallel_size

    if get_embedding_ranks is None:
        get_embedding_ranks = partial(
            default_embedding_ranks, split_rank=pipeline_model_parallel_split_rank
        )

    if get_position_embedding_ranks is None:
        get_position_embedding_ranks = partial(
            default_position_embedding_ranks, split_rank=pipeline_model_parallel_split_rank
        )

    if encoder_pipeline_model_parallel_size > 0:
        global _PIPELINE_MODEL_PARALLEL_DECODER_START
        _PIPELINE_MODEL_PARALLEL_DECODER_START = encoder_pipeline_model_parallel_size

511
512
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
513
514
    world_size: int = torch.distributed.get_world_size()

liangjing's avatar
liangjing committed
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
    if encoder_tensor_model_parallel_size > 0:
        assert encoder_pipeline_model_parallel_size > 0
        assert (
            encoder_tensor_model_parallel_size <= tensor_model_parallel_size
        ), "We do not support encoders with more TP than the decoder."

    encoder_model_size = (
        encoder_tensor_model_parallel_size
        * encoder_pipeline_model_parallel_size
        * context_parallel_size
    )
    decoder_model_size = (
        tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
    )
    total_model_size = encoder_model_size + decoder_model_size

    if world_size % total_model_size != 0:
        raise RuntimeError(f"world_size ({world_size}) is not divisible by {total_model_size}")

    data_parallel_size: int = world_size // total_model_size

    if data_parallel_size % expert_model_parallel_size != 0:
537
        raise RuntimeError(
liangjing's avatar
liangjing committed
538
539
            f"data_parallel_size ({data_parallel_size}) is not divisible by "
            "expert_model_parallel_size "
540
541
        )

liangjing's avatar
liangjing committed
542
543
    encoder_world_size = encoder_model_size * data_parallel_size
    decoder_world_size = decoder_model_size * data_parallel_size
544

liangjing's avatar
liangjing committed
545
546
547
    assert (
        encoder_world_size + decoder_world_size == world_size
    ), f"{encoder_world_size=} + {decoder_world_size=} != {world_size=}"
548
549

    if virtual_pipeline_model_parallel_size is not None:
liangjing's avatar
liangjing committed
550
        if not pipeline_model_parallel_size > 1:
liangjing's avatar
v1  
liangjing committed
551
            raise RuntimeError(
liangjing's avatar
liangjing committed
552
                "pipeline-model-parallel size should be greater than 1 with interleaved schedule"
liangjing's avatar
v1  
liangjing committed
553
            )
554
555
556
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
557
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
558

559
    if pipeline_model_parallel_split_rank is not None:
560
        global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
561
        _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
562

563
564
    rank = torch.distributed.get_rank()

liangjing's avatar
liangjing committed
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
    nccl_comm_cfgs = {}
    if nccl_communicator_config_path is not None:
        try:
            import yaml
        except ImportError:
            raise RuntimeError(
                "Cannot import `yaml`. Setting custom nccl communicator configs "
                "requires the yaml package."
            )

        with open(nccl_communicator_config_path, "r") as stream:
            nccl_comm_cfgs = yaml.safe_load(stream)

    if encoder_world_size > 0:
        encoder_rank_generator = RankGenerator(
            tp=encoder_tensor_model_parallel_size,
            ep=1,
            dp=data_parallel_size,
            pp=encoder_pipeline_model_parallel_size,
            cp=context_parallel_size,
            order=order,
            rank_offset=0,
        )
    else:
        encoder_rank_generator = None

    decoder_rank_generator = RankGenerator(
        tp=tensor_model_parallel_size,
        ep=expert_model_parallel_size,
        dp=data_parallel_size,
        pp=pipeline_model_parallel_size,
        cp=context_parallel_size,
        order=order,
        rank_offset=encoder_world_size,
    )

    def generator_wrapper(group_type, **kwargs):
        """The `RankGenerator` class produces a hyper-rectangle for a given set of
        tensor, pipeline, data, expert, and context parallelism. If we have an encoder,
        in addition to the default decoder, we essentially instantiate two `RankGenerator`
        classes to construct the parallelism for each module separately, and we then have
        to stitch them together for the right groups. For now, this means pp and tp-pp."""
        d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)
        if encoder_rank_generator is None:
            for x in d_ranks:
                yield x
            return
        e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs)
        if group_type == 'pp':
            # Map 1 encoder tp rank to several decoder tp ranks, because
            # these won't be the same size.
            for x, y in zip(cycle(e_ranks), d_ranks):
                yield x + y
        elif group_type == 'tp-pp':
            # For this group, we can just return the concatenated
            # groups together, because their sizes are the same.
            assert len(e_ranks) == len(d_ranks)
            for x, y in zip(e_ranks, d_ranks):
                yield x + y
        else:
            for x in e_ranks:
                yield x
            for x in d_ranks:
                yield x

    timeout = timedelta(minutes=distributed_timeout_minutes)

632
    # Build the data-parallel groups.
633
    global _DATA_PARALLEL_GROUP
liangjing's avatar
v1  
liangjing committed
634
    global _DATA_PARALLEL_GROUP_GLOO
635
    global _DATA_PARALLEL_GLOBAL_RANKS
liangjing's avatar
liangjing committed
636
637
638
    global _DATA_PARALLEL_GROUP_WITH_CP
    global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
    global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
639
    assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
liangjing's avatar
liangjing committed
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661

    for ranks in generator_wrapper('dp'):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('dp', nccl_comm_cfgs)
        )
        group_gloo = torch.distributed.new_group(ranks, timeout=timeout, backend="gloo")
        if rank in ranks:
            _DATA_PARALLEL_GROUP = group
            _DATA_PARALLEL_GROUP_GLOO = group_gloo
            _DATA_PARALLEL_GLOBAL_RANKS = ranks

    for ranks_with_cp in generator_wrapper('dp-cp'):
        group_with_cp = torch.distributed.new_group(
            ranks_with_cp, timeout=timeout, pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs)
        )
        group_with_cp_gloo = torch.distributed.new_group(
            ranks_with_cp, timeout=timeout, backend="gloo"
        )
        if rank in ranks_with_cp:
            _DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
            _DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
            _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp
662

liangjing's avatar
v1  
liangjing committed
663
664
665
666
667
668
669
670
671
672
673
674
675
676
    # Apply SHARP to DP process groups
    if use_sharp:
        if rank == 0:
            print(
                "The number of process groups to use SHARP with depends on the type "
                "of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
                "process groups and QM2 supports up to 256 process groups. We apply "
                "SHARP to the communications of the data-parallel domain. If the "
                "number of data-parallel process groups is larger than the max "
                "process groups that the network switch supports, the communication "
                "will fall back to non-SHARP operators. To enable SHARP, "
                "`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
            )
        torch.distributed.barrier(
liangjing's avatar
liangjing committed
677
678
            group=get_data_parallel_group(with_context_parallel=True),
            device_ids=[torch.cuda.current_device()],
liangjing's avatar
v1  
liangjing committed
679
        )
liangjing's avatar
liangjing committed
680
681
682
683
684
685
686
687
688
689
690
691
692
693
        # Set `NCCL_COLLNET_ENABLE=0` to restrict SHARP application to DP process groups
        os.environ["NCCL_COLLNET_ENABLE"] = "0"

    # Build the context-parallel groups.
    global _CONTEXT_PARALLEL_GROUP
    global _CONTEXT_PARALLEL_GLOBAL_RANKS
    assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
    for ranks in generator_wrapper('cp'):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('cp', nccl_comm_cfgs)
        )
        if rank in ranks:
            _CONTEXT_PARALLEL_GROUP = group
            _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
liangjing's avatar
v1  
liangjing committed
694

695
    # Build the model-parallel groups.
696
    global _MODEL_PARALLEL_GROUP
697
    assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
liangjing's avatar
liangjing committed
698
699
700
701
    for ranks in generator_wrapper('tp-pp'):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('mp', nccl_comm_cfgs)
        )
702
        if rank in ranks:
703
704
            _MODEL_PARALLEL_GROUP = group

liangjing's avatar
liangjing committed
705
706
707
708
709
710
711
712
713
714
715
716
    # Build the model-parallel groups with expert parallel
    global _MODEL_AND_EXPERT_PARALLEL_GROUP
    assert (
        _MODEL_AND_EXPERT_PARALLEL_GROUP is None
    ), 'model and expert parallel group is already initialized'
    for ranks in generator_wrapper('tp-ep-pp', independent_ep=True):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('mp_exp', nccl_comm_cfgs)
        )
        if rank in ranks:
            _MODEL_AND_EXPERT_PARALLEL_GROUP = group

717
718
    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
liangjing's avatar
liangjing committed
719
    global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
liangjing's avatar
v1  
liangjing committed
720
721
722
    assert (
        _TENSOR_MODEL_PARALLEL_GROUP is None
    ), 'tensor model parallel group is already initialized'
liangjing's avatar
liangjing committed
723
724
725
726
    for ranks in generator_wrapper('tp'):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('tp', nccl_comm_cfgs)
        )
727
        if rank in ranks:
728
            _TENSOR_MODEL_PARALLEL_GROUP = group
liangjing's avatar
liangjing committed
729
            _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks
730

731
732
733
    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).
    global _PIPELINE_MODEL_PARALLEL_GROUP
734
    global _PIPELINE_GLOBAL_RANKS
liangjing's avatar
v1  
liangjing committed
735
736
737
    assert (
        _PIPELINE_MODEL_PARALLEL_GROUP is None
    ), 'pipeline model parallel group is already initialized'
738
    global _EMBEDDING_GROUP
739
    global _EMBEDDING_GLOBAL_RANKS
740
    assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
741
742
    global _POSITION_EMBEDDING_GROUP
    global _POSITION_EMBEDDING_GLOBAL_RANKS
liangjing's avatar
v1  
liangjing committed
743
    assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'
liangjing's avatar
liangjing committed
744
745
746
747
    for ranks in generator_wrapper('pp'):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('pp', nccl_comm_cfgs)
        )
748
        if rank in ranks:
liangjing's avatar
liangjing committed
749
750
751
752
753
754
755
756
757
758
759
760
761
762
            if _PIPELINE_MODEL_PARALLEL_GROUP is None:
                _PIPELINE_MODEL_PARALLEL_GROUP = group
                _PIPELINE_GLOBAL_RANKS = ranks
            elif isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
                _PIPELINE_MODEL_PARALLEL_GROUP.append(group)
                _PIPELINE_GLOBAL_RANKS.append(ranks)
            else:
                _PIPELINE_MODEL_PARALLEL_GROUP = [_PIPELINE_MODEL_PARALLEL_GROUP, group]
                _PIPELINE_GLOBAL_RANKS = [_PIPELINE_GLOBAL_RANKS, ranks]

        embedding_ranks = get_embedding_ranks(ranks)
        group = torch.distributed.new_group(
            embedding_ranks, timeout=timeout, pg_options=get_nccl_options('embd', nccl_comm_cfgs)
        )
763
764
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group
765
            _EMBEDDING_GLOBAL_RANKS = embedding_ranks
766

liangjing's avatar
liangjing committed
767
768
769
770
771
772
        position_embedding_ranks = get_position_embedding_ranks(ranks)
        group = torch.distributed.new_group(
            position_embedding_ranks,
            timeout=timeout,
            pg_options=get_nccl_options('embd', nccl_comm_cfgs),
        )
Vijay Korthikanti's avatar
Vijay Korthikanti committed
773
774
        if rank in position_embedding_ranks:
            _POSITION_EMBEDDING_GROUP = group
Vijay Korthikanti's avatar
Vijay Korthikanti committed
775
            _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
776

liangjing's avatar
liangjing committed
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
    # Build the tensor + data parallel groups.
    global _TENSOR_AND_DATA_PARALLEL_GROUP
    global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
    assert (
        _TENSOR_AND_DATA_PARALLEL_GROUP is None
    ), 'Tensor + data parallel group is already initialized'
    for ranks in generator_wrapper('tp-dp-cp'):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('tp_dp_cp', nccl_comm_cfgs)
        )
        if rank in ranks:
            _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
    for ranks in generator_wrapper('tp-dp'):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('tp_dp', nccl_comm_cfgs)
        )
        if rank in ranks:
            _TENSOR_AND_DATA_PARALLEL_GROUP = group

    global _TENSOR_AND_CONTEXT_PARALLEL_GROUP
    assert (
        _TENSOR_AND_CONTEXT_PARALLEL_GROUP is None
    ), 'Tensor + context parallel group is already initialized'
    for ranks in generator_wrapper('tp-cp'):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('tp_cp', nccl_comm_cfgs)
        )
        if rank in ranks:
            _TENSOR_AND_CONTEXT_PARALLEL_GROUP = group

    # Build the tensor + expert parallel groups
    global _EXPERT_MODEL_PARALLEL_GROUP
    assert _EXPERT_MODEL_PARALLEL_GROUP is None, 'Expert parallel group is already initialized'
    global _TENSOR_AND_EXPERT_PARALLEL_GROUP
    assert (
        _TENSOR_AND_EXPERT_PARALLEL_GROUP is None
    ), 'Tensor + expert parallel group is already initialized'
    global _DATA_MODULO_EXPERT_PARALLEL_GROUP
    assert (
        _DATA_MODULO_EXPERT_PARALLEL_GROUP is None
    ), 'Data modulo expert group is already initialized'
    global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
    assert (
        _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP is None
    ), 'Data modulo expert group with context parallel is already initialized'
    global _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
    global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO

    for ranks in generator_wrapper('tp-ep', independent_ep=True):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('tp_exp', nccl_comm_cfgs)
        )
        if rank in ranks:
            _TENSOR_AND_EXPERT_PARALLEL_GROUP = group

    for ranks in generator_wrapper('ep', independent_ep=True):
        group = torch.distributed.new_group(
            ranks, pg_options=get_nccl_options('exp', nccl_comm_cfgs)
        )
        if rank in ranks:
            _EXPERT_MODEL_PARALLEL_GROUP = group

    for ranks in generator_wrapper('dp', independent_ep=True):
        group = torch.distributed.new_group(
            ranks, timeout=timeout, pg_options=get_nccl_options('dp_modulo_exp', nccl_comm_cfgs)
        )
        group_gloo = torch.distributed.new_group(ranks, backend="gloo")
        if rank in ranks:
            _DATA_MODULO_EXPERT_PARALLEL_GROUP = group
            _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = group_gloo

    for ranks in generator_wrapper('dp-cp', independent_ep=True):
        # Lazy initialization of the group
        if get_context_parallel_world_size() > 1:
            group = torch.distributed.new_group(
                ranks,
                timeout=timeout,
                pg_options=get_nccl_options('dp_modulo_exp_cp', nccl_comm_cfgs),
            )
            group_gloo = torch.distributed.new_group(ranks, backend="gloo")
        else:
            group = _DATA_MODULO_EXPERT_PARALLEL_GROUP
            group_gloo = _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
        if rank in ranks:
            _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = group
            _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = group_gloo
liangjing's avatar
v1  
liangjing committed
863

864
865
866
867
868
869
    # Initialize global memory buffer
    # This isn't really "parallel state" but there isn't another good place to
    # put this. If we end up with a more generic initialization of megatron-core
    # we could stick it there
    _set_global_memory_buffer()

870

liangjing's avatar
liangjing committed
871
def is_initialized():
Abhinav Khattar's avatar
Abhinav Khattar committed
872
    """Useful for code segments that may be accessed with or without mpu initialization"""
liangjing's avatar
liangjing committed
873
874
875
876
877
878
879
880
881
882
883
    return _DATA_PARALLEL_GROUP is not None


def is_unitialized() -> bool:
    """Check if parallel state has been initialized

    Deprecated. Use is_initialized instead.

    """
    warnings.warn("is_unitialized is deprecated, use is_initialized instead", DeprecationWarning)
    return not is_initialized()
Abhinav Khattar's avatar
Abhinav Khattar committed
884
885


886
def model_parallel_is_initialized():
liangjing's avatar
liangjing committed
887
    """Check if model- and data-parallel groups are initialized."""
liangjing's avatar
v1  
liangjing committed
888
889
890
891
892
    if (
        _TENSOR_MODEL_PARALLEL_GROUP is None
        or _PIPELINE_MODEL_PARALLEL_GROUP is None
        or _DATA_PARALLEL_GROUP is None
    ):
893
894
895
896
        return False
    return True


liangjing's avatar
liangjing committed
897
898
899
900
901
902
903
def get_model_parallel_group(with_expert_parallel=False):
    """Get the model-parallel group the caller rank belongs to."""
    if with_expert_parallel:
        assert (
            _MODEL_AND_EXPERT_PARALLEL_GROUP is not None
        ), 'model parallel group is not initialized'
        return _MODEL_AND_EXPERT_PARALLEL_GROUP
liangjing's avatar
v1  
liangjing committed
904
    assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized'
905
906
907
    return _MODEL_PARALLEL_GROUP


liangjing's avatar
v1  
liangjing committed
908
def get_tensor_model_parallel_group(check_initialized=True):
liangjing's avatar
liangjing committed
909
    """Get the tensor-model-parallel group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
910
911
912
913
    if check_initialized:
        assert (
            _TENSOR_MODEL_PARALLEL_GROUP is not None
        ), 'tensor model parallel group is not initialized'
914
    return _TENSOR_MODEL_PARALLEL_GROUP
915
916


917
def get_pipeline_model_parallel_group():
liangjing's avatar
liangjing committed
918
    """Get the pipeline-model-parallel group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
919
920
921
    assert (
        _PIPELINE_MODEL_PARALLEL_GROUP is not None
    ), 'pipeline_model parallel group is not initialized'
922
    return _PIPELINE_MODEL_PARALLEL_GROUP
923
924


liangjing's avatar
liangjing committed
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
def get_data_parallel_group(with_context_parallel=False):
    """Get the data-parallel group the caller rank belongs to."""
    if with_context_parallel:
        assert (
            _DATA_PARALLEL_GROUP_WITH_CP is not None
        ), 'data parallel group with context parallel combined is not initialized'
        return _DATA_PARALLEL_GROUP_WITH_CP
    else:
        assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
        return _DATA_PARALLEL_GROUP


def get_data_parallel_group_gloo(with_context_parallel=False):
    """Get the Gloo data-parallel group the caller rank belongs to."""
    if with_context_parallel:
        assert (
            _DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None
        ), 'data parallel group-gloo with context parallel combined is not initialized'
        return _DATA_PARALLEL_GROUP_WITH_CP_GLOO
    else:
        assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized'
        return _DATA_PARALLEL_GROUP_GLOO
947
948


liangjing's avatar
liangjing committed
949
950
951
952
953
954
955
956
957
958
959
960
961
962
def get_context_parallel_group(check_initialized=True):
    """Get the context-parallel group the caller rank belongs to."""
    if check_initialized:
        assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized'
    return _CONTEXT_PARALLEL_GROUP


def get_context_parallel_global_ranks(check_initialized=True):
    """Get all global ranks of the context-parallel group that the caller rank belongs to."""
    if check_initialized:
        assert (
            _CONTEXT_PARALLEL_GLOBAL_RANKS is not None
        ), 'context parallel group is not initialized'
    return _CONTEXT_PARALLEL_GLOBAL_RANKS
liangjing's avatar
v1  
liangjing committed
963
964


965
966
def get_embedding_group():
    """Get the embedding group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
967
    assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized'
968
969
970
    return _EMBEDDING_GROUP


Vijay Korthikanti's avatar
Vijay Korthikanti committed
971
972
def get_position_embedding_group():
    """Get the position embedding group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
973
    assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
974
975
976
    return _POSITION_EMBEDDING_GROUP


liangjing's avatar
liangjing committed
977
def get_amax_reduction_group(with_context_parallel=False, tp_only_amax_red=False):
liangjing's avatar
v1  
liangjing committed
978
    """Get the FP8 amax reduction group the caller rank belongs to."""
liangjing's avatar
liangjing committed
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
    if with_context_parallel:
        if not tp_only_amax_red:
            assert (
                _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
            ), 'FP8 amax reduction group is not initialized'
            return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
        else:
            assert (
                _TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None
            ), 'FP8 amax reduction group is not initialized'
            return _TENSOR_AND_CONTEXT_PARALLEL_GROUP
    else:
        if not tp_only_amax_red:
            assert (
                _TENSOR_AND_DATA_PARALLEL_GROUP is not None
            ), 'FP8 amax reduction group is not initialized'
            return _TENSOR_AND_DATA_PARALLEL_GROUP
        else:
            assert (
                _TENSOR_MODEL_PARALLEL_GROUP is not None
            ), 'FP8 amax reduction group is not initialized'
            return _TENSOR_MODEL_PARALLEL_GROUP


def get_tensor_and_data_parallel_group(with_context_parallel=False):
    """Get the tensor- and data-parallel group the caller rank belongs to."""
    if with_context_parallel:
        assert (
            _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
        ), 'tensor and data parallel group is not initialized'
        return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
    else:
        assert (
            _TENSOR_AND_DATA_PARALLEL_GROUP is not None
        ), 'tensor and data parallel group is not initialized'
        return _TENSOR_AND_DATA_PARALLEL_GROUP


def get_tensor_and_context_parallel_group():
    """Get the tensor- and context-parallel group the caller rank belongs to."""
    assert (
        _TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None
    ), 'tensor and context parallel group is not initialized'
    return _TENSOR_AND_CONTEXT_PARALLEL_GROUP


def get_expert_model_parallel_group():
    """Get the expert-model-parallel group the caller rank belongs to."""
    assert (
        _EXPERT_MODEL_PARALLEL_GROUP is not None
    ), 'expert model parallel group is not initialized'
    return _EXPERT_MODEL_PARALLEL_GROUP


def get_tensor_and_expert_parallel_group():
    """Get the tensor- and expert-parallel group the caller rank belongs to."""
    assert (
        _TENSOR_AND_EXPERT_PARALLEL_GROUP is not None
    ), 'tensor and expert parallel group is not initialized'
    return _TENSOR_AND_EXPERT_PARALLEL_GROUP


def get_data_modulo_expert_parallel_group(with_context_parallel=False):
    """Get the data-modulo-expert-parallel group the caller rank belongs to."""
    if with_context_parallel:
        assert (
            _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP is not None
        ), 'data modulo expert parallel group with context parallel is not initialized'
        return _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
    else:
        assert (
            _DATA_MODULO_EXPERT_PARALLEL_GROUP is not None
        ), 'data modulo expert parallel group is not initialized'
        return _DATA_MODULO_EXPERT_PARALLEL_GROUP


def get_data_modulo_expert_parallel_group_gloo(with_context_parallel=False):
    """Get the Gloo data-modulo-expert-parallel group the caller rank belongs to."""
    if with_context_parallel:
        assert (
            _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO is not None
        ), 'data modulo expert parallel group-gloo with context parallel is not initialized'
        return _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
    else:
        assert (
            _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO is not None
        ), 'data modulo expert parallel group-gloo is not initialized'
        return _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO


def set_expert_model_parallel_world_size(world_size):
    """Sets the expert-model-parallel world size."""
    global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
    _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = world_size
liangjing's avatar
v1  
liangjing committed
1073
1074


1075
def set_tensor_model_parallel_world_size(world_size):
liangjing's avatar
liangjing committed
1076
    """Set the tensor-model-parallel size"""
1077
1078
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
1079
1080


1081
def set_pipeline_model_parallel_world_size(world_size):
liangjing's avatar
liangjing committed
1082
    """Set the pipeline-model-parallel size"""
1083
1084
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
1085
1086


liangjing's avatar
v1  
liangjing committed
1087
def set_virtual_pipeline_model_parallel_world_size(world_size):
liangjing's avatar
liangjing committed
1088
    """Set the pipeline-model-parallel size"""
liangjing's avatar
v1  
liangjing committed
1089
1090
1091
1092
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size


1093
def get_tensor_model_parallel_world_size():
liangjing's avatar
liangjing committed
1094
    """Return world size for the tensor-model-parallel group."""
1095
1096
1097
1098
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
1099
1100


1101
def get_pipeline_model_parallel_world_size():
liangjing's avatar
liangjing committed
1102
    """Return world size for the pipeline-model-parallel group."""
1103
1104
1105
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
liangjing's avatar
liangjing committed
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122

    pp_group = get_pipeline_model_parallel_group()
    if isinstance(pp_group, list):
        # Implicit assumption that each PP group is the same size.
        sizes = []
        for group in _PIPELINE_GLOBAL_RANKS:
            sizes.append(len(group))
        assert all(x == sizes[0] for x in sizes)
        return torch.distributed.get_world_size(group=pp_group[0])
    else:
        return torch.distributed.get_world_size(group=pp_group)


def set_expert_model_parallel_rank(rank):
    """Set expert-model-parallel rank."""
    global _MPU_EXPERT_MODEL_PARALLEL_RANK
    _MPU_EXPERT_MODEL_PARALLEL_RANK = rank
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
1123
1124


1125
def set_tensor_model_parallel_rank(rank):
liangjing's avatar
liangjing committed
1126
    """Set tensor-model-parallel rank."""
1127
1128
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
1129
1130


1131
def set_pipeline_model_parallel_rank(rank):
liangjing's avatar
liangjing committed
1132
    """Set pipeline-model-parallel rank."""
1133
1134
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
1135
1136


1137
def set_pipeline_model_parallel_split_rank(rank):
liangjing's avatar
liangjing committed
1138
    """Set pipeline-model-parallel split rank. DEPRECATED."""
1139
1140
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
1141
1142


1143
def get_tensor_model_parallel_rank():
liangjing's avatar
liangjing committed
1144
    """Return caller's rank for the tensor-model-parallel group."""
1145
1146
1147
1148
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
1149
1150


1151
def get_pipeline_model_parallel_rank():
liangjing's avatar
liangjing committed
1152
    """Return caller's rank for the pipeline-model-parallel group."""
1153
1154
1155
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_RANK
liangjing's avatar
liangjing committed
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
    rank = torch.distributed.get_rank()
    pp_group = get_pipeline_model_parallel_group()
    if isinstance(pp_group, list):
        # Assume that if the caller exist in multiple PP groups, then it has the same index.
        indices = []
        for group in _PIPELINE_GLOBAL_RANKS:
            for i, r in enumerate(group):
                if r == rank:
                    indices.append(i)
        assert all(x == indices[0] for x in indices)
        return torch.distributed.get_rank(group=pp_group[0])
    else:
        return torch.distributed.get_rank(group=pp_group)
1169
1170


1171
def get_pipeline_model_parallel_split_rank():
liangjing's avatar
liangjing committed
1172
    """Return pipeline-model-parallel split rank."""
1173
1174
1175
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK

1176

1177
def is_pipeline_first_stage(ignore_virtual=False):
1178
    """Return True if in the first pipeline model-parallel stage, False otherwise."""
1179
    if not ignore_virtual:
liangjing's avatar
v1  
liangjing committed
1180
1181
1182
1183
        if (
            get_virtual_pipeline_model_parallel_world_size() is not None
            and get_virtual_pipeline_model_parallel_rank() != 0
        ):
1184
            return False
1185
    return get_pipeline_model_parallel_rank() == 0
1186
1187


1188
def is_pipeline_last_stage(ignore_virtual=False):
liangjing's avatar
liangjing committed
1189
    """Return True if in the last pipeline-model-parallel stage, False otherwise."""
1190
    if not ignore_virtual:
liangjing's avatar
v1  
liangjing committed
1191
        virtual_pipeline_model_parallel_world_size = (
1192
            get_virtual_pipeline_model_parallel_world_size()
liangjing's avatar
v1  
liangjing committed
1193
        )
liangjing's avatar
liangjing committed
1194
1195
1196
1197
        if (
            virtual_pipeline_model_parallel_world_size is not None
            and get_virtual_pipeline_model_parallel_rank()
            != (virtual_pipeline_model_parallel_world_size - 1)
liangjing's avatar
v1  
liangjing committed
1198
        ):
1199
            return False
liangjing's avatar
v1  
liangjing committed
1200
    return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
1201
1202


1203
1204
1205
1206
def is_rank_in_embedding_group(ignore_virtual=False):
    """Return true if current rank is in embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _EMBEDDING_GLOBAL_RANKS
liangjing's avatar
liangjing committed
1207
1208
    if _EMBEDDING_GLOBAL_RANKS is None:
        return False
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
    if ignore_virtual:
        return rank in _EMBEDDING_GLOBAL_RANKS
    if rank in _EMBEDDING_GLOBAL_RANKS:
        if rank == _EMBEDDING_GLOBAL_RANKS[0]:
            return is_pipeline_first_stage(ignore_virtual=False)
        elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
            return is_pipeline_last_stage(ignore_virtual=False)
        else:
            return True
    return False


Vijay Korthikanti's avatar
Vijay Korthikanti committed
1221
1222
1223
1224
def is_rank_in_position_embedding_group():
    """Return true if current rank is in position embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _POSITION_EMBEDDING_GLOBAL_RANKS
liangjing's avatar
liangjing committed
1225
    return _POSITION_EMBEDDING_GLOBAL_RANKS is not None and rank in _POSITION_EMBEDDING_GLOBAL_RANKS
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1226
1227


1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
def is_pipeline_stage_before_split(rank=None):
    """Return True if pipeline stage executes encoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_after_split(rank=None):
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


liangjing's avatar
liangjing committed
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
def is_inside_encoder(rank=None):
    """Return True if pipeline stage executes encoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_DECODER_START
    if _PIPELINE_MODEL_PARALLEL_DECODER_START is None:
        return True
    if rank < _PIPELINE_MODEL_PARALLEL_DECODER_START:
        return True
    return False


def is_inside_decoder(rank=None):
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_DECODER_START
    if _PIPELINE_MODEL_PARALLEL_DECODER_START is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_DECODER_START:
        return True
    return False


1288
1289
1290
1291
1292
def is_pipeline_stage_at_split():
    """Return true if pipeline stage executes decoder block and next
    stage executes encoder block for a model with both encoder and
    decoder."""
    rank = get_pipeline_model_parallel_rank()
liangjing's avatar
v1  
liangjing committed
1293
    return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
1294
1295


1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
def get_virtual_pipeline_model_parallel_rank():
    """Return the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK


def set_virtual_pipeline_model_parallel_rank(rank):
    """Set the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank


1308
1309
1310
1311
1312
1313
def get_virtual_pipeline_model_parallel_world_size():
    """Return the virtual pipeline-parallel world size."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE


1314
def get_tensor_model_parallel_src_rank():
1315
    """Calculate the global rank corresponding to the first local rank
1316
    in the tensor model parallel group."""
liangjing's avatar
liangjing committed
1317
1318
1319
1320
    assert (
        _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None
    ), "Tensor model parallel group is not initialized"
    return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0]
1321

1322

liangjing's avatar
liangjing committed
1323
def get_data_parallel_src_rank(with_context_parallel=False):
1324
    """Calculate the global rank corresponding to the first local rank
1325
    in the data parallel group."""
liangjing's avatar
liangjing committed
1326
1327
1328
1329
1330
1331
1332
1333
    if with_context_parallel:
        assert (
            _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
        ), "Data parallel group with context parallel combined is not initialized"
        return _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP[0]
    else:
        assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized"
        return _DATA_PARALLEL_GLOBAL_RANKS[0]
1334
1335


1336
def get_pipeline_model_parallel_first_rank():
liangjing's avatar
liangjing committed
1337
    """Return the global rank of the first stage in the current rank's pipeline."""
liangjing's avatar
v1  
liangjing committed
1338
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
liangjing's avatar
liangjing committed
1339
1340
1341
1342
1343
1344
1345
    if isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
        # I assume the first rank is the same for all pp groups right now.
        for rank_group in _PIPELINE_GLOBAL_RANKS:
            assert rank_group[0] == _PIPELINE_GLOBAL_RANKS[0][0]
        return _PIPELINE_GLOBAL_RANKS[0][0]
    else:
        return _PIPELINE_GLOBAL_RANKS[0]
1346

1347

1348
def get_pipeline_model_parallel_last_rank():
liangjing's avatar
liangjing committed
1349
    """Return the global rank of the last stage in the current rank's pipeline."""
liangjing's avatar
v1  
liangjing committed
1350
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
1351
1352
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]
1353

liangjing's avatar
v1  
liangjing committed
1354

1355
def get_pipeline_model_parallel_next_rank():
liangjing's avatar
liangjing committed
1356
1357
1358
1359
1360
    """Return the global rank that follows the caller in the pipeline, for each
    pipeline-parallel group that the rank is part of.

    If it is just part of one group, an int is returned, otherwise a list of ints.
    """
liangjing's avatar
v1  
liangjing committed
1361
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
1362
1363
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
liangjing's avatar
liangjing committed
1364
1365
1366
1367
1368
1369
1370
    if isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
        to_return = []
        for group in _PIPELINE_GLOBAL_RANKS:
            to_return.append(group[(rank_in_pipeline + 1) % world_size])
        return to_return
    else:
        return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
1371

1372

1373
def get_pipeline_model_parallel_prev_rank():
liangjing's avatar
liangjing committed
1374
1375
1376
1377
1378
    """Return the global rank that precedes the caller in the pipeline, for each
    pipeline-parallel group that the rank is part of.

    If it is just part of one group, an int is returned, otherwise a list of ints.
    """
liangjing's avatar
v1  
liangjing committed
1379
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
1380
1381
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
liangjing's avatar
liangjing committed
1382
1383
1384
1385
1386
1387
1388
    if isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
        to_return = []
        for group in _PIPELINE_GLOBAL_RANKS:
            to_return.append(group[(rank_in_pipeline - 1) % world_size])
        return to_return
    else:
        return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
1389

1390

liangjing's avatar
liangjing committed
1391
def get_data_parallel_world_size(with_context_parallel=False):
1392
    """Return world size for the data parallel group."""
liangjing's avatar
liangjing committed
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
    global _MPU_DATA_PARALLEL_WORLD_SIZE
    if _MPU_DATA_PARALLEL_WORLD_SIZE is not None:
        return _MPU_DATA_PARALLEL_WORLD_SIZE
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_world_size(
            group=get_data_parallel_group(with_context_parallel=with_context_parallel)
        )
    else:
        return 0


def set_data_parallel_rank(rank):
    """Return world size for the data parallel group."""
    global _MPU_DATA_PARALLEL_RANK
    _MPU_DATA_PARALLEL_RANK = rank


def get_data_parallel_rank(with_context_parallel=False):
    """Return caller's rank in the data-parallel group."""
    global _MPU_DATA_PARALLEL_RANK
    if _MPU_DATA_PARALLEL_RANK is not None:
        return _MPU_DATA_PARALLEL_RANK
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(
            group=get_data_parallel_group(with_context_parallel=with_context_parallel)
        )
    else:
        return 0


def get_context_parallel_world_size():
    """Return world size for the context parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_world_size(group=get_context_parallel_group())
    else:
        return 0


def get_context_parallel_rank():
    """Return caller's rank in the context-parallel group."""
liangjing's avatar
v1  
liangjing committed
1433
    if torch.distributed.is_available() and torch.distributed.is_initialized():
liangjing's avatar
liangjing committed
1434
        return torch.distributed.get_rank(group=get_context_parallel_group())
liangjing's avatar
v1  
liangjing committed
1435
1436
    else:
        return 0
1437
1438


liangjing's avatar
liangjing committed
1439
1440
def get_tensor_and_context_parallel_world_size():
    """Return world size for the tensor and context-parallel group."""
liangjing's avatar
v1  
liangjing committed
1441
    if torch.distributed.is_available() and torch.distributed.is_initialized():
liangjing's avatar
liangjing committed
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
        return torch.distributed.get_world_size(group=get_tensor_and_context_parallel_group())
    else:
        return 0


def get_tensor_and_context_parallel_rank():
    """Return caller's rank in the joint tensor-model-parallel and context-parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(group=get_tensor_and_context_parallel_group())
    else:
        return 0


def get_expert_model_parallel_world_size():
    """Return world size for the expert-model-parallel group."""
    if _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        tensor_and_expert_parallel_world_size = torch.distributed.get_world_size(
            group=get_tensor_and_expert_parallel_group()
        )
        return tensor_and_expert_parallel_world_size // get_tensor_model_parallel_world_size()
    else:
        return 0


def get_tensor_and_expert_parallel_world_size():
    """Return world size for the expert model parallel group times model parallel group.
    Currently, each expert will also be distributed across TP group by default.
    """
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        tensor_and_expert_parallel_world_size = torch.distributed.get_world_size(
            group=get_tensor_and_expert_parallel_group()
        )
        return tensor_and_expert_parallel_world_size
    else:
        return 0


def get_expert_model_parallel_rank():
    """Return caller's rank in the expert-model-parallel group."""
    if _MPU_EXPERT_MODEL_PARALLEL_RANK is not None:
        return _MPU_EXPERT_MODEL_PARALLEL_RANK
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        tensor_and_expert_parallel_rank = torch.distributed.get_rank(
            group=get_tensor_and_expert_parallel_group()
        )
        return tensor_and_expert_parallel_rank // get_tensor_model_parallel_world_size()
    else:
        return 0


def get_data_modulo_expert_parallel_rank(with_context_parallel=False):
    """Return caller's rank in the context-parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(
            group=get_data_modulo_expert_parallel_group(with_context_parallel=with_context_parallel)
        )
    else:
        return 0


def get_tensor_and_expert_parallel_rank():
    """Return caller's rank in the joint tensor- and expert-model-parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(group=get_tensor_and_expert_parallel_group())
liangjing's avatar
v1  
liangjing committed
1508
1509
1510
    else:
        return 0

1511

1512
def _set_global_memory_buffer():
liangjing's avatar
liangjing committed
1513
    """Initialize global buffer."""
1514
1515
1516
1517
    global _GLOBAL_MEMORY_BUFFER
    assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
    _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()

liangjing's avatar
v1  
liangjing committed
1518

1519
def get_global_memory_buffer():
1520
    """Return the global GlobalMemoryBuffer object"""
1521
1522
1523
1524
    assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
    return _GLOBAL_MEMORY_BUFFER


liangjing's avatar
v1  
liangjing committed
1525
1526
1527
1528
1529
1530
def destroy_global_memory_buffer():
    """Sets the global memory buffer to None"""
    global _GLOBAL_MEMORY_BUFFER
    _GLOBAL_MEMORY_BUFFER = None


liangjing's avatar
liangjing committed
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
def get_all_ranks():
    """Get caller's rank in tensor-model-parallel, data-parallel, context-parallel,
    pipeline-model-parallel and expert-model-parallel groups."""
    ranks = [
        get_tensor_model_parallel_rank(),
        get_data_parallel_rank(),
        get_context_parallel_rank(),
        get_pipeline_model_parallel_rank(),
        get_expert_model_parallel_rank(),
    ]
    return '_'.join(map(lambda x: str(x or 0), ranks))


def get_moe_layer_wise_logging_tracker():
    """Return the moe layer wise tracker."""
    global _MOE_LAYER_WISE_LOGGING_TRACKER
    return _MOE_LAYER_WISE_LOGGING_TRACKER


1550
1551
def destroy_model_parallel():
    """Set the groups to none."""
1552
1553
    global _MODEL_PARALLEL_GROUP
    _MODEL_PARALLEL_GROUP = None
liangjing's avatar
liangjing committed
1554
1555
1556
1557

    global _MODEL_AND_EXPERT_PARALLEL_GROUP
    _MODEL_AND_EXPERT_PARALLEL_GROUP = None

1558
1559
    global _TENSOR_MODEL_PARALLEL_GROUP
    _TENSOR_MODEL_PARALLEL_GROUP = None
liangjing's avatar
liangjing committed
1560

1561
1562
    global _PIPELINE_MODEL_PARALLEL_GROUP
    _PIPELINE_MODEL_PARALLEL_GROUP = None
liangjing's avatar
liangjing committed
1563

1564
1565
    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
liangjing's avatar
liangjing committed
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575

    global _DATA_PARALLEL_GROUP_WITH_CP
    _DATA_PARALLEL_GROUP_WITH_CP = None

    global _CONTEXT_PARALLEL_GROUP
    _CONTEXT_PARALLEL_GROUP = None

    global _CONTEXT_PARALLEL_GLOBAL_RANKS
    _CONTEXT_PARALLEL_GLOBAL_RANKS = None

1576
1577
    global _EMBEDDING_GROUP
    _EMBEDDING_GROUP = None
liangjing's avatar
liangjing committed
1578

Vijay Korthikanti's avatar
Vijay Korthikanti committed
1579
1580
    global _POSITION_EMBEDDING_GROUP
    _POSITION_EMBEDDING_GROUP = None
liangjing's avatar
liangjing committed
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602

    global _TENSOR_AND_DATA_PARALLEL_GROUP
    _TENSOR_AND_DATA_PARALLEL_GROUP = None

    global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
    _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None

    global _TENSOR_AND_CONTEXT_PARALLEL_GROUP
    _TENSOR_AND_CONTEXT_PARALLEL_GROUP = None

    global _EXPERT_MODEL_PARALLEL_GROUP
    _EXPERT_MODEL_PARALLEL_GROUP = None

    global _TENSOR_AND_EXPERT_PARALLEL_GROUP
    _TENSOR_AND_EXPERT_PARALLEL_GROUP = None

    global _DATA_MODULO_EXPERT_PARALLEL_GROUP
    _DATA_MODULO_EXPERT_PARALLEL_GROUP = None

    global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP
    _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP = None

1603
1604
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
liangjing's avatar
liangjing committed
1605

1606
1607
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
liangjing's avatar
liangjing committed
1608

1609
1610
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
liangjing's avatar
liangjing committed
1611

1612
1613
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
liangjing's avatar
liangjing committed
1614

1615
1616
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = None
liangjing's avatar
liangjing committed
1617

1618
1619
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
liangjing's avatar
liangjing committed
1620

1621
1622
    global _GLOBAL_MEMORY_BUFFER
    _GLOBAL_MEMORY_BUFFER = None
liangjing's avatar
liangjing committed
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647

    global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
    _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None

    global _MPU_EXPERT_MODEL_PARALLEL_RANK
    _MPU_EXPERT_MODEL_PARALLEL_RANK = None

    global _DATA_PARALLEL_GROUP_GLOO
    if _DATA_PARALLEL_GROUP_GLOO is not None:
        torch.distributed.destroy_process_group(_DATA_PARALLEL_GROUP_GLOO)
    _DATA_PARALLEL_GROUP_GLOO = None

    global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
    _DATA_PARALLEL_GROUP_WITH_CP_GLOO = None

    global _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO
    if _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO is not None:
        torch.distributed.destroy_process_group(_DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO)
    _DATA_MODULO_EXPERT_PARALLEL_GROUP_GLOO = None

    global _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO
    _DATA_MODULO_EXPERT_PARALLEL_GROUP_WITH_CP_GLOO = None

    global _MOE_LAYER_WISE_LOGGING_TRACKER
    _MOE_LAYER_WISE_LOGGING_TRACKER = {}