parallel_state.py 83.8 KB
Newer Older
Jared Casper's avatar
Jared Casper committed
1
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
2
3
4

"""Model and data parallel groups."""

liangjing's avatar
v1  
liangjing committed
5
import os
xingjinliang's avatar
xingjinliang committed
6
7
8
9
10
import warnings
from datetime import timedelta
from functools import partial
from itertools import cycle
from typing import Callable, List, Optional
11

liangjing's avatar
v1  
liangjing committed
12
13
import torch

wangxj's avatar
wangxj committed
14
from .utils import GlobalMemoryBuffer, is_torch_min_version
15

16
# Intra-layer model parallel group that the current rank belongs to.
17
_TENSOR_MODEL_PARALLEL_GROUP = None
18
# Inter-layer model parallel group that the current rank belongs to.
19
20
_PIPELINE_MODEL_PARALLEL_GROUP = None
# Model parallel group (both intra- and pipeline) that the current rank belongs to.
21
_MODEL_PARALLEL_GROUP = None
xingjinliang's avatar
xingjinliang committed
22
# Model parallel group (both intra-, pipeline, and expert) that the current rank belongs to.
23
24
# Embedding group.
_EMBEDDING_GROUP = None
Vijay Korthikanti's avatar
Vijay Korthikanti committed
25
# Position embedding group.
Vijay Korthikanti's avatar
Vijay Korthikanti committed
26
_POSITION_EMBEDDING_GROUP = None
27
28
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
liangjing's avatar
v1  
liangjing committed
29
_DATA_PARALLEL_GROUP_GLOO = None
xingjinliang's avatar
xingjinliang committed
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# tensor model parallel group and data parallel group combined
# used for fp8 and moe training
_TENSOR_AND_DATA_PARALLEL_GROUP = None

### Expert-related parallel states
# Naming convention:
# _EXPERT prefix in group name means it's used for expert layer in MoE models.
# _EXPERT_MODEL denotes expert parallelism which splits number of experts across the group.
# _EXPERT_TENSOR denotes tensor parallelism of expert which splits tensor across the group.
# _EXPERT_DATA denotes data parallelism of expert which replicates weight across the group.

# Expert model parallel group that current rank belongs to.
_EXPERT_MODEL_PARALLEL_GROUP = None
# Expert tensor parallel group that current rank belongs to.
_EXPERT_TENSOR_PARALLEL_GROUP = None
# Expert tensor and model combined parallel group
_EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = None
# Expert tensor, model, pipeline combined parallel group
_EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = None
# Expert data parallel group
_EXPERT_DATA_PARALLEL_GROUP = None
_EXPERT_DATA_PARALLEL_GROUP_GLOO = None
# Parallel state values changed on the fly
_MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_EXPERT_MODEL_PARALLEL_RANK = None
_MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE = None
_MPU_EXPERT_TENSOR_PARALLEL_RANK = None
### End of expert related parallel states
58

59
60
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
61
_PIPELINE_MODEL_PARALLEL_SPLIT_RANK = None
62

xingjinliang's avatar
xingjinliang committed
63
64
_PIPELINE_MODEL_PARALLEL_DECODER_START = None

Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
65
# These values enable us to change the mpu sizes on the fly.
66
67
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
xingjinliang's avatar
xingjinliang committed
68
69
_MPU_DATA_PARALLEL_WORLD_SIZE = None
_MPU_DATA_PARALLEL_RANK = None
70
71
_MPU_TENSOR_MODEL_PARALLEL_RANK = None
_MPU_PIPELINE_MODEL_PARALLEL_RANK = None
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
72

73
74
75
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None

Vijay Korthikanti's avatar
Vijay Korthikanti committed
76
77
78
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None

79
# A list of global ranks for each pipeline group to ease calculation of the source
80
# rank when broadcasting from the first or last pipeline stage.
81
_PIPELINE_GLOBAL_RANKS = None
82

83
84
85
86
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS = None

xingjinliang's avatar
xingjinliang committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# A list of global ranks for each tensor model parallel group to ease calculation of
# the first local rank in the tensor model parallel group
_TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = None

# A list of global ranks for each model parallel group to ease calculation of
# the first local rank in the model parallel group
_MODEL_PARALLEL_GLOBAL_RANKS = None

# Context parallel group that the current rank belongs to
_CONTEXT_PARALLEL_GROUP = None
# A list of global ranks for each context parallel group to ease calculation of the
# destination rank when exchanging KV/dKV between context parallel_ranks
_CONTEXT_PARALLEL_GLOBAL_RANKS = None
# Hierarchical context parallel groups
_HIERARCHICAL_CONTEXT_PARALLEL_GROUPS = []

# Data parallel group information with context parallel combined.
_DATA_PARALLEL_GROUP_WITH_CP = None
_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
_DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = None

# Partial Data parallel group information with context parallel combined.
_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = None
_INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO = None
_INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = None

# combined parallel group of TP and CP
_TENSOR_AND_CONTEXT_PARALLEL_GROUP = None

# combined parallel group of TP, DP, and CP used for fp8
_TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None

119
120
# Memory buffers to avoid dynamic memory allocation
_GLOBAL_MEMORY_BUFFER = None
121

xingjinliang's avatar
xingjinliang committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
# MOE logging
_MOE_LAYER_WISE_LOGGING_TRACKER = {}


def get_nccl_options(pg_name, nccl_comm_cfgs):
    """Set the NCCL process group options.

    Args:
        pg_name (str): process group name
        nccl_comm_cfgs (dict): nccl communicator configurations

    When an option (e.g., max_ctas) is not found in the config, use the NCCL default setting.
    """
    if pg_name in nccl_comm_cfgs:
        nccl_options = torch.distributed.ProcessGroupNCCL.Options()
        nccl_options.config.cga_cluster_size = nccl_comm_cfgs[pg_name].get('cga_cluster_size', 4)
        nccl_options.config.max_ctas = nccl_comm_cfgs[pg_name].get('max_ctas', 32)
        nccl_options.config.min_ctas = nccl_comm_cfgs[pg_name].get('min_ctas', 1)
wangxj's avatar
wangxj committed
140
141
142
143
144
145
146
147
        if 'net_name' in nccl_comm_cfgs[pg_name]:
            nccl_options.config.net_name = nccl_comm_cfgs[pg_name].get('net_name')
            # verify net_name value
            if nccl_options.config.net_name.lower() not in ['ib', 'socket']:
                raise RuntimeError(
                    f"net_name ({nccl_options.config.net_name}) is not supported."
                    f"Accepted values: 'IB' or 'socket'."
                )
xingjinliang's avatar
xingjinliang committed
148
149
150
151
152
        return nccl_options
    else:
        return None


wangxj's avatar
wangxj committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
def create_group(
    ranks=None,
    timeout=None,
    backend=None,
    pg_options=None,
    use_local_synchronization=False,
    group_desc=None,
):
    """Creates a ProcessGroup."""
    kwargs = {
        'ranks': ranks,
        'timeout': timeout,
        'backend': backend,
        'pg_options': pg_options,
        'use_local_synchronization': use_local_synchronization,
        'group_desc': group_desc,
    }
    if not is_torch_min_version('2.4.0'):
        kwargs.pop('group_desc')
        if timeout is None:
            # Old version (e.g. v2.1.2) sets default_pg_timeout as default value to timeout
            # in function signature, then check tiemout value type.
            # New version sets None as default value to timeout in function signature. If value
            # is None, torch will give value according to the backend, then check type.
            # So need to unset timeout here if caller doesn't set value. Otherwise there is
            # type error.
            kwargs.pop('timeout')
    return torch.distributed.new_group(**kwargs)


