parallel_state.py 12.6 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
# Copyright 2023 The vLLM team.
2
3
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
4
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
5
"""Tensor and pipeline parallel groups."""
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import contextlib
7
from typing import Optional
Zhuohan Li's avatar
Zhuohan Li committed
8
9
10

import torch

11
import vllm.envs as envs
12
13
14
15
from vllm.logger import init_logger

logger = init_logger(__name__)

16
# Tensor model parallel group that the current rank belongs to.
Zhuohan Li's avatar
Zhuohan Li committed
17
_TENSOR_MODEL_PARALLEL_GROUP = None
18
# Pipeline model parallel group that the current rank belongs to.
Zhuohan Li's avatar
Zhuohan Li committed
19
20
_PIPELINE_MODEL_PARALLEL_GROUP = None

21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# parameter of `init_distributed_environment` below.
# Essentially, this is `torch.distributed.group.WORLD`.
# We leave a line here to note that this is device-specific.
# Note that this variable is not safe to use, because when users
# call `init_distributed_environment` first, and then destroy
# the process group themselves, this variable will keep a reference to the
# destroyed process group, which is not useful.
_DEVICE_WORLD_GROUP = None

# duing `init_distributed_environment`, we will also initialize a
# group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
_CPU_WORLD_GROUP = None

# In summary, after calling `init_distributed_environment`, we will
# always have two groups: one for device-specific (and is the default)
# and one for CPU. All processes will be part of both groups.

41
42
43
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
Zhuohan Li's avatar
Zhuohan Li committed
44

45
46
47
48
49
50
51
_LOCAL_RANK = -1


def get_local_rank():
    global _LOCAL_RANK
    return _LOCAL_RANK

Zhuohan Li's avatar
Zhuohan Li committed
52

53
def init_distributed_environment(
54
55
56
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
57
58
59
    local_rank: int = -1,
    backend: str = "nccl",
):
60
61
62
63
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
        global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
        _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
        ranks = list(range(torch.distributed.get_world_size()))
        _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
                                                       backend="gloo")
79
80
81
82
        # set the local rank
        # local_rank is not available in torch ProcessGroup,
        # see https://github.com/pytorch/pytorch/issues/122816
        if local_rank == -1 and distributed_init_method == "env://":
83
            local_rank = envs.LOCAL_RANK
84
85
        global _LOCAL_RANK
        _LOCAL_RANK = local_rank
86
87


Zhuohan Li's avatar
Zhuohan Li committed
88
89
90
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
91
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
92
93
) -> None:
    """
94
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
95
96

    Arguments:
97
98
99
100
101
102
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
103
104
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
105
106
107
108
109
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
110
111
112
113
114
115
116
117
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
118
119
    # get the backend of _DEVICE_WORLD_GROUP
    backend = backend or torch.distributed.get_backend()
Zhuohan Li's avatar
Zhuohan Li committed
120

121
122
    if (world_size !=
            tensor_model_parallel_size * pipeline_model_parallel_size):
Zhuohan Li's avatar
Zhuohan Li committed
123
        raise RuntimeError(
124
125
126
127
128
129
130
131
            f"world_size ({world_size}) is not equal to "
            f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
            f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    num_pipeline_model_parallel_groups: int = (world_size //
                                               pipeline_model_parallel_size)
Zhuohan Li's avatar
Zhuohan Li committed
132
133
134
135
    rank = torch.distributed.get_rank()

    # Build the tensor model-parallel groups.
    global _TENSOR_MODEL_PARALLEL_GROUP
136
137
    assert _TENSOR_MODEL_PARALLEL_GROUP is None, (
        "tensor model parallel group is already initialized")
Zhuohan Li's avatar
Zhuohan Li committed
138
139
140
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)
141
        group = torch.distributed.new_group(ranks, backend=backend)
Zhuohan Li's avatar
Zhuohan Li committed
142
143
144
        if rank in ranks:
            _TENSOR_MODEL_PARALLEL_GROUP = group

145
    # Build the pipeline model-parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
146
147
    global _PIPELINE_MODEL_PARALLEL_GROUP
    global _PIPELINE_GLOBAL_RANKS
148
149
    assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
        "pipeline model parallel group is already initialized")
Zhuohan Li's avatar
Zhuohan Li committed
150
151
    for i in range(num_pipeline_model_parallel_groups):
        ranks = range(i, world_size, num_pipeline_model_parallel_groups)
152
        group = torch.distributed.new_group(ranks, backend=backend)
Zhuohan Li's avatar
Zhuohan Li committed
153
154
155
156
157
        if rank in ranks:
            _PIPELINE_MODEL_PARALLEL_GROUP = group
            _PIPELINE_GLOBAL_RANKS = ranks


158
159
160
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
161
    backend: Optional[str] = None,
162
163
164
165
166
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
167
168
    # get the backend of _DEVICE_WORLD_GROUP
    backend = backend or torch.distributed.get_backend()
169
170
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
171
                                  pipeline_model_parallel_size, backend)
172
173
174
175
176
177
178
179
180
181
182
183
184
185
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
    assert (get_pipeline_model_parallel_world_size(
    ) == pipeline_model_parallel_size), (
        "pipeline parallel group already initialized, but of unexpected size: "
        f"{get_pipeline_model_parallel_world_size()=} vs. "
        f"{pipeline_model_parallel_size=}")


Zhuohan Li's avatar
Zhuohan Li committed
186
def model_parallel_is_initialized():
187
    """Check if tensor and pipeline parallel groups are initialized."""
188
189
    return (_TENSOR_MODEL_PARALLEL_GROUP is not None
            and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
Zhuohan Li's avatar
Zhuohan Li committed
190
191


192
193
194
195
196
197
def get_cpu_world_group():
    """Get the CPU world group."""
    assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
    return _CPU_WORLD_GROUP


Zhuohan Li's avatar
Zhuohan Li committed
198
199
def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
200
    assert _TENSOR_MODEL_PARALLEL_GROUP is not None, (
201
        "tensor model parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
202
203
204
205
206
    return _TENSOR_MODEL_PARALLEL_GROUP


def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
207
208
    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
        "pipeline model parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
209
210
211
212
213
    return _PIPELINE_MODEL_PARALLEL_GROUP


def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
214
215
    return torch.distributed.get_world_size(
        group=get_tensor_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
216
217
218
219


def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
220
221
    return torch.distributed.get_world_size(
        group=get_pipeline_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
222
223
224
225
226
227
228
229
230


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())


def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
231
232
    return torch.distributed.get_rank(
        group=get_pipeline_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
233
234
235
236
237
238
239
240
241
242
243
244
245


def get_tensor_model_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the tensor model parallel group."""
    global_rank = torch.distributed.get_rank()
    local_world_size = get_tensor_model_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size


