parallel_state.py 9.51 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
# Copyright 2023 The vLLM team.
2
3
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
4
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
5
"""Tensor and pipeline parallel groups."""
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import contextlib
Zhuohan Li's avatar
Zhuohan Li committed
7
8
9

import torch

Woosuk Kwon's avatar
Woosuk Kwon committed
10
11
from vllm.model_executor.parallel_utils import cupy_utils

12
# Tensor model parallel group that the current rank belongs to.
Zhuohan Li's avatar
Zhuohan Li committed
13
_TENSOR_MODEL_PARALLEL_GROUP = None
14
# Pipeline model parallel group that the current rank belongs to.
Zhuohan Li's avatar
Zhuohan Li committed
15
16
_PIPELINE_MODEL_PARALLEL_GROUP = None

17
18
19
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
Zhuohan Li's avatar
Zhuohan Li committed
20
21
22
23
24
25
26


def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
) -> None:
    """
27
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
28
29

    Arguments:
30
31
32
33
34
35
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
36
37
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
38
39
40
41
42
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
43
44
45
46
47
48
49
50
51
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()

52
53
    if (world_size !=
            tensor_model_parallel_size * pipeline_model_parallel_size):
Zhuohan Li's avatar
Zhuohan Li committed
54
        raise RuntimeError(
55
56
57
58
59
60
61
62
            f"world_size ({world_size}) is not equal to "
            f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
            f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    num_pipeline_model_parallel_groups: int = (world_size //
                                               pipeline_model_parallel_size)
Zhuohan Li's avatar
Zhuohan Li committed
63
64
65
66
    rank = torch.distributed.get_rank()

    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
67
68
    assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
        "tensor model parallel group is already initialized")
Zhuohan Li's avatar
Zhuohan Li committed
69
70
71
72
73
74
75
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            _TENSOR_MODEL_PARALLEL_GROUP = group

76
    # Build the pipeline model-parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
77
78
    global _PIPELINE_MODEL_PARALLEL_GROUP
    global _PIPELINE_GLOBAL_RANKS
79
80
    assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
        "pipeline model parallel group is already initialized")
Zhuohan Li's avatar
Zhuohan Li committed
81
82
83
84
85
86
87
88
    for i in range(num_pipeline_model_parallel_groups):
        ranks = range(i, world_size, num_pipeline_model_parallel_groups)
        group = torch.distributed.new_group(ranks)
        if rank in ranks:
            _PIPELINE_MODEL_PARALLEL_GROUP = group
            _PIPELINE_GLOBAL_RANKS = ranks


89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
                                  pipeline_model_parallel_size)
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
    assert (get_pipeline_model_parallel_world_size(
    ) == pipeline_model_parallel_size), (
        "pipeline parallel group already initialized, but of unexpected size: "
        f"{get_pipeline_model_parallel_world_size()=} vs. "
        f"{pipeline_model_parallel_size=}")


Zhuohan Li's avatar
Zhuohan Li committed
114
def model_parallel_is_initialized():
115
    """Check if tensor and pipeline parallel groups are initialized."""
116
117
    return (_TENSOR_MODEL_PARALLEL_GROUP is not None
            and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
Zhuohan Li's avatar
Zhuohan Li committed
118
119
120
121


def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
122
    assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
123
        "tensor model parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
124
125
126
127
128
    return _TENSOR_MODEL_PARALLEL_GROUP


def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
129
130
    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
        "pipeline model parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
131
132
133
134
135
    return _PIPELINE_MODEL_PARALLEL_GROUP


def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
136
137
    return torch.distributed.get_world_size(
        group=get_tensor_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
138
139
140
141


def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
142
143
    return torch.distributed.get_world_size(
        group=get_pipeline_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
144
145
146
147
148
149
150
151
152


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())


def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
153
154
    return torch.distributed.get_rank(
        group=get_pipeline_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
155
156
157
158
159
160
161
162
163
164
165
166
167


def get_tensor_model_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the tensor model parallel group."""
    global_rank = torch.distributed.get_rank()
    local_world_size = get_tensor_model_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size


def get_pipeline_model_parallel_first_rank():
    """Return the global rank of the first process in the pipeline for the
    current tensor parallel group"""
168
169
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
170
171
172
173
174
175
    return _PIPELINE_GLOBAL_RANKS[0]


def get_pipeline_model_parallel_last_rank():
    """Return the global rank of the last process in the pipeline for the
    current tensor parallel group"""
176
177
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
178
179
180
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]

Zhuohan Li's avatar
Zhuohan Li committed
181

Zhuohan Li's avatar
Zhuohan Li committed
182
183
def get_pipeline_model_parallel_next_rank():
    """Return the global rank that follows the caller in the pipeline"""
184
185
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
186
187
188
189
190
191
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]


def get_pipeline_model_parallel_prev_rank():
192
    """Return the global rank that precedes the caller in the pipeline"""
193
194
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
195
196
197
198
199
200
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]


def destroy_model_parallel():
201
    """Set the groups to none and destroy them."""
Zhuohan Li's avatar
Zhuohan Li committed
202
    global _TENSOR_MODEL_PARALLEL_GROUP
203
204
    if _TENSOR_MODEL_PARALLEL_GROUP:
        torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
Zhuohan Li's avatar
Zhuohan Li committed
205
206
    _TENSOR_MODEL_PARALLEL_GROUP = None
    global _PIPELINE_MODEL_PARALLEL_GROUP
207
208
    if _PIPELINE_MODEL_PARALLEL_GROUP:
        torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
Zhuohan Li's avatar
Zhuohan Li committed
209
    _PIPELINE_MODEL_PARALLEL_GROUP = None
210
211
    global _PIPELINE_GLOBAL_RANKS
    _PIPELINE_GLOBAL_RANKS = None
Woosuk Kwon's avatar
Woosuk Kwon committed
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245

    # Destroy the cupy states if any.
    cupy_utils.destroy_process_group()


# Whether to use cupy for nccl all reduce.
# We use cupy for all reduce when using CUDA graph, because torch.distributed
# is not well supported by CUDA graph.
_ENABLE_CUPY_FOR_ALL_REDUCE = False


@contextlib.contextmanager
def with_cupy_nccl_for_all_reduce():
    """use CuPy nccl instead of torch.distributed for all reduce"""
    tp_size = get_tensor_model_parallel_world_size()
    if tp_size == 1:
        # No-op.
        # NOTE(woosuk): We don't initialize CuPy when tp_size is 1.
        yield
    else:
        global _ENABLE_CUPY_FOR_ALL_REDUCE
        old = _ENABLE_CUPY_FOR_ALL_REDUCE
        _ENABLE_CUPY_FOR_ALL_REDUCE = True

        stream = torch.cuda.current_stream()
        with cupy_utils.set_cupy_stream(stream):
            yield
        _ENABLE_CUPY_FOR_ALL_REDUCE = old


def is_cupy_nccl_enabled_for_all_reduce():
    """check if CuPy nccl is enabled for all reduce"""
    global _ENABLE_CUPY_FOR_ALL_REDUCE
    return _ENABLE_CUPY_FOR_ALL_REDUCE