xingjinliang's avatar
xingjinliang committed
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
def generate_masked_orthogonal_rank_groups(
    world_size: int, parallel_size: List[int], mask: List[bool]
) -> List[List[int]]:
    r"""Generate orthogonal parallel groups based on the parallel size and mask.

    Arguments:
        world_size (int): world size

        parallel_size (List[int]):
            The parallel size of each orthogonal parallel type. For example, if
            tensor_parallel_size = 2, pipeline_model_parallel_group = 3, data_parallel_size = 4,
            and the parallel mapping order is tp-pp-dp, then the parallel_size = [2, 3, 4].

        mask (List[bool]):
            The mask controls which parallel methods the generated groups represent. If mask[i] is
            True, it means the generated group contains the i-th parallelism method. For example,
            if parallel_size = [tp_size, pp_size, dp_size], and mask = [True, False , True], then
            the generated group is the `tp-dp` group, if the mask = [False, True, False], then the
            generated group is the `pp` group.

    Algorithm:
        For orthogonal parallelism, such as tp/dp/pp/cp, the global_rank and
        local_rank satisfy the following equation:
            global_rank = tp_rank + dp_rank * tp_size + pp_rank * tp_size * dp_size (1)
                tp_rank \in [0, tp_size)
                dp_rank \in [0, dp_size)
                pp_rank \in [0, pp_size)

        If we want to get the `dp_group` (tp_size * pp_size groups of dp_size ranks each.
        For example,  if the gpu size is 8 and order is 'tp-pp-dp', size is '2-2-2', and the
        dp_group here is [[0, 4], [1, 5], [2, 6], [3, 7]].)
        The tp_rank and pp_rank will be combined to form the `dp_group_index`.
            dp_group_index = tp_rank + pp_rank * tp_size (2)

        So, Given that tp_rank and pp_rank satisfy equation (2), and dp_rank in
        range(0, dp_size), the ranks in dp_group[dp_group_index] satisfies the
        equation (1).

        This function solve this math problem.

    For example, if the parallel_size = [tp_size, dp_size, pp_size] = [2, 3, 4],
    and the mask = [False, True, False]. Then,
        dp_group_index(0) = tp_rank(0) + pp_rank(0) * 2
        dp_group_index(1) = tp_rank(1) + pp_rank(0) * 2
        ...
        dp_group_index(7) = tp_rank(1) + pp_rank(3) * 2

        dp_group[0] = 0 + range(0, 3) * 2 + 0 = [0, 2, 4]
        dp_group[1] = 1 + range(0, 3) * 2 + 0 = [1, 3, 5]
        ...
        dp_group[7] = 1 + range(0, 3) * 2 + 3 * 2 * 3 = [19, 21, 23]
    """

    def prefix_product(a: List[int], init=1) -> List[int]:
        r = [init]
        for v in a:
            init = init * v
            r.append(init)
        return r

    def inner_product(a: List[int], b: List[int]) -> int:
        return sum([x * y for x, y in zip(a, b)])

    def decompose(index, shape, stride=None):
        """
        This function solve the math problem below:
            There is an equation:
                index = sum(idx[i] * stride[i])
            And given the value of index, stride.
            Return the idx.
        This function will be used to get the pp/dp/pp_rank
        from group_index and rank_in_group.
        """
        if stride is None:
            stride = prefix_product(shape)
        idx = [(index // d) % s for s, d in zip(shape, stride)]
        # stride is a prefix_product result. And the value of stride[-1]
        # is not used.
        assert (
            sum([x * y for x, y in zip(idx, stride[:-1])]) == index
        ), "idx {} with shape {} mismatch the return idx {}".format(index, shape, idx)
        return idx

    masked_shape = [s for s, m in zip(parallel_size, mask) if m]
    unmasked_shape = [s for s, m in zip(parallel_size, mask) if not m]

    global_stride = prefix_product(parallel_size)
    masked_stride = [d for d, m in zip(global_stride, mask) if m]
    unmasked_stride = [d for d, m in zip(global_stride, mask) if not m]

    group_size = prefix_product(masked_shape)[-1]
    num_of_group = world_size // group_size

    ranks = []
    for group_index in range(num_of_group):
        # get indices from unmaksed for group_index.
        decomposed_group_idx = decompose(group_index, unmasked_shape)
        rank = []
        for rank_in_group in range(group_size):
            # get indices from masked for rank_in_group.
            decomposed_rank_idx = decompose(rank_in_group, masked_shape)
            rank.append(
                inner_product(decomposed_rank_idx, masked_stride)
                + inner_product(decomposed_group_idx, unmasked_stride)
            )
        ranks.append(rank)
    return ranks


def create_hierarchical_parallel_groups(
    rank, ranks, group_size, hierarchical_group_sizes, pg_options
):
    """Create hierarchical groups for one parallelism.
    Taking a group size of 16 as example, so we have a total of 16 GPUs denoted by g0 ... g15.
    If the hierarchical group sizes are [2,2,4], we use 2 GPUs in the first and second level
    of sub-groups, and 4 GPUs in the last level of sub groups. The present function will
    create 8 level-1 sub-groups, 8 level-2 sub-groups and 4 level-3 sub-groups as:
        8 level-1 sub-groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
        8 level-2 sub-groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
        4 level-3 sub-groups:
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
    """

    hierarchical_groups = []
    accumulated_group_sizes = 1
    processed_group_sizes = 1
wangxj's avatar
wangxj committed
311
    for level, hierarchical_group_size in enumerate(hierarchical_group_sizes):
xingjinliang's avatar
xingjinliang committed
312
313
314
315
316
317
318
        accumulated_group_sizes *= hierarchical_group_size
        for k in range(group_size // accumulated_group_sizes):
            for j in range(processed_group_sizes):
                global_sub_ranks = [
                    ranks[j + i * processed_group_sizes + k * accumulated_group_sizes]
                    for i in range(hierarchical_group_size)
                ]
wangxj's avatar
wangxj committed
319
320
321
322
323
                sub_group = create_group(
                    global_sub_ranks,
                    pg_options=pg_options,
                    group_desc=f'HIERARCHICAL_CONTEXT_PARALLEL_GROUP_L{level}',
                )
xingjinliang's avatar
xingjinliang committed
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
                if rank in global_sub_ranks:
                    hierarchical_groups.append(sub_group)
        processed_group_sizes *= hierarchical_group_size
    return hierarchical_groups


class RankGenerator(object):
    """A class for generating rank groups for different modes of parallelism."""

    def __init__(
        self, tp: int, ep: int, dp: int, pp: int, cp: int, order: str, rank_offset: int = 0
    ) -> None:
        assert (
            ep == 1 or cp == 1
        ), "Both EP and CP > 1 in not allow in one rank generator. \
            CP is only included in default RankGenerator, and EP only in expert RankGenerator."

        self.tp = tp
        self.ep = ep
        self.dp = dp
        self.pp = pp
        self.cp = cp
        self.rank_offset = rank_offset
        self.world_size = tp * dp * pp * cp * ep

        self.name_to_size = {
            "tp": self.tp,
            "pp": self.pp,
            "dp": self.dp,
            "ep": self.ep,
            "cp": self.cp,
        }
        self.order = order
        order = order.lower()

        for name in self.name_to_size.keys():
            if name not in order and self.name_to_size[name] != 1:
                raise RuntimeError(
                    f"The size of ({name}) is ({self.name_to_size[name]}), but you haven't"
                    f"specified the order ({self.order})."
                )
            elif name not in order:
                order = order + '-' + name

        self.order = order
        self.ordered_size = []

        for token in order.split('-'):
            self.ordered_size.append(self.name_to_size[token])

    def get_mask(self, order: str, token: str):
        """Create a mask for the specified tokens based on the given order.

        Args:
            order (str): The order of parallelism types (e.g., 'tp-dp-pp').
            token (str): The specific parallelism types to include in the mask,
                         separated by hyphens (e.g., 'tp-dp').
        """
        ordered_token = order.split('-')
        token_list = token.split('-')
        mask = [False] * len(ordered_token)
        for t in token_list:
            mask[ordered_token.index(t)] = True
        return mask

    def get_ranks(self, token):
        """Get rank group by input token.

        Args:
            token (str):
                Specify the ranks type that want to get. If we want
                to obtain multiple parallel types, we can use a hyphen
                '-' to separate them. For example, if we want to obtain
                the TP_DP group, the token should be 'tp-dp'.
        """
        mask = self.get_mask(self.order, token)
        ranks = generate_masked_orthogonal_rank_groups(self.world_size, self.ordered_size, mask)
        if self.rank_offset > 0:
            for rank_group in ranks:
                for i in range(len(rank_group)):
                    rank_group[i] += self.rank_offset
        return ranks


def default_embedding_ranks(pp_ranks, split_rank=None):
    """Return the default ranks that constitute the stages on which the word embeddings live.
    For most models, these are the first and last pipeline stages.

    We also support the deprecated split rank argument for backwards compatibility."""
    if len(pp_ranks) == 1:
        return [pp_ranks[0]]
    elif split_rank is not None and pp_ranks[split_rank] not in (pp_ranks[0], pp_ranks[-1]):
        return [pp_ranks[0], pp_ranks[split_rank], pp_ranks[-1]]
    else:
        return [pp_ranks[0], pp_ranks[-1]]


def default_position_embedding_ranks(pp_ranks, split_rank=None):
    """Return the default ranks that constitute the stages on which the position embeddings live.
    For most models, this is only the first pipeline stage.

    We also support the deprecated split rank argument for backwards compatibility."""
    if split_rank is not None and pp_ranks[0] != pp_ranks[split_rank]:
        return [pp_ranks[0], pp_ranks[split_rank]]
    else:
        return [pp_ranks[0]]

431

432
433
434
435
436
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
    virtual_pipeline_model_parallel_size: Optional[int] = None,
    pipeline_model_parallel_split_rank: Optional[int] = None,
wangxj's avatar
wangxj committed
437
    pipeline_model_parallel_comm_backend: Optional[str] = None,
liangjing's avatar
v1  
liangjing committed
438
    use_sharp: bool = False,
xingjinliang's avatar
xingjinliang committed
439
440
441
442
443
444
445
446
447
448
449
450
    context_parallel_size: int = 1,
    hierarchical_context_parallel_sizes: Optional[List[int]] = None,
    expert_model_parallel_size: int = 1,
    num_distributed_optimizer_instances: int = 1,
    expert_tensor_parallel_size: Optional[int] = None,
    nccl_communicator_config_path: Optional[str] = None,
    distributed_timeout_minutes: int = 30,
    order: str = "tp-cp-ep-dp-pp",
    encoder_tensor_model_parallel_size: int = 0,
    encoder_pipeline_model_parallel_size: Optional[int] = 0,
    get_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
    get_position_embedding_ranks: Optional[Callable[[List[int], Optional[int]], List[int]]] = None,
wangxj's avatar
wangxj committed
451
    create_gloo_process_groups: bool = True,
452
) -> None:
xingjinliang's avatar
xingjinliang committed
453
    # pylint: disable=line-too-long
liangjing's avatar
v1  
liangjing committed
454
    """Initialize model data parallel groups.
455

xingjinliang's avatar
xingjinliang committed
456
    Args:
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
        tensor_model_parallel_size (int, default = 1):
            The number of GPUs to split individual tensors across.

        pipeline_model_parallel_size (int, default = 1):
            The number of tensor parallel GPU groups to split the
            Transformer layers across. For example, if
            tensor_model_parallel_size is 4 and
            pipeline_model_parallel_size is 2, the model will be split
            into 2 groups of 4 GPUs.

        virtual_pipeline_model_parallel_size (int, optional):
            The number of stages that each pipeline group will have,
            interleaving as necessary. If None, no interleaving is
            performed. For example, if tensor_model_parallel_size is 1,
            pipeline_model_parallel_size is 4,
            virtual_pipeline_model_parallel_size is 2, and there are
            16 transformer layers in the model, the model will be
            split into 8 stages with two layers each and each GPU
            would get 2 stages as such (layer number starting with 1):

            GPU 0: [1, 2] [9, 10]
            GPU 1: [3, 4] [11, 12]
            GPU 2: [5, 6] [13, 14]
            GPU 3: [7, 8] [15, 16]

        pipeline_model_parallel_split_rank (int, optional):
xingjinliang's avatar
xingjinliang committed
483
            DEPRECATED. For models with both an encoder and decoder, the rank in
484
485
486
487
488
489
490
            pipeline to switch between encoder and decoder (i.e. the
            first rank of the decoder). This allows the user to set
            the pipeline parallel size of the encoder and decoder
            independently. For example, if
            pipeline_model_parallel_size is 8 and
            pipeline_model_parallel_split_rank is 3, then ranks 0-2
            will be the encoder and ranks 3-7 will be the decoder.
491

wangxj's avatar
wangxj committed
492
493
494
495
        pipeline_model_parallel_comm_backend (str, optional):
            The backend to use for pipeline parallel communication.
            If None, the default backend will be used.

liangjing's avatar
v1  
liangjing committed
496
497
498
499
500
501
        use_sharp (bool, default = False):
            Set the use of SHARP for the collective communications of
            data-parallel process groups. When `True`, run barrier
            within each data-parallel process group, which specifies
            the SHARP application target groups.

xingjinliang's avatar
xingjinliang committed
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
        context_parallel_size (int, default = 1):
            The number of tensor parallel GPU groups to split the
            network input sequence length across. Compute of attention
            module requires tokens of full sequence length, so GPUs
            in a context parallel group need to communicate with each
            other to exchange information of other sequence chunks.
            Each GPU and its counterparts in other tensor parallel
            groups compose a context parallel group.

            For example, assume we have 8 GPUs, if tensor model parallel
            size is 4 and context parallel size is 2, the network input
            will be split into two sequence chunks, which are processed
            by 2 different groups of 4 GPUs. One chunk is processed by
            GPU0-3, the other chunk is processed by GPU4-7. Four groups
            are build to do context parallel communications: [GPU0, GPU4],
            [GPU1, GPU5], [GPU2, GPU6], and [GPU3, GPU7].

            Context parallelism partitions sequence length, so it has no
            impact on weights, which means weights are duplicated among
            GPUs in a context parallel group. Hence, weight gradients
            all-reduce is required in backward. For simplicity, we piggyback
            GPUs of context parallelism on data parallel group for
            weight gradient all-reduce.

        expert_model_parallel_size (int, default = 1):
            The number of Mixture of Experts parallel GPUs in each expert
            parallel group.

        num_distributed_optimizer_instances (int, default = 1):
            The number of distributed optimizer replicas across the data-
            parallel domain.

        expert_tensor_parallel_size (int, default = tp_size):
            The number of GPUs to split individual tensors of expert.

        nccl_communicator_config_path (str, default = None):
            Path to the yaml file of NCCL communicator configurations.
            `min_ctas`, `max_ctas`, and `cga_cluster_size` can be set
            for each communicator.

        distributed_timeout_minutes (int, default = 30): Timeout, in
            minutes,for operations executed against distributed
            process groups. See PyTorch documentation at
            https://pytorch.org/docs/stable/distributed.html for
            caveats.

        order (str, default=tp-dp-pp):
            The rank initialization order of parallelism. Now we support
            tp-dp-pp and tp-pp-dp orders.

        encoder_tensor_model_parallel_size (int, default = 0):
            The number of GPUs to split individual tensors across in the encoder. If 0,
            then we use the default, decoder's tensor model parallel size.

        encoder_pipeline_model_parallel_size (int, default = 0):
            The number of tensor parallel GPU groups to allocate to the encoder. As an example,
            if pipeline_model_parallel_size is 4 and encoder_pipeline_model_parallel_size is 2,
            then the encoder will use the first two pipeline stages for its layers, and the total
            amount of pipelineing is 6.

        get_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None):
            A function that takes in a list of ranks for a pipeline group and returns
            those ranks that should have embeddings.

        get_position_embedding_ranks (Callable[[List[int], Optional[int]], List[int]], optional, default=None):
            A function that takes in a list of ranks for a pipeline group, and returns
            those ranks that should have position embeddings.

wangxj's avatar
wangxj committed
570
571
572
573
        create_gloo_process_groups (bool, default = True):
            Create Gloo process groups if set to True. If set to False, Gloo process groups are
            not created and calls to get Gloo process groups will result in assertion errors.

574
    Let's say we have a total of 16 GPUs denoted by g0 ... g15 and we
575
576
577
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
    create 8 tensor model-parallel groups, 4 pipeline model-parallel groups
578
579
580
    and 8 data-parallel groups as:
        8 data_parallel groups:
            [g0, g2], [g1, g3], [g4, g6], [g5, g7], [g8, g10], [g9, g11], [g12, g14], [g13, g15]
581
        8 tensor model-parallel groups:
582
            [g0, g1], [g2, g3], [g4, g5], [g6, g7], [g8, g9], [g10, g11], [g12, g13], [g14, g15]
583
        4 pipeline model-parallel groups:
584
            [g0, g4, g8, g12], [g1, g5, g9, g13], [g2, g6, g10, g14], [g3, g7, g11, g15]
585
586
587
588
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
liangjing's avatar
v1  
liangjing committed
589

590
    """
xingjinliang's avatar
xingjinliang committed
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
    if encoder_pipeline_model_parallel_size is None:
        encoder_pipeline_model_parallel_size = 0

    if encoder_tensor_model_parallel_size == 0 and encoder_pipeline_model_parallel_size > 0:
        encoder_tensor_model_parallel_size = tensor_model_parallel_size

    if get_embedding_ranks is None:
        get_embedding_ranks = partial(
            default_embedding_ranks, split_rank=pipeline_model_parallel_split_rank
        )

    if get_position_embedding_ranks is None:
        get_position_embedding_ranks = partial(
            default_position_embedding_ranks, split_rank=pipeline_model_parallel_split_rank
        )

    if encoder_pipeline_model_parallel_size > 0:
        global _PIPELINE_MODEL_PARALLEL_DECODER_START
        _PIPELINE_MODEL_PARALLEL_DECODER_START = encoder_pipeline_model_parallel_size

611
612
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
613
614
    world_size: int = torch.distributed.get_world_size()

xingjinliang's avatar
xingjinliang committed
615
616
617
618
    if encoder_tensor_model_parallel_size > 0:
        assert (
            encoder_tensor_model_parallel_size <= tensor_model_parallel_size
        ), "We do not support encoders with more TP than the decoder."
619

xingjinliang's avatar
xingjinliang committed
620
621
622
623
    encoder_model_size = (
        encoder_tensor_model_parallel_size
        * encoder_pipeline_model_parallel_size
        * context_parallel_size
liangjing's avatar
v1  
liangjing committed
624
    )
xingjinliang's avatar
xingjinliang committed
625
626
627
628
629
630
631
632
633
    decoder_model_size = (
        tensor_model_parallel_size * pipeline_model_parallel_size * context_parallel_size
    )
    total_model_size = encoder_model_size + decoder_model_size

    if world_size % total_model_size != 0:
        raise RuntimeError(f"world_size ({world_size}) is not divisible by {total_model_size}")

    data_parallel_size: int = world_size // total_model_size
634

xingjinliang's avatar
xingjinliang committed
635
636
637
638
639
640
    encoder_world_size = encoder_model_size * data_parallel_size
    decoder_world_size = decoder_model_size * data_parallel_size

    assert (
        encoder_world_size + decoder_world_size == world_size
    ), f"{encoder_world_size=} + {decoder_world_size=} != {world_size=}"
641
642

    if virtual_pipeline_model_parallel_size is not None:
xingjinliang's avatar
xingjinliang committed
643
        if not pipeline_model_parallel_size > 1:
liangjing's avatar
v1  
liangjing committed
644
            raise RuntimeError(
xingjinliang's avatar
xingjinliang committed
645
                "pipeline-model-parallel size should be greater than 1 with interleaved schedule"
liangjing's avatar
v1  
liangjing committed
646
            )
647
648
649
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
        global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
650
        _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size
651

652
    if pipeline_model_parallel_split_rank is not None:
653
        global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
654
        _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = pipeline_model_parallel_split_rank
655

656
657
    rank = torch.distributed.get_rank()

xingjinliang's avatar
xingjinliang committed
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
    nccl_comm_cfgs = {}
    if nccl_communicator_config_path is not None:
        try:
            import yaml
        except ImportError:
            raise RuntimeError(
                "Cannot import `yaml`. Setting custom nccl communicator configs "
                "requires the yaml package."
            )

        with open(nccl_communicator_config_path, "r") as stream:
            nccl_comm_cfgs = yaml.safe_load(stream)

    if encoder_world_size > 0:
        encoder_rank_generator = RankGenerator(
            tp=encoder_tensor_model_parallel_size,
            ep=1,
            dp=data_parallel_size,
            pp=encoder_pipeline_model_parallel_size,
            cp=context_parallel_size,
            order=order,
            rank_offset=0,
        )
    else:
        encoder_rank_generator = None

    decoder_rank_generator = RankGenerator(
        tp=tensor_model_parallel_size,
        ep=1,
        dp=data_parallel_size,
        pp=pipeline_model_parallel_size,
        cp=context_parallel_size,
        order=order,
        rank_offset=encoder_world_size,
    )

    # Build expert rank generator
    if expert_tensor_parallel_size is None:
        expert_tensor_parallel_size = tensor_model_parallel_size
    expert_tensor_model_pipeline_parallel_size = (
        expert_tensor_parallel_size * expert_model_parallel_size * pipeline_model_parallel_size
    )
    expert_data_parallel_size = decoder_world_size // expert_tensor_model_pipeline_parallel_size
    if decoder_world_size % expert_tensor_model_pipeline_parallel_size != 0:
        raise RuntimeError(
            f"decoder world_size ({decoder_world_size}) is not divisible by expert_tensor_model_pipeline_parallel size ({expert_tensor_model_pipeline_parallel_size})"
        )

    # TODO: support expert specific ordering
    expert_decoder_rank_generator = RankGenerator(
        tp=expert_tensor_parallel_size,
        ep=expert_model_parallel_size,
        dp=expert_data_parallel_size,
        pp=pipeline_model_parallel_size,
        cp=1,
        order=order,
        rank_offset=encoder_world_size,
    )

wangxj's avatar
wangxj committed
717
718
719
720
721
722
    assert (
        order.endswith("pp")
        or pipeline_model_parallel_size == 1
        or expert_data_parallel_size == data_parallel_size
    ), "When not using pp-last rank ordering, the data parallel size of the attention and moe layers must be the same"

xingjinliang's avatar
xingjinliang committed
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
    assert decoder_rank_generator.get_ranks("pp") == expert_decoder_rank_generator.get_ranks(
        "pp"
    ), f"Pipeline parallel groups are expected to be the same for Non-Expert and Expert part, \
    but got {decoder_rank_generator.get_ranks('pp')} and {expert_decoder_rank_generator.get_ranks('pp')}"

    def generator_wrapper(group_type, is_expert=False, **kwargs):
        """The `RankGenerator` class produces a hyper-rectangle for a given set of
        tensor, pipeline, data, expert, and context parallelism. If we have an encoder,
        in addition to the default decoder, we essentially instantiate two `RankGenerator`
        classes to construct the parallelism for each module separately, and we then have
        to stitch them together for the right groups. For now, this means pp and tp-pp."""
        if is_expert:
            d_ranks = expert_decoder_rank_generator.get_ranks(group_type, **kwargs)
        else:
            d_ranks = decoder_rank_generator.get_ranks(group_type, **kwargs)

        if encoder_rank_generator is None:
            for x in d_ranks:
                yield x
            return
        e_ranks = encoder_rank_generator.get_ranks(group_type, **kwargs)
        if group_type == 'pp':
            # Map 1 encoder tp rank to several decoder tp ranks, because
            # these won't be the same size.
            for x, y in zip(cycle(e_ranks), d_ranks):
                yield x + y
        elif group_type == 'tp-pp':
            # For this group, we can just return the concatenated
            # groups together, because their sizes are the same.
            assert len(e_ranks) == len(d_ranks)
            for x, y in zip(e_ranks, d_ranks):
                yield x + y
        else:
            for x in e_ranks:
                yield x
            for x in d_ranks:
                yield x

    timeout = timedelta(minutes=distributed_timeout_minutes)

763
    # Build the data-parallel groups.
764
    global _DATA_PARALLEL_GROUP
liangjing's avatar
v1  
liangjing committed
765
    global _DATA_PARALLEL_GROUP_GLOO
766
    global _DATA_PARALLEL_GLOBAL_RANKS
xingjinliang's avatar
xingjinliang committed
767
768
769
770
771
772
    global _DATA_PARALLEL_GROUP_WITH_CP
    global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
    global _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP
    global _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
    global _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO
    global _INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
773
    assert _DATA_PARALLEL_GROUP is None, 'data parallel group is already initialized'
xingjinliang's avatar
xingjinliang committed
774
775

    for ranks in generator_wrapper('dp'):
wangxj's avatar
wangxj committed
776
777
778
779
780
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('dp', nccl_comm_cfgs),
            group_desc='DATA_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
781
        )
wangxj's avatar
wangxj committed
782
783
784
785
786
787
        if create_gloo_process_groups:
            group_gloo = create_group(
                ranks, timeout=timeout, backend="gloo", group_desc='DATA_PARALLEL_GROUP_GLOO'
            )
        else:
            group_gloo = None
xingjinliang's avatar
xingjinliang committed
788
789
790
791
792
793
        if rank in ranks:
            _DATA_PARALLEL_GROUP = group
            _DATA_PARALLEL_GROUP_GLOO = group_gloo
            _DATA_PARALLEL_GLOBAL_RANKS = ranks

    assert (
wangxj's avatar
wangxj committed
794
795
796
797
798
799
800
        data_parallel_size * context_parallel_size
    ) % num_distributed_optimizer_instances == 0, (
        'Data parallel size should be divisible by partial DistOpt shard factor'
    )
    intra_partial_data_parallel_size = (
        data_parallel_size * context_parallel_size
    ) // num_distributed_optimizer_instances
xingjinliang's avatar
xingjinliang committed
801
802

    for ranks_with_cp in generator_wrapper('dp-cp'):
wangxj's avatar
wangxj committed
803
804
805
806
807
        group_with_cp = create_group(
            ranks_with_cp,
            timeout=timeout,
            pg_options=get_nccl_options('dp_cp', nccl_comm_cfgs),
            group_desc='DATA_PARALLEL_GROUP_WITH_CP',
xingjinliang's avatar
xingjinliang committed
808
        )
wangxj's avatar
wangxj committed
809
810
811
812
813
814
815
816
817
        if create_gloo_process_groups:
            group_with_cp_gloo = create_group(
                ranks_with_cp,
                timeout=timeout,
                backend="gloo",
                group_desc='DATA_PARALLEL_GROUP_WITH_CP_GLOO',
            )
        else:
            group_with_cp_gloo = None
xingjinliang's avatar
xingjinliang committed
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
        if rank in ranks_with_cp:
            _DATA_PARALLEL_GROUP_WITH_CP = group_with_cp
            _DATA_PARALLEL_GROUP_WITH_CP_GLOO = group_with_cp_gloo
            _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP = ranks_with_cp

        if num_distributed_optimizer_instances > 1:
            # Create groups for Partial DistOpt, one for intra-partial DP domain
            # Another for inter-partial DP domain
            for i in range(num_distributed_optimizer_instances):
                intra_partial_data_parallel_ranks_with_cp = ranks_with_cp[
                    (i * intra_partial_data_parallel_size) : (
                        (i + 1) * intra_partial_data_parallel_size
                    )
                ]

wangxj's avatar
wangxj committed
833
                intra_partial_data_parallel_group_with_cp = create_group(
xingjinliang's avatar
xingjinliang committed
834
835
                    intra_partial_data_parallel_ranks_with_cp,
                    timeout=timeout,
wangxj's avatar
wangxj committed
836
837
                    pg_options=get_nccl_options('intra_dp_cp', nccl_comm_cfgs),
                    group_desc='INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP',
xingjinliang's avatar
xingjinliang committed
838
                )
wangxj's avatar
wangxj committed
839
840
841
842
843
844
845
846
847
                if create_gloo_process_groups:
                    intra_partial_data_parallel_group_with_cp_gloo = create_group(
                        intra_partial_data_parallel_ranks_with_cp,
                        timeout=timeout,
                        backend="gloo",
                        group_desc='INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO',
                    )
                else:
                    intra_partial_data_parallel_group_with_cp_gloo = None
xingjinliang's avatar
xingjinliang committed
848
849
850
851
852
853
854
855
856
857
858
859
860
861

                if rank in intra_partial_data_parallel_ranks_with_cp:
                    _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = (
                        intra_partial_data_parallel_group_with_cp
                    )
                    _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO = (
                        intra_partial_data_parallel_group_with_cp_gloo
                    )

            for i in range(intra_partial_data_parallel_size):
                inter_partial_data_parallel_ranks_with_cp = ranks_with_cp[
                    i::intra_partial_data_parallel_size
                ]

wangxj's avatar
wangxj committed
862
                inter_partial_data_parallel_group_with_cp = create_group(
xingjinliang's avatar
xingjinliang committed
863
864
                    inter_partial_data_parallel_ranks_with_cp,
                    timeout=timeout,
wangxj's avatar
wangxj committed
865
866
                    pg_options=get_nccl_options('inter_dp_cp', nccl_comm_cfgs),
                    group_desc='INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP',
xingjinliang's avatar
xingjinliang committed
867
868
869
870
871
872
873
874
875
                )

                if rank in inter_partial_data_parallel_ranks_with_cp:
                    _INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = (
                        inter_partial_data_parallel_group_with_cp
                    )
        else:
            _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP = _DATA_PARALLEL_GROUP_WITH_CP
            _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO = _DATA_PARALLEL_GROUP_WITH_CP_GLOO
876

liangjing's avatar
v1  
liangjing committed
877
878
879
880
881
882
883
884
885
886
887
888
889
890
    # Apply SHARP to DP process groups
    if use_sharp:
        if rank == 0:
            print(
                "The number of process groups to use SHARP with depends on the type "
                "of the network switch. Nvidia QM1 switch supports SAHRP up to 8 "
                "process groups and QM2 supports up to 256 process groups. We apply "
                "SHARP to the communications of the data-parallel domain. If the "
                "number of data-parallel process groups is larger than the max "
                "process groups that the network switch supports, the communication "
                "will fall back to non-SHARP operators. To enable SHARP, "
                "`#SBATCH_NETWORK=sharp` should be set in the sbatch script."
            )
        torch.distributed.barrier(
xingjinliang's avatar
xingjinliang committed
891
892
893
894
895
896
897
898
899
900
901
            group=get_data_parallel_group(with_context_parallel=True),
            device_ids=[torch.cuda.current_device()],
        )
        # Set `NCCL_COLLNET_ENABLE=0` to restrict SHARP application to DP process groups
        os.environ["NCCL_COLLNET_ENABLE"] = "0"

    # Build the context-parallel groups.
    global _CONTEXT_PARALLEL_GROUP
    global _CONTEXT_PARALLEL_GLOBAL_RANKS
    assert _CONTEXT_PARALLEL_GROUP is None, 'context parallel group is already initialized'
    for ranks in generator_wrapper('cp'):
wangxj's avatar
wangxj committed
902
903
904
905
906
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('cp', nccl_comm_cfgs),
            group_desc='CONTEXT_PARALLEL_GROUP',
liangjing's avatar
v1  
liangjing committed
907
        )
xingjinliang's avatar
xingjinliang committed
908
909
910
911
912
913
914
915
916
917
        if rank in ranks:
            _CONTEXT_PARALLEL_GROUP = group
            _CONTEXT_PARALLEL_GLOBAL_RANKS = ranks
        if hierarchical_context_parallel_sizes:
            global _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS
            _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS += create_hierarchical_parallel_groups(
                rank,
                ranks,
                context_parallel_size,
                hierarchical_context_parallel_sizes,
wangxj's avatar
wangxj committed
918
                get_nccl_options('hcp', nccl_comm_cfgs),
xingjinliang's avatar
xingjinliang committed
919
            )
liangjing's avatar
v1  
liangjing committed
920

921
    # Build the model-parallel groups.
922
    global _MODEL_PARALLEL_GROUP
xingjinliang's avatar
xingjinliang committed
923
    global _MODEL_PARALLEL_GLOBAL_RANKS
924
    assert _MODEL_PARALLEL_GROUP is None, 'model parallel group is already initialized'
xingjinliang's avatar
xingjinliang committed
925
    for ranks in generator_wrapper('tp-pp'):
wangxj's avatar
wangxj committed
926
927
928
929
930
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('mp', nccl_comm_cfgs),
            group_desc='MODEL_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
931
        )