def get_pipeline_model_parallel_first_rank():
    """Return the global rank of the first process in the pipeline for the
    current tensor parallel group"""
246
247
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
248
249
250
251
252
253
    return _PIPELINE_GLOBAL_RANKS[0]


def get_pipeline_model_parallel_last_rank():
    """Return the global rank of the last process in the pipeline for the
    current tensor parallel group"""
254
255
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
256
257
258
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]

Zhuohan Li's avatar
Zhuohan Li committed
259

Zhuohan Li's avatar
Zhuohan Li committed
260
261
def get_pipeline_model_parallel_next_rank():
    """Return the global rank that follows the caller in the pipeline"""
262
263
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
264
265
266
267
268
269
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]


def get_pipeline_model_parallel_prev_rank():
270
    """Return the global rank that precedes the caller in the pipeline"""
271
272
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
273
274
275
276
277
278
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]


def destroy_model_parallel():
279
    """Set the groups to none and destroy them."""
Zhuohan Li's avatar
Zhuohan Li committed
280
    global _TENSOR_MODEL_PARALLEL_GROUP
281
282
    if _TENSOR_MODEL_PARALLEL_GROUP:
        torch.distributed.destroy_process_group(_TENSOR_MODEL_PARALLEL_GROUP)
Zhuohan Li's avatar
Zhuohan Li committed
283
284
    _TENSOR_MODEL_PARALLEL_GROUP = None
    global _PIPELINE_MODEL_PARALLEL_GROUP
285
286
    if _PIPELINE_MODEL_PARALLEL_GROUP:
        torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
Zhuohan Li's avatar
Zhuohan Li committed
287
    _PIPELINE_MODEL_PARALLEL_GROUP = None
288
289
    global _PIPELINE_GLOBAL_RANKS
    _PIPELINE_GLOBAL_RANKS = None
290
    from vllm.distributed.device_communicators import pynccl_utils
Woosuk Kwon's avatar
Woosuk Kwon committed
291

292
293
    # Destroy the pynccl states if any.
    pynccl_utils.destroy_process_group()
Woosuk Kwon's avatar
Woosuk Kwon committed
294
295


296
297
# Whether to use pynccl for nccl all reduce.
# We use pynccl for all reduce when using CUDA graph, because torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
298
# is not well supported by CUDA graph.
299
_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
Woosuk Kwon's avatar
Woosuk Kwon committed
300
301
302


@contextlib.contextmanager
303
def with_pynccl_for_all_reduce():
304
    from vllm.distributed.device_communicators import pynccl_utils
305
    """use pynccl instead of torch.distributed for all reduce"""
Woosuk Kwon's avatar
Woosuk Kwon committed
306
307
308
    tp_size = get_tensor_model_parallel_world_size()
    if tp_size == 1:
        # No-op.
309
        # NOTE(woosuk): We don't initialize pynccl when tp_size is 1.
Woosuk Kwon's avatar
Woosuk Kwon committed
310
311
        yield
    else:
312
313
314
        global _ENABLE_PYNCCL_FOR_ALL_REDUCE
        old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
        _ENABLE_PYNCCL_FOR_ALL_REDUCE = True
Woosuk Kwon's avatar
Woosuk Kwon committed
315
316

        stream = torch.cuda.current_stream()
317
        with pynccl_utils.set_pynccl_stream(stream):
Woosuk Kwon's avatar
Woosuk Kwon committed
318
            yield
319
        _ENABLE_PYNCCL_FOR_ALL_REDUCE = old
Woosuk Kwon's avatar
Woosuk Kwon committed
320
321


322
323
324
325
def is_pynccl_enabled_for_all_reduce():
    """check if pynccl is enabled for all reduce"""
    global _ENABLE_PYNCCL_FOR_ALL_REDUCE
    return _ENABLE_PYNCCL_FOR_ALL_REDUCE