932
        if rank in ranks:
933
            _MODEL_PARALLEL_GROUP = group
xingjinliang's avatar
xingjinliang committed
934
            _MODEL_PARALLEL_GLOBAL_RANKS = ranks
935

936
937
    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
xingjinliang's avatar
xingjinliang committed
938
    global _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS
liangjing's avatar
v1  
liangjing committed
939
940
941
    assert (
        _TENSOR_MODEL_PARALLEL_GROUP is None
    ), 'tensor model parallel group is already initialized'
xingjinliang's avatar
xingjinliang committed
942
    for ranks in generator_wrapper('tp'):
wangxj's avatar
wangxj committed
943
944
945
946
947
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('tp', nccl_comm_cfgs),
            group_desc='TENSOR_MODEL_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
948
        )
949
        if rank in ranks:
950
            _TENSOR_MODEL_PARALLEL_GROUP = group
xingjinliang's avatar
xingjinliang committed
951
            _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS = ranks
952

953
954
955
    # Build the pipeline model-parallel groups and embedding groups
    # (first and last rank in each pipeline model-parallel group).
    global _PIPELINE_MODEL_PARALLEL_GROUP
956
    global _PIPELINE_GLOBAL_RANKS
liangjing's avatar
v1  
liangjing committed
957
958
959
    assert (
        _PIPELINE_MODEL_PARALLEL_GROUP is None
    ), 'pipeline model parallel group is already initialized'
960
    global _EMBEDDING_GROUP
961
    global _EMBEDDING_GLOBAL_RANKS
962
    assert _EMBEDDING_GROUP is None, 'embedding group is already initialized'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
963
964
    global _POSITION_EMBEDDING_GROUP
    global _POSITION_EMBEDDING_GLOBAL_RANKS
liangjing's avatar
v1  
liangjing committed
965
    assert _POSITION_EMBEDDING_GROUP is None, 'position embedding group is already initialized'
wangxj's avatar
wangxj committed
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
    if pipeline_model_parallel_comm_backend == 'ucc':
        # The UCC backend provides two key benefits:
        # 1) Achieves better bandwidth utilization than NCCL when using InfiniBand links.
        # 2) Does not use GPU SM resources (Zero-SM), mitigating performance interference
        #    with overlapping compute kernels.

        # The UCC backend is recommended in the following cases:
        # 1) When the exposed pipeline-parallel (PP) communications are significant.
        #    - E.g., Pipeline parallelism with very less gradient accumulation steps.
        #    - It may provide better performance due to improved bandwidth utilization.
        # 2) When the critical-path pipeline stage has substantial PP-communication overlap.
        #    - E.g., Uneven pipeline parallelism.
        #    - It may provide better performance due to zero SM resource usage.
        if 'CUDA_DEVICE_MAX_CONNECTIONS' in os.environ:
            # UCC backend requires CUDA_DEVICE_MAX_CONNECTIONS variable to be larger than 1,
            # to gurantee the overlapped UCC communications. If this environment variable is set to 1,
            # all the UCC communication will be serialized.
            assert (
                os.environ['CUDA_DEVICE_MAX_CONNECTIONS'] != '1'
            ), "UCC-backend requires CUDA_DEVICE_MAX_CONNECTIONS > 1"

        # Setting up required environment variables for ucc backend
        #
        # "TORCH_UCC_BLOCKING_WAIT=none" allows non-blocking waits of the communiction handle
        # "UCC_EC_CUDA_STREAM_TASK_MODE" controls how CUDA execution engines (EC)
        # schedule tasks on CUDA streams.
        # "UCX_TLS" controls transport layer selection
        # "NSYS_UCP_COMM_PARAMS=1" enables capturing ucx tracing in nsys profiling
        # "UCX_RNDV_THRESH" controls threshold threshold for switching between
        # eager and rendezvous (RNDV) communication protocols.
        # "UCX_NET_DEVICES" select which network interfaces UCX should use.
        # "UCC_CL_BASIC_TLS" controls which Transport Layers are used by
        # the Basic Collective libraray

        os.environ['TORCH_UCC_BLOCKING_WAIT'] = (
            os.environ['TORCH_UCC_BLOCKING_WAIT']
            if "TORCH_UCC_BLOCKING_WAIT" in os.environ
            else 'none'
        )
        os.environ['UCC_EC_CUDA_STREAM_TASK_MODE'] = (
            os.environ['UCC_EC_CUDA_STREAM_TASK_MODE']
            if "UCC_EC_CUDA_STREAM_TASK_MODE" in os.environ
            else 'driver'
        )
        os.environ['UCX_TLS'] = (
            os.environ['UCX_TLS'] if "UCX_TLS" in os.environ else 'ib,cuda_copy'
        )  # cuda_ipc (i.e., NVLink-enablement) will be later supported
        os.environ['NSYS_UCP_COMM_PARAMS'] = '1'
        os.environ['UCX_RNDV_THRESH'] = '0'
        os.environ['UCX_NET_DEVICES'] = 'all'
        os.environ['UCC_CL_BASIC_TLS'] = '^sharp,nccl'

xingjinliang's avatar
xingjinliang committed
1018
    for ranks in generator_wrapper('pp'):
wangxj's avatar
wangxj committed
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
        group = create_group(
            ranks,
            timeout=timeout,
            backend=pipeline_model_parallel_comm_backend,
            pg_options=(
                None
                if pipeline_model_parallel_comm_backend == 'ucc'
                else get_nccl_options('pp', nccl_comm_cfgs)
            ),
            group_desc='PIPELINE_MODEL_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
1029
        )
wangxj's avatar
wangxj committed
1030
1031
1032
1033
1034
1035
        assert (
            pipeline_model_parallel_comm_backend == None
            or pipeline_model_parallel_comm_backend == 'nccl'
            or pipeline_model_parallel_comm_backend == 'ucc'
        ), f'"{pipeline_model_parallel_comm_backend}" backend for PP communication is currently not supported'

1036
        if rank in ranks:
xingjinliang's avatar
xingjinliang committed
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
            if _PIPELINE_MODEL_PARALLEL_GROUP is None:
                _PIPELINE_MODEL_PARALLEL_GROUP = group
                _PIPELINE_GLOBAL_RANKS = ranks
            elif isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
                _PIPELINE_MODEL_PARALLEL_GROUP.append(group)
                _PIPELINE_GLOBAL_RANKS.append(ranks)
            else:
                _PIPELINE_MODEL_PARALLEL_GROUP = [_PIPELINE_MODEL_PARALLEL_GROUP, group]
                _PIPELINE_GLOBAL_RANKS = [_PIPELINE_GLOBAL_RANKS, ranks]

        embedding_ranks = get_embedding_ranks(ranks)
wangxj's avatar
wangxj committed
1048
1049
1050
1051
1052
        group = create_group(
            embedding_ranks,
            timeout=timeout,
            pg_options=get_nccl_options('embd', nccl_comm_cfgs),
            group_desc='EMBEDDING_GROUP',
xingjinliang's avatar
xingjinliang committed
1053
        )
1054
1055
        if rank in embedding_ranks:
            _EMBEDDING_GROUP = group
1056
            _EMBEDDING_GLOBAL_RANKS = embedding_ranks
1057

xingjinliang's avatar
xingjinliang committed
1058
        position_embedding_ranks = get_position_embedding_ranks(ranks)
wangxj's avatar
wangxj committed
1059
        group = create_group(
xingjinliang's avatar
xingjinliang committed
1060
1061
            position_embedding_ranks,
            timeout=timeout,
wangxj's avatar
wangxj committed
1062
1063
            pg_options=get_nccl_options('pos_embd', nccl_comm_cfgs),
            group_desc='POSITION_EMBEDDING_GROUP',
xingjinliang's avatar
xingjinliang committed
1064
        )
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1065
1066
        if rank in position_embedding_ranks:
            _POSITION_EMBEDDING_GROUP = group
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1067
            _POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1068

xingjinliang's avatar
xingjinliang committed
1069
1070
1071
1072
1073
1074
1075
    # Build the tensor + data parallel groups.
    global _TENSOR_AND_DATA_PARALLEL_GROUP
    global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
    assert (
        _TENSOR_AND_DATA_PARALLEL_GROUP is None
    ), 'Tensor + data parallel group is already initialized'
    for ranks in generator_wrapper('tp-dp-cp'):
wangxj's avatar
wangxj committed
1076
1077
1078
1079
1080
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('tp_dp_cp', nccl_comm_cfgs),
            group_desc='TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP',
xingjinliang's avatar
xingjinliang committed
1081
1082
1083
1084
        )
        if rank in ranks:
            _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = group
    for ranks in generator_wrapper('tp-dp'):
wangxj's avatar
wangxj committed
1085
1086
1087
1088
1089
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('tp_dp', nccl_comm_cfgs),
            group_desc='TENSOR_AND_DATA_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
1090
1091
1092
1093
1094
1095
1096
1097
1098
        )
        if rank in ranks:
            _TENSOR_AND_DATA_PARALLEL_GROUP = group

    global _TENSOR_AND_CONTEXT_PARALLEL_GROUP
    assert (
        _TENSOR_AND_CONTEXT_PARALLEL_GROUP is None
    ), 'Tensor + context parallel group is already initialized'
    for ranks in generator_wrapper('tp-cp'):
wangxj's avatar
wangxj committed
1099
1100
1101
1102
1103
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('tp_cp', nccl_comm_cfgs),
            group_desc='TENSOR_AND_CONTEXT_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
1104
1105
1106
1107
1108
1109
1110
1111
1112
        )
        if rank in ranks:
            _TENSOR_AND_CONTEXT_PARALLEL_GROUP = group

    ### Expert-related parallel groups initialization
    # Build the expert model parallel group
    global _EXPERT_MODEL_PARALLEL_GROUP
    assert _EXPERT_MODEL_PARALLEL_GROUP is None, 'Expert parallel group is already initialized'
    for ranks in generator_wrapper('ep', is_expert=True):
wangxj's avatar
wangxj committed
1113
1114
1115
1116
        group = create_group(
            ranks,
            pg_options=get_nccl_options('ep', nccl_comm_cfgs),
            group_desc='EXPERT_MODEL_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
        )
        if rank in ranks:
            _EXPERT_MODEL_PARALLEL_GROUP = group

    # Build the expert tensor parallel group
    global _EXPERT_TENSOR_PARALLEL_GROUP
    assert (
        _EXPERT_TENSOR_PARALLEL_GROUP is None
    ), 'Expert tensor model parallel group is already initialized'
    for ranks in generator_wrapper('tp', is_expert=True):
wangxj's avatar
wangxj committed
1127
1128
1129
1130
1131
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('ep_tp', nccl_comm_cfgs),
            group_desc='EXPERT_TENSOR_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
        )
        if rank in ranks:
            _EXPERT_TENSOR_PARALLEL_GROUP = group

    # Build the tensor + expert parallel groups
    global _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP
    assert (
        _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP is None
    ), 'Expert tensor + model parallel group is already initialized'
    for ranks in generator_wrapper('tp-ep', is_expert=True):
wangxj's avatar
wangxj committed
1142
1143
1144
1145
1146
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('tp_ep_mp', nccl_comm_cfgs),
            group_desc='EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
        )
        if rank in ranks:
            _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = group

    # Build the expert+tensor+pipeline parallel groups
    global _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP
    assert (
        _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP is None
    ), 'The expert_tensor_model_pipeline parallel group is already initialized'
    for ranks in generator_wrapper('tp-ep-pp', is_expert=True):
wangxj's avatar
wangxj committed
1157
1158
1159
1160
1161
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('tp_ep_pp', nccl_comm_cfgs),
            group_desc='EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
        )
        if rank in ranks:
            _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = group

    # Build the expert data parallel group
    global _EXPERT_DATA_PARALLEL_GROUP
    assert _EXPERT_DATA_PARALLEL_GROUP is None, 'Expert data group is already initialized'
    global _EXPERT_DATA_PARALLEL_GROUP_GLOO
    assert _EXPERT_DATA_PARALLEL_GROUP_GLOO is None, 'Expert data group-gloo is already initialized'

    for ranks in generator_wrapper('dp', is_expert=True):
wangxj's avatar
wangxj committed
1173
1174
1175
1176
1177
        group = create_group(
            ranks,
            timeout=timeout,
            pg_options=get_nccl_options('ep_dp', nccl_comm_cfgs),
            group_desc='EXPERT_DATA_PARALLEL_GROUP',
xingjinliang's avatar
xingjinliang committed
1178
        )
wangxj's avatar
wangxj committed
1179
1180
1181
1182
1183
1184
        if create_gloo_process_groups:
            group_gloo = create_group(
                ranks, backend="gloo", group_desc='EXPERT_DATA_PARALLEL_GROUP_GLOO'
            )
        else:
            group_gloo = None
xingjinliang's avatar
xingjinliang committed
1185
1186
1187
1188
        if rank in ranks:
            _EXPERT_DATA_PARALLEL_GROUP = group
            _EXPERT_DATA_PARALLEL_GROUP_GLOO = group_gloo
    ### End of expert related parallel groups initialization
liangjing's avatar
v1  
liangjing committed
1189

1190
1191
1192
1193
1194
1195
    # Initialize global memory buffer
    # This isn't really "parallel state" but there isn't another good place to
    # put this. If we end up with a more generic initialization of megatron-core
    # we could stick it there
    _set_global_memory_buffer()

1196

xingjinliang's avatar
xingjinliang committed
1197
def is_initialized():
Abhinav Khattar's avatar
Abhinav Khattar committed
1198
    """Useful for code segments that may be accessed with or without mpu initialization"""
xingjinliang's avatar
xingjinliang committed
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
    return _DATA_PARALLEL_GROUP is not None


def is_unitialized() -> bool:
    """Check if parallel state has been initialized

    Deprecated. Use is_initialized instead.

    """
    warnings.warn("is_unitialized is deprecated, use is_initialized instead", DeprecationWarning)
    return not is_initialized()
Abhinav Khattar's avatar
Abhinav Khattar committed
1210
1211


1212
def model_parallel_is_initialized():
xingjinliang's avatar
xingjinliang committed
1213
    """Check if model- and data-parallel groups are initialized."""
liangjing's avatar
v1  
liangjing committed
1214
1215
1216
1217
1218
    if (
        _TENSOR_MODEL_PARALLEL_GROUP is None
        or _PIPELINE_MODEL_PARALLEL_GROUP is None
        or _DATA_PARALLEL_GROUP is None
    ):
1219
1220
1221
1222
1223
        return False
    return True


def get_model_parallel_group():
xingjinliang's avatar
xingjinliang committed
1224
    """Get the model-parallel group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
1225
    assert _MODEL_PARALLEL_GROUP is not None, 'model parallel group is not initialized'
1226
1227
1228
    return _MODEL_PARALLEL_GROUP


liangjing's avatar
v1  
liangjing committed
1229
def get_tensor_model_parallel_group(check_initialized=True):
xingjinliang's avatar
xingjinliang committed
1230
    """Get the tensor-model-parallel group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
1231
1232
1233
1234
    if check_initialized:
        assert (
            _TENSOR_MODEL_PARALLEL_GROUP is not None
        ), 'tensor model parallel group is not initialized'
1235
    return _TENSOR_MODEL_PARALLEL_GROUP
1236
1237


1238
def get_pipeline_model_parallel_group():
xingjinliang's avatar
xingjinliang committed
1239
    """Get the pipeline-model-parallel group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
1240
1241
1242
    assert (
        _PIPELINE_MODEL_PARALLEL_GROUP is not None
    ), 'pipeline_model parallel group is not initialized'
1243
    return _PIPELINE_MODEL_PARALLEL_GROUP
1244
1245


xingjinliang's avatar
xingjinliang committed
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
def get_data_parallel_group(with_context_parallel=False, partial_data_parallel=False):
    """Get the data-parallel group the caller rank belongs to."""
    if with_context_parallel:
        if partial_data_parallel:
            assert (
                _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP is not None
            ), 'Intra partial data parallel group is not initialized'
            return _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
        assert (
            _DATA_PARALLEL_GROUP_WITH_CP is not None
        ), 'data parallel group with context parallel combined is not initialized'
        return _DATA_PARALLEL_GROUP_WITH_CP
    else:
        assert _DATA_PARALLEL_GROUP is not None, 'data parallel group is not initialized'
        assert partial_data_parallel == False, 'Partial DP for Optimizer needs to include CP'
        return _DATA_PARALLEL_GROUP


def get_data_parallel_group_gloo(with_context_parallel=False, partial_data_parallel=False):
    """Get the Gloo data-parallel group the caller rank belongs to."""
    if with_context_parallel:
        if partial_data_parallel:
            assert (
                _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None
            ), 'Intra partial data parallel group is not initialized'
            return _INTRA_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP_GLOO
        assert (
            _DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None
        ), 'data parallel group-gloo with context parallel combined is not initialized'
        return _DATA_PARALLEL_GROUP_WITH_CP_GLOO
    else:
        assert _DATA_PARALLEL_GROUP_GLOO is not None, 'data parallel group-gloo is not initialized'
        assert partial_data_parallel == False, 'Partial DP for Optimizer needs to include CP'
        return _DATA_PARALLEL_GROUP_GLOO

1281

xingjinliang's avatar
xingjinliang committed
1282
1283
1284
1285
1286
1287
def get_inter_partial_data_parallel_group():
    """Get the group spanning the different partial data-parallel groups."""
    assert (
        _INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP is not None
    ), 'Inter partial data parallel group is not initialized'
    return _INTER_PARTIAL_DATA_PARALLEL_GROUP_WITH_CP
1288

xingjinliang's avatar
xingjinliang committed
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310

def get_context_parallel_group(check_initialized=True):
    """Get the context-parallel group the caller rank belongs to."""
    if check_initialized:
        assert _CONTEXT_PARALLEL_GROUP is not None, 'context parallel group is not initialized'
    return _CONTEXT_PARALLEL_GROUP


def get_context_parallel_global_ranks(check_initialized=True):
    """Get all global ranks of the context-parallel group that the caller rank belongs to."""
    if check_initialized:
        assert (
            _CONTEXT_PARALLEL_GLOBAL_RANKS is not None
        ), 'context parallel group is not initialized'
    return _CONTEXT_PARALLEL_GLOBAL_RANKS


def get_hierarchical_context_parallel_groups(check_initialized=True):
    """Get the inner ring of context parallel group the caller rank belongs to."""
    if check_initialized:
        assert _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS is not None
    return _HIERARCHICAL_CONTEXT_PARALLEL_GROUPS
liangjing's avatar
v1  
liangjing committed
1311
1312


1313
1314
def get_embedding_group():
    """Get the embedding group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
1315
    assert _EMBEDDING_GROUP is not None, 'embedding group is not initialized'
1316
1317
1318
    return _EMBEDDING_GROUP


Vijay Korthikanti's avatar
Vijay Korthikanti committed
1319
1320
def get_position_embedding_group():
    """Get the position embedding group the caller rank belongs to."""
liangjing's avatar
v1  
liangjing committed
1321
    assert _POSITION_EMBEDDING_GROUP is not None, 'position embedding group is not initialized'
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1322
1323
1324
    return _POSITION_EMBEDDING_GROUP


xingjinliang's avatar
xingjinliang committed
1325
def get_amax_reduction_group(with_context_parallel=False, tp_only_amax_red=False):
liangjing's avatar
v1  
liangjing committed
1326
    """Get the FP8 amax reduction group the caller rank belongs to."""
xingjinliang's avatar
xingjinliang committed
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
    if with_context_parallel:
        if not tp_only_amax_red:
            assert (
                _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
            ), 'FP8 amax reduction group is not initialized'
            return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
        else:
            assert (
                _TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None
            ), 'FP8 amax reduction group is not initialized'
            return _TENSOR_AND_CONTEXT_PARALLEL_GROUP
    else:
        if not tp_only_amax_red:
            assert (
                _TENSOR_AND_DATA_PARALLEL_GROUP is not None
            ), 'FP8 amax reduction group is not initialized'
            return _TENSOR_AND_DATA_PARALLEL_GROUP
        else:
            assert (
                _TENSOR_MODEL_PARALLEL_GROUP is not None
            ), 'FP8 amax reduction group is not initialized'
            return _TENSOR_MODEL_PARALLEL_GROUP


def get_tensor_and_data_parallel_group(with_context_parallel=False):
    """Get the tensor- and data-parallel group the caller rank belongs to."""
    if with_context_parallel:
        assert (
            _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP is not None
        ), 'tensor and data parallel group is not initialized'
        return _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
    else:
        assert (
            _TENSOR_AND_DATA_PARALLEL_GROUP is not None
        ), 'tensor and data parallel group is not initialized'
        return _TENSOR_AND_DATA_PARALLEL_GROUP


def get_tensor_and_context_parallel_group():
    """Get the tensor- and context-parallel group the caller rank belongs to."""
    assert (
        _TENSOR_AND_CONTEXT_PARALLEL_GROUP is not None
    ), 'tensor and context parallel group is not initialized'
    return _TENSOR_AND_CONTEXT_PARALLEL_GROUP
liangjing's avatar
v1  
liangjing committed
1371
1372


1373
def set_tensor_model_parallel_world_size(world_size):
xingjinliang's avatar
xingjinliang committed
1374
    """Set the tensor-model-parallel size"""
1375
1376
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = world_size
1377
1378


1379
def set_pipeline_model_parallel_world_size(world_size):
xingjinliang's avatar
xingjinliang committed
1380
    """Set the pipeline-model-parallel size"""
1381
1382
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
1383
1384


liangjing's avatar
v1  
liangjing committed
1385
def set_virtual_pipeline_model_parallel_world_size(world_size):
xingjinliang's avatar
xingjinliang committed
1386
    """Set the pipeline-model-parallel size"""
liangjing's avatar
v1  
liangjing committed
1387
1388
1389
1390
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = world_size


1391
def get_tensor_model_parallel_world_size():
xingjinliang's avatar
xingjinliang committed
1392
    """Return world size for the tensor-model-parallel group."""
1393
1394
1395
1396
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    return torch.distributed.get_world_size(group=get_tensor_model_parallel_group())
1397
1398


1399
def get_pipeline_model_parallel_world_size():
xingjinliang's avatar
xingjinliang committed
1400
    """Return world size for the pipeline-model-parallel group."""
1401
1402
1403
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    if _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
xingjinliang's avatar
xingjinliang committed
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414

    pp_group = get_pipeline_model_parallel_group()
    if isinstance(pp_group, list):
        # Implicit assumption that each PP group is the same size.
        sizes = []
        for group in _PIPELINE_GLOBAL_RANKS:
            sizes.append(len(group))
        assert all(x == sizes[0] for x in sizes)
        return torch.distributed.get_world_size(group=pp_group[0])
    else:
        return torch.distributed.get_world_size(group=pp_group)
Mohammad Shoeybi's avatar
Mohammad Shoeybi committed
1415
1416


1417
def set_tensor_model_parallel_rank(rank):
xingjinliang's avatar
xingjinliang committed
1418
    """Set tensor-model-parallel rank."""
1419
1420
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = rank
1421
1422


1423
def set_pipeline_model_parallel_rank(rank):
xingjinliang's avatar
xingjinliang committed
1424
    """Set pipeline-model-parallel rank."""
1425
1426
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = rank
1427
1428


1429
def set_pipeline_model_parallel_split_rank(rank):
xingjinliang's avatar
xingjinliang committed
1430
    """Set pipeline-model-parallel split rank. DEPRECATED."""
1431
1432
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    _PIPELINE_MODEL_PARALLEL_SPLIT_RANK = rank
1433
1434


1435
def get_tensor_model_parallel_rank():
xingjinliang's avatar
xingjinliang committed
1436
    """Return caller's rank for the tensor-model-parallel group."""
1437
1438
1439
1440
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    if _MPU_TENSOR_MODEL_PARALLEL_RANK is not None:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
1441
1442


1443
def get_pipeline_model_parallel_rank():
xingjinliang's avatar
xingjinliang committed
1444
    """Return caller's rank for the pipeline-model-parallel group."""
1445
1446
1447
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    if _MPU_PIPELINE_MODEL_PARALLEL_RANK is not None:
        return _MPU_PIPELINE_MODEL_PARALLEL_RANK
xingjinliang's avatar
xingjinliang committed
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
    rank = torch.distributed.get_rank()
    pp_group = get_pipeline_model_parallel_group()
    if isinstance(pp_group, list):
        # Assume that if the caller exist in multiple PP groups, then it has the same index.
        indices = []
        for group in _PIPELINE_GLOBAL_RANKS:
            for i, r in enumerate(group):
                if r == rank:
                    indices.append(i)
        assert all(x == indices[0] for x in indices)
        return torch.distributed.get_rank(group=pp_group[0])
    else:
        return torch.distributed.get_rank(group=pp_group)
1461
1462


1463
def get_pipeline_model_parallel_split_rank():
xingjinliang's avatar
xingjinliang committed
1464
    """Return pipeline-model-parallel split rank."""
1465
1466
1467
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    return _PIPELINE_MODEL_PARALLEL_SPLIT_RANK

1468

1469
def is_pipeline_first_stage(ignore_virtual=False):
1470
    """Return True if in the first pipeline model-parallel stage, False otherwise."""
1471
    if not ignore_virtual:
liangjing's avatar
v1  
liangjing committed
1472
1473
1474
1475
        if (
            get_virtual_pipeline_model_parallel_world_size() is not None
            and get_virtual_pipeline_model_parallel_rank() != 0
        ):
1476
            return False
1477
    return get_pipeline_model_parallel_rank() == 0
1478
1479


1480
def is_pipeline_last_stage(ignore_virtual=False):
xingjinliang's avatar
xingjinliang committed
1481
    """Return True if in the last pipeline-model-parallel stage, False otherwise."""
1482
    if not ignore_virtual:
liangjing's avatar
v1  
liangjing committed
1483
        virtual_pipeline_model_parallel_world_size = (
1484
            get_virtual_pipeline_model_parallel_world_size()
liangjing's avatar
v1  
liangjing committed
1485
        )
xingjinliang's avatar
xingjinliang committed
1486
1487
1488
1489
        if (
            virtual_pipeline_model_parallel_world_size is not None
            and get_virtual_pipeline_model_parallel_rank()
            != (virtual_pipeline_model_parallel_world_size - 1)
liangjing's avatar
v1  
liangjing committed
1490
        ):
1491
            return False
liangjing's avatar
v1  
liangjing committed
1492
    return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
1493
1494


1495
1496
1497
1498
def is_rank_in_embedding_group(ignore_virtual=False):
    """Return true if current rank is in embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _EMBEDDING_GLOBAL_RANKS
xingjinliang's avatar
xingjinliang committed
1499
1500
    if _EMBEDDING_GLOBAL_RANKS is None:
        return False
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
    if ignore_virtual:
        return rank in _EMBEDDING_GLOBAL_RANKS
    if rank in _EMBEDDING_GLOBAL_RANKS:
        if rank == _EMBEDDING_GLOBAL_RANKS[0]:
            return is_pipeline_first_stage(ignore_virtual=False)
        elif rank == _EMBEDDING_GLOBAL_RANKS[-1]:
            return is_pipeline_last_stage(ignore_virtual=False)
        else:
            return True
    return False


Vijay Korthikanti's avatar
Vijay Korthikanti committed
1513
1514
1515
1516
def is_rank_in_position_embedding_group():
    """Return true if current rank is in position embedding group, False otherwise."""
    rank = torch.distributed.get_rank()
    global _POSITION_EMBEDDING_GLOBAL_RANKS
xingjinliang's avatar
xingjinliang committed
1517
    return _POSITION_EMBEDDING_GLOBAL_RANKS is not None and rank in _POSITION_EMBEDDING_GLOBAL_RANKS
Vijay Korthikanti's avatar
Vijay Korthikanti committed
1518
1519


1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
def is_pipeline_stage_before_split(rank=None):
    """Return True if pipeline stage executes encoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank < _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


def is_pipeline_stage_after_split(rank=None):
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
    if _PIPELINE_MODEL_PARALLEL_SPLIT_RANK is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_SPLIT_RANK:
        return True
    return False


xingjinliang's avatar
xingjinliang committed
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
def is_inside_encoder(rank=None) -> bool:
    """Return True if pipeline stage executes encoder block.
    This function implicitly assumes we have a model with both
    encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_DECODER_START
    # _PIPELINE_MODEL_PARALLEL_DECODER_START == None means that the
    # encoder shares the first pipeline rank with the decoder
    if _PIPELINE_MODEL_PARALLEL_DECODER_START is None and rank == 0:
        return True
    # _PIPELINE_MODEL_PARALLEL_DECODER_START != None means that the
    # encoder is on it's own pipeline ranks before the decoder
    if (
        _PIPELINE_MODEL_PARALLEL_DECODER_START is not None
        and rank < _PIPELINE_MODEL_PARALLEL_DECODER_START
    ):
        return True
    return False


def is_inside_decoder(rank=None) -> bool:
    """Return True if pipeline stage executes decoder block for a model
    with both encoder and decoder."""
    if get_pipeline_model_parallel_world_size() == 1:
        return True
    if rank is None:
        rank = get_pipeline_model_parallel_rank()
    global _PIPELINE_MODEL_PARALLEL_DECODER_START
    if _PIPELINE_MODEL_PARALLEL_DECODER_START is None:
        return True
    if rank >= _PIPELINE_MODEL_PARALLEL_DECODER_START:
        return True
    return False


def get_pipeline_model_parallel_decoder_start() -> int:
    """Return decoder start rank (if encoder pipeline parallelism is set)."""
    global _PIPELINE_MODEL_PARALLEL_DECODER_START
    return _PIPELINE_MODEL_PARALLEL_DECODER_START


1594
1595
1596
1597
1598
def is_pipeline_stage_at_split():
    """Return true if pipeline stage executes decoder block and next
    stage executes encoder block for a model with both encoder and
    decoder."""
    rank = get_pipeline_model_parallel_rank()
liangjing's avatar
v1  
liangjing committed
1599
    return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
1600
1601


1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
def get_virtual_pipeline_model_parallel_rank():
    """Return the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK


def set_virtual_pipeline_model_parallel_rank(rank):
    """Set the virtual pipeline-parallel rank."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = rank


1614
1615
1616
1617
1618
1619
def get_virtual_pipeline_model_parallel_world_size():
    """Return the virtual pipeline-parallel world size."""
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    return _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE


1620
def get_tensor_model_parallel_src_rank():
1621
    """Calculate the global rank corresponding to the first local rank
1622
    in the tensor model parallel group."""
xingjinliang's avatar
xingjinliang committed
1623
1624
1625
1626
    assert (
        _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS is not None
    ), "Tensor model parallel group is not initialized"
    return _TENSOR_MODEL_PARALLEL_GLOBAL_RANKS[0]
1627

1628

xingjinliang's avatar
xingjinliang committed
1629
1630
1631
1632
1633
1634
1635
1636
def get_model_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the model parallel group."""
    assert _MODEL_PARALLEL_GLOBAL_RANKS is not None, "Model parallel group is not initialized"
    return _MODEL_PARALLEL_GLOBAL_RANKS[0]


def get_data_parallel_src_rank(with_context_parallel=False):
1637
    """Calculate the global rank corresponding to the first local rank
1638
    in the data parallel group."""
xingjinliang's avatar
xingjinliang committed
1639
1640
1641
1642
1643
1644
1645
1646
    if with_context_parallel:
        assert (
            _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP is not None
        ), "Data parallel group with context parallel combined is not initialized"
        return _DATA_PARALLEL_GLOBAL_RANKS_WITH_CP[0]
    else:
        assert _DATA_PARALLEL_GLOBAL_RANKS is not None, "Data parallel group is not initialized"
        return _DATA_PARALLEL_GLOBAL_RANKS[0]
1647
1648


1649
def get_pipeline_model_parallel_first_rank():
xingjinliang's avatar
xingjinliang committed
1650
    """Return the global rank of the first stage in the current rank's pipeline."""
liangjing's avatar
v1  
liangjing committed
1651
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
xingjinliang's avatar
xingjinliang committed
1652
1653
1654
1655
1656
1657
1658
    if isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
        # I assume the first rank is the same for all pp groups right now.
        for rank_group in _PIPELINE_GLOBAL_RANKS:
            assert rank_group[0] == _PIPELINE_GLOBAL_RANKS[0][0]
        return _PIPELINE_GLOBAL_RANKS[0][0]
    else:
        return _PIPELINE_GLOBAL_RANKS[0]
1659

1660

1661
def get_pipeline_model_parallel_last_rank():
xingjinliang's avatar
xingjinliang committed
1662
    """Return the global rank of the last stage in the current rank's pipeline."""
liangjing's avatar
v1  
liangjing committed
1663
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
1664
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
wangxj's avatar
wangxj committed
1665
1666
1667
1668
    if isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
        return [group[last_rank_local] for group in _PIPELINE_GLOBAL_RANKS]
    else:
        return _PIPELINE_GLOBAL_RANKS[last_rank_local]
1669

liangjing's avatar
v1  
liangjing committed
1670

1671
def get_pipeline_model_parallel_next_rank():
xingjinliang's avatar
xingjinliang committed
1672
1673
1674
1675
1676
    """Return the global rank that follows the caller in the pipeline, for each
    pipeline-parallel group that the rank is part of.

    If it is just part of one group, an int is returned, otherwise a list of ints.
    """
liangjing's avatar
v1  
liangjing committed
1677
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
1678
1679
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
xingjinliang's avatar
xingjinliang committed
1680
1681
1682
1683
1684
1685
1686
    if isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
        to_return = []
        for group in _PIPELINE_GLOBAL_RANKS:
            to_return.append(group[(rank_in_pipeline + 1) % world_size])
        return to_return
    else:
        return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
1687

1688

1689
def get_pipeline_model_parallel_prev_rank():
xingjinliang's avatar
xingjinliang committed
1690
1691
1692
1693
1694
    """Return the global rank that precedes the caller in the pipeline, for each
    pipeline-parallel group that the rank is part of.

    If it is just part of one group, an int is returned, otherwise a list of ints.
    """
liangjing's avatar
v1  
liangjing committed
1695
    assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
1696
1697
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
xingjinliang's avatar
xingjinliang committed
1698
1699
1700
1701
1702
1703
1704
    if isinstance(_PIPELINE_GLOBAL_RANKS[0], list):
        to_return = []
        for group in _PIPELINE_GLOBAL_RANKS:
            to_return.append(group[(rank_in_pipeline - 1) % world_size])
        return to_return
    else:
        return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
1705

1706

xingjinliang's avatar
xingjinliang committed
1707
def get_data_parallel_world_size(with_context_parallel=False, partial_data_parallel=False):
1708
    """Return world size for the data parallel group."""
xingjinliang's avatar
xingjinliang committed
1709
1710
1711
    global _MPU_DATA_PARALLEL_WORLD_SIZE
    if _MPU_DATA_PARALLEL_WORLD_SIZE is not None:
        return _MPU_DATA_PARALLEL_WORLD_SIZE
liangjing's avatar
v1  
liangjing committed
1712
    if torch.distributed.is_available() and torch.distributed.is_initialized():
xingjinliang's avatar
xingjinliang committed
1713
1714
1715
1716
1717
1718
        return torch.distributed.get_world_size(
            group=get_data_parallel_group(
                with_context_parallel=with_context_parallel,
                partial_data_parallel=partial_data_parallel,
            )
        )
liangjing's avatar
v1  
liangjing committed
1719
1720
    else:
        return 0
1721
1722


xingjinliang's avatar
xingjinliang committed
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
def set_data_parallel_rank(rank):
    """Return world size for the data parallel group."""
    global _MPU_DATA_PARALLEL_RANK
    _MPU_DATA_PARALLEL_RANK = rank


def get_data_parallel_rank(with_context_parallel=False, partial_data_parallel=False):
    """Return caller's rank in the data-parallel group."""
    global _MPU_DATA_PARALLEL_RANK
    if _MPU_DATA_PARALLEL_RANK is not None:
        return _MPU_DATA_PARALLEL_RANK
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(
            group=get_data_parallel_group(
                with_context_parallel=with_context_parallel,
                partial_data_parallel=partial_data_parallel,
            )
        )
    else:
        return 0


def get_context_parallel_world_size():
    """Return world size for the context parallel group."""
liangjing's avatar
v1  
liangjing committed
1747
    if torch.distributed.is_available() and torch.distributed.is_initialized():
xingjinliang's avatar
xingjinliang committed
1748
        return torch.distributed.get_world_size(group=get_context_parallel_group())
liangjing's avatar
v1  
liangjing committed
1749
1750
1751
    else:
        return 0

1752

xingjinliang's avatar
xingjinliang committed
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
def get_context_parallel_rank():
    """Return caller's rank in the context-parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(group=get_context_parallel_group())
    else:
        return 0


def get_tensor_and_context_parallel_world_size():
    """Return world size for the tensor and context-parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_world_size(group=get_tensor_and_context_parallel_group())
    else:
        return 0


def get_tensor_and_context_parallel_rank():
    """Return caller's rank in the joint tensor-model-parallel and context-parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(group=get_tensor_and_context_parallel_group())
    else:
        return 0


### Expert-related parallel states functions
def get_expert_model_parallel_group(check_initialized=True):
    """Get the expert-model-parallel group the caller rank belongs to."""
    if check_initialized:
        assert (
            _EXPERT_MODEL_PARALLEL_GROUP is not None
        ), 'expert model parallel group is not initialized'
    return _EXPERT_MODEL_PARALLEL_GROUP


def get_expert_model_parallel_world_size():
    """Return world size for the expert-model-parallel group."""
    if _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE is not None:
        return _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_world_size(group=get_expert_model_parallel_group())
    else:
        return 0


def set_expert_model_parallel_world_size(world_size):
    """Sets the expert-model-parallel world size."""
    global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
    _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = world_size


def get_expert_model_parallel_rank():
    """Return caller's rank in the expert-model-parallel group."""
    if _MPU_EXPERT_MODEL_PARALLEL_RANK is not None:
        return _MPU_EXPERT_MODEL_PARALLEL_RANK
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(group=get_expert_model_parallel_group())
    else:
        return 0


def set_expert_model_parallel_rank(rank):
    """Set expert-model-parallel rank."""
    global _MPU_EXPERT_MODEL_PARALLEL_RANK
    _MPU_EXPERT_MODEL_PARALLEL_RANK = rank


def get_expert_tensor_parallel_group(check_initialized=True):
    """Get the expert-tensor-parallel group the caller rank belongs to."""
    if check_initialized:
        assert (
            _EXPERT_TENSOR_PARALLEL_GROUP is not None
        ), 'Expert tensor parallel group is not initialized'
    return _EXPERT_TENSOR_PARALLEL_GROUP


def get_expert_tensor_parallel_world_size():
    """Return world size for the expert tensor parallel group."""
    global _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE
    if _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE is not None:
        return _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE
    # Use tensor parallel group world size for backward compability otherwise
    if not _EXPERT_TENSOR_PARALLEL_GROUP:
        return _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    else:
        return torch.distributed.get_world_size(group=get_expert_tensor_parallel_group())


def set_expert_tensor_parallel_world_size(world_size):
    "Set expert tensor model parallel size"
    global _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE
    _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE = world_size


def get_expert_tensor_parallel_rank():
    """Return my rank for the expert tensor parallel group."""
    global _MPU_EXPERT_TENSOR_PARALLEL_RANK
    if _MPU_EXPERT_TENSOR_PARALLEL_RANK is not None:
        return _MPU_EXPERT_TENSOR_PARALLEL_RANK
    # Use tensor parallel group rank for backward compability otherwise
    if not _EXPERT_TENSOR_PARALLEL_GROUP:
        return _MPU_TENSOR_MODEL_PARALLEL_RANK
    else:
        return torch.distributed.get_rank(group=get_expert_tensor_parallel_group())


def set_expert_tensor_parallel_rank(rank):
    "Set expert tensor model parallel rank"
    global _MPU_EXPERT_TENSOR_PARALLEL_RANK
    _MPU_EXPERT_TENSOR_PARALLEL_RANK = rank


def get_expert_tensor_and_model_parallel_group(check_initialized=True):
    """Get the expert-tensor and expert-model group the caller rank belongs to."""
    if check_initialized:
        assert (
            _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP is not None
        ), 'Expert tensor and model parallel group is not initialized'
    return _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP


def get_expert_tensor_and_model_parallel_world_size():
    """Return world size for the expert model parallel group times expert tensor parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        world_size = torch.distributed.get_world_size(
            group=get_expert_tensor_and_model_parallel_group()
        )
        return world_size
    else:
        return 0


def get_expert_tensor_and_model_parallel_rank():
    """Return caller's rank in the joint tensor- and expert-model-parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(group=get_expert_tensor_and_model_parallel_group())
    else:
        return 0


def get_expert_tensor_model_pipeline_parallel_group():
    """Get expert tensor-model-pipeline parallel group."""
    assert (
        _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP is not None
    ), 'Expert tensor-model-pipeline parallel group is not initialized'
    return _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP


def get_expert_data_parallel_group():
    """Get expert data parallel group."""
    assert _EXPERT_DATA_PARALLEL_GROUP is not None, 'Expert data parallel group is not initialized'
    return _EXPERT_DATA_PARALLEL_GROUP


def get_data_modulo_expert_parallel_group():
    """[Deprecated] Get expert data parallel group."""
    warnings.warn(
        "get_data_modulo_expert_parallel_group is deprecated, please use "
        "get_expert_data_parallel_group instead.",
        DeprecationWarning,
    )
    return get_expert_data_parallel_group()


def get_expert_data_parallel_group_gloo():
    """Get expert data parallel group-gloo."""
    assert (
        _EXPERT_DATA_PARALLEL_GROUP_GLOO is not None
    ), 'Expert data parallel group-gloo is not initialized'
    return _EXPERT_DATA_PARALLEL_GROUP_GLOO


def get_expert_data_parallel_rank():
    """Return caller's rank in the expert data parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_rank(group=get_expert_data_parallel_group())
    else:
        return 0


wangxj's avatar
wangxj committed
1932
1933
1934
1935
1936
1937
1938
1939
def get_expert_data_parallel_world_size():
    """Return world size for the expert data parallel group."""
    if torch.distributed.is_available() and torch.distributed.is_initialized():
        return torch.distributed.get_world_size(group=get_expert_data_parallel_group())
    else:
        return 0


xingjinliang's avatar
xingjinliang committed
1940
1941
1942
### End of expert-related functions region


1943
def _set_global_memory_buffer():
xingjinliang's avatar
xingjinliang committed
1944
    """Initialize global buffer."""
1945
1946
1947
1948
    global _GLOBAL_MEMORY_BUFFER
    assert _GLOBAL_MEMORY_BUFFER is None, 'global memory buffer is already initialized'
    _GLOBAL_MEMORY_BUFFER = GlobalMemoryBuffer()

liangjing's avatar
v1  
liangjing committed
1949

1950
def get_global_memory_buffer():
1951
    """Return the global GlobalMemoryBuffer object"""
1952
1953
1954
1955
    assert _GLOBAL_MEMORY_BUFFER is not None, 'global memory buffer is not initialized'
    return _GLOBAL_MEMORY_BUFFER


liangjing's avatar
v1  
liangjing committed
1956
1957
1958
1959
1960
1961
def destroy_global_memory_buffer():
    """Sets the global memory buffer to None"""
    global _GLOBAL_MEMORY_BUFFER
    _GLOBAL_MEMORY_BUFFER = None


xingjinliang's avatar
xingjinliang committed
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
def get_all_ranks():
    """Get caller's rank in tensor-model-parallel, data-parallel, context-parallel,
    pipeline-model-parallel and expert-model-parallel groups."""
    ranks = [
        get_tensor_model_parallel_rank(),
        get_data_parallel_rank(),
        get_context_parallel_rank(),
        get_pipeline_model_parallel_rank(),
        get_expert_model_parallel_rank(),
    ]
    return '_'.join(map(lambda x: str(x or 0), ranks))


def get_moe_layer_wise_logging_tracker():
    """Return the moe layer wise tracker."""
    global _MOE_LAYER_WISE_LOGGING_TRACKER
    return _MOE_LAYER_WISE_LOGGING_TRACKER


1981
1982
def destroy_model_parallel():
    """Set the groups to none."""
1983
1984
    global _MODEL_PARALLEL_GROUP
    _MODEL_PARALLEL_GROUP = None
xingjinliang's avatar
xingjinliang committed
1985

1986
1987
    global _TENSOR_MODEL_PARALLEL_GROUP
    _TENSOR_MODEL_PARALLEL_GROUP = None
xingjinliang's avatar
xingjinliang committed
1988

1989
1990
    global _PIPELINE_MODEL_PARALLEL_GROUP
    _PIPELINE_MODEL_PARALLEL_GROUP = None
xingjinliang's avatar
xingjinliang committed
1991
1992
1993
1994

    global _PIPELINE_MODEL_PARALLEL_DECODER_START
    _PIPELINE_MODEL_PARALLEL_DECODER_START = None

1995
1996
    global _DATA_PARALLEL_GROUP
    _DATA_PARALLEL_GROUP = None
xingjinliang's avatar
xingjinliang committed
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006

    global _DATA_PARALLEL_GROUP_WITH_CP
    _DATA_PARALLEL_GROUP_WITH_CP = None

    global _CONTEXT_PARALLEL_GROUP
    _CONTEXT_PARALLEL_GROUP = None

    global _CONTEXT_PARALLEL_GLOBAL_RANKS
    _CONTEXT_PARALLEL_GLOBAL_RANKS = None

2007
2008
    global _EMBEDDING_GROUP
    _EMBEDDING_GROUP = None
xingjinliang's avatar
xingjinliang committed
2009

Vijay Korthikanti's avatar
Vijay Korthikanti committed
2010
2011
    global _POSITION_EMBEDDING_GROUP
    _POSITION_EMBEDDING_GROUP = None
xingjinliang's avatar
xingjinliang committed
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021

    global _TENSOR_AND_DATA_PARALLEL_GROUP
    _TENSOR_AND_DATA_PARALLEL_GROUP = None

    global _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP
    _TENSOR_AND_DATA_PARALLEL_GROUP_WITH_CP = None

    global _TENSOR_AND_CONTEXT_PARALLEL_GROUP
    _TENSOR_AND_CONTEXT_PARALLEL_GROUP = None

2022
2023
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = None
xingjinliang's avatar
xingjinliang committed
2024

2025
2026
    global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
xingjinliang's avatar
xingjinliang committed
2027

2028
2029
    global _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
    _MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
xingjinliang's avatar
xingjinliang committed
2030

2031
2032
    global _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
    _MPU_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = None
xingjinliang's avatar
xingjinliang committed
2033

2034
2035
    global _MPU_TENSOR_MODEL_PARALLEL_RANK
    _MPU_TENSOR_MODEL_PARALLEL_RANK = None
xingjinliang's avatar
xingjinliang committed
2036

2037
2038
    global _MPU_PIPELINE_MODEL_PARALLEL_RANK
    _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
xingjinliang's avatar
xingjinliang committed
2039

2040
2041
    global _GLOBAL_MEMORY_BUFFER
    _GLOBAL_MEMORY_BUFFER = None
xingjinliang's avatar
xingjinliang committed
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104

    global _DATA_PARALLEL_GROUP_GLOO
    if (
        _DATA_PARALLEL_GROUP_GLOO is not None
        and torch.distributed.distributed_c10d._world.pg_map.get(_DATA_PARALLEL_GROUP_GLOO, None)
        is not None
    ):
        torch.distributed.destroy_process_group(_DATA_PARALLEL_GROUP_GLOO)
    _DATA_PARALLEL_GROUP_GLOO = None

    global _DATA_PARALLEL_GROUP_WITH_CP_GLOO
    if (
        _DATA_PARALLEL_GROUP_WITH_CP_GLOO is not None
        and torch.distributed.distributed_c10d._world.pg_map.get(
            _DATA_PARALLEL_GROUP_WITH_CP_GLOO, None
        )
        is not None
    ):
        torch.distributed.destroy_process_group(_DATA_PARALLEL_GROUP_WITH_CP_GLOO)
    _DATA_PARALLEL_GROUP_WITH_CP_GLOO = None

    # Destroy parallel state related to expert parallelism.
    global _EXPERT_MODEL_PARALLEL_GROUP
    _EXPERT_MODEL_PARALLEL_GROUP = None

    global _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE
    _MPU_EXPERT_MODEL_PARALLEL_WORLD_SIZE = None

    global _MPU_EXPERT_MODEL_PARALLEL_RANK
    _MPU_EXPERT_MODEL_PARALLEL_RANK = None

    global _EXPERT_TENSOR_PARALLEL_GROUP
    _EXPERT_TENSOR_PARALLEL_GROUP = None

    global _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE
    _MPU_EXPERT_TENSOR_PARALLEL_WORLD_SIZE = None

    global _MPU_EXPERT_TENSOR_PARALLEL_RANK
    _MPU_EXPERT_TENSOR_PARALLEL_RANK = None

    global _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP
    _EXPERT_TENSOR_AND_MODEL_PARALLEL_GROUP = None

    global _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP
    _EXPERT_TENSOR_MODEL_PIPELINE_PARALLEL_GROUP = None

    global _EXPERT_DATA_PARALLEL_GROUP
    _EXPERT_DATA_PARALLEL_GROUP = None

    global _EXPERT_DATA_PARALLEL_GROUP_GLOO
    if (
        _EXPERT_DATA_PARALLEL_GROUP_GLOO is not None
        and torch.distributed.distributed_c10d._world.pg_map.get(
            _EXPERT_DATA_PARALLEL_GROUP_GLOO, None
        )
        is not None
    ):
        torch.distributed.destroy_process_group(_EXPERT_DATA_PARALLEL_GROUP_GLOO)
    _EXPERT_DATA_PARALLEL_GROUP_GLOO = None
    # End of expert parallelism destroy.

    global _MOE_LAYER_WISE_LOGGING_TRACKER
    _MOE_LAYER_WISE_LOGGING_TRACKER = {}