parallel_state.py 13 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
# Copyright 2023 The vLLM team.
2
3
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
4
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
5
"""Tensor and pipeline parallel groups."""
Woosuk Kwon's avatar
Woosuk Kwon committed
6
import contextlib
7
from typing import Optional
Zhuohan Li's avatar
Zhuohan Li committed
8
9
10

import torch

11
import vllm.envs as envs
12
13
14
15
from vllm.logger import init_logger

logger = init_logger(__name__)

16
# Tensor model parallel group that the current rank belongs to.
17
18
_TP_DEVICE_GROUP = None
_TP_CPU_GROUP = None
19
# Pipeline model parallel group that the current rank belongs to.
Zhuohan Li's avatar
Zhuohan Li committed
20
21
_PIPELINE_MODEL_PARALLEL_GROUP = None

22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# parameter of `init_distributed_environment` below.
# Essentially, this is `torch.distributed.group.WORLD`.
# We leave a line here to note that this is device-specific.
# Note that this variable is not safe to use, because when users
# call `init_distributed_environment` first, and then destroy
# the process group themselves, this variable will keep a reference to the
# destroyed process group, which is not useful.
_DEVICE_WORLD_GROUP = None

# duing `init_distributed_environment`, we will also initialize a
# group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
_CPU_WORLD_GROUP = None

# In summary, after calling `init_distributed_environment`, we will
# always have two groups: one for device-specific (and is the default)
# and one for CPU. All processes will be part of both groups.

42
43
44
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS = None
Zhuohan Li's avatar
Zhuohan Li committed
45

46
47
48
49
50
51
52
_LOCAL_RANK = -1


def get_local_rank():
    global _LOCAL_RANK
    return _LOCAL_RANK

Zhuohan Li's avatar
Zhuohan Li committed
53

54
def init_distributed_environment(
55
56
57
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
58
59
60
    local_rank: int = -1,
    backend: str = "nccl",
):
61
62
63
64
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
        global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
        _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
        ranks = list(range(torch.distributed.get_world_size()))
        _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
                                                       backend="gloo")
80
81
82
83
        # set the local rank
        # local_rank is not available in torch ProcessGroup,
        # see https://github.com/pytorch/pytorch/issues/122816
        if local_rank == -1 and distributed_init_method == "env://":
84
            local_rank = envs.LOCAL_RANK
85
86
        global _LOCAL_RANK
        _LOCAL_RANK = local_rank
87
88


Zhuohan Li's avatar
Zhuohan Li committed
89
90
91
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
92
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
93
94
) -> None:
    """
95
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
96
97

    Arguments:
98
99
100
101
102
103
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
104
105
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
106
107
108
109
110
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
111
112
113
114
115
116
117
118
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
119
120
    # get the backend of _DEVICE_WORLD_GROUP
    backend = backend or torch.distributed.get_backend()
Zhuohan Li's avatar
Zhuohan Li committed
121

122
123
    if (world_size !=
            tensor_model_parallel_size * pipeline_model_parallel_size):
Zhuohan Li's avatar
Zhuohan Li committed
124
        raise RuntimeError(
125
126
127
128
129
130
131
132
            f"world_size ({world_size}) is not equal to "
            f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
            f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    num_pipeline_model_parallel_groups: int = (world_size //
                                               pipeline_model_parallel_size)
Zhuohan Li's avatar
Zhuohan Li committed
133
134
135
    rank = torch.distributed.get_rank()

    # Build the tensor model-parallel groups.
136
137
    global _TP_DEVICE_GROUP, _TP_CPU_GROUP
    assert _TP_DEVICE_GROUP is None, (
138
        "tensor model parallel group is already initialized")
Zhuohan Li's avatar
Zhuohan Li committed
139
140
141
    for i in range(num_tensor_model_parallel_groups):
        ranks = range(i * tensor_model_parallel_size,
                      (i + 1) * tensor_model_parallel_size)
142
        group = torch.distributed.new_group(ranks, backend=backend)
143
        cpu_group = torch.distributed.new_group(ranks, backend="gloo")
Zhuohan Li's avatar
Zhuohan Li committed
144
        if rank in ranks:
145
146
            _TP_DEVICE_GROUP = group
            _TP_CPU_GROUP = cpu_group
Zhuohan Li's avatar
Zhuohan Li committed
147

148
    # Build the pipeline model-parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
149
150
    global _PIPELINE_MODEL_PARALLEL_GROUP
    global _PIPELINE_GLOBAL_RANKS
151
152
    assert _PIPELINE_MODEL_PARALLEL_GROUP is None, (
        "pipeline model parallel group is already initialized")
Zhuohan Li's avatar
Zhuohan Li committed
153
154
    for i in range(num_pipeline_model_parallel_groups):
        ranks = range(i, world_size, num_pipeline_model_parallel_groups)
155
        group = torch.distributed.new_group(ranks, backend=backend)
Zhuohan Li's avatar
Zhuohan Li committed
156
157
158
159
160
        if rank in ranks:
            _PIPELINE_MODEL_PARALLEL_GROUP = group
            _PIPELINE_GLOBAL_RANKS = ranks


161
162
163
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
164
    backend: Optional[str] = None,
165
166
167
168
169
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
170
171
    # get the backend of _DEVICE_WORLD_GROUP
    backend = backend or torch.distributed.get_backend()
172
173
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
174
                                  pipeline_model_parallel_size, backend)
175
176
177
178
179
180
181
182
183
184
185
186
187
188
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
    assert (get_pipeline_model_parallel_world_size(
    ) == pipeline_model_parallel_size), (
        "pipeline parallel group already initialized, but of unexpected size: "
        f"{get_pipeline_model_parallel_world_size()=} vs. "
        f"{pipeline_model_parallel_size=}")


Zhuohan Li's avatar
Zhuohan Li committed
189
def model_parallel_is_initialized():
190
    """Check if tensor and pipeline parallel groups are initialized."""
191
    return (_TP_DEVICE_GROUP is not None
192
            and _PIPELINE_MODEL_PARALLEL_GROUP is not None)
Zhuohan Li's avatar
Zhuohan Li committed
193
194


195
196
197
198
199
200
def get_cpu_world_group():
    """Get the CPU world group."""
    assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
    return _CPU_WORLD_GROUP


Zhuohan Li's avatar
Zhuohan Li committed
201
202
def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
203
    assert _TP_DEVICE_GROUP is not None, (
204
        "tensor model parallel group is not initialized")
205
206
207
208
209
210
211
212
    return _TP_DEVICE_GROUP


def get_tensor_model_parallel_cpu_group():
    """Get the tensor model parallel cpu group the caller rank belongs to."""
    assert _TP_CPU_GROUP is not None, (
        "tensor model parallel cpu group is not initialized")
    return _TP_CPU_GROUP
Zhuohan Li's avatar
Zhuohan Li committed
213
214
215
216


def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
217
218
    assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, (
        "pipeline model parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
219
220
221
222
223
    return _PIPELINE_MODEL_PARALLEL_GROUP


def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
224
225
    return torch.distributed.get_world_size(
        group=get_tensor_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
226
227
228
229


def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
230
231
    return torch.distributed.get_world_size(
        group=get_pipeline_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
232
233
234
235
236
237
238
239
240


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())


def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
241
242
    return torch.distributed.get_rank(
        group=get_pipeline_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
243
244
245
246
247
248
249
250
251
252
253
254
255


def get_tensor_model_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the tensor model parallel group."""
    global_rank = torch.distributed.get_rank()
    local_world_size = get_tensor_model_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size


def get_pipeline_model_parallel_first_rank():
    """Return the global rank of the first process in the pipeline for the
    current tensor parallel group"""
256
257
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
258
259
260
261
262
263
    return _PIPELINE_GLOBAL_RANKS[0]


def get_pipeline_model_parallel_last_rank():
    """Return the global rank of the last process in the pipeline for the
    current tensor parallel group"""
264
265
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
266
267
268
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
    return _PIPELINE_GLOBAL_RANKS[last_rank_local]

Zhuohan Li's avatar
Zhuohan Li committed
269

Zhuohan Li's avatar
Zhuohan Li committed
270
271
def get_pipeline_model_parallel_next_rank():
    """Return the global rank that follows the caller in the pipeline"""
272
273
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
274
275
276
277
278
279
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]


def get_pipeline_model_parallel_prev_rank():
280
    """Return the global rank that precedes the caller in the pipeline"""
281
282
    assert _PIPELINE_GLOBAL_RANKS is not None, (
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
283
284
285
286
287
288
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
    return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]


def destroy_model_parallel():
289
    """Set the groups to none and destroy them."""
290
291
292
293
294
295
296
297
    global _TP_DEVICE_GROUP
    if _TP_DEVICE_GROUP:
        torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
    _TP_DEVICE_GROUP = None
    global _TP_CPU_GROUP
    if _TP_CPU_GROUP:
        torch.distributed.destroy_process_group(_TP_CPU_GROUP)
    _TP_CPU_GROUP = None
Zhuohan Li's avatar
Zhuohan Li committed
298
    global _PIPELINE_MODEL_PARALLEL_GROUP
299
300
    if _PIPELINE_MODEL_PARALLEL_GROUP:
        torch.distributed.destroy_process_group(_PIPELINE_MODEL_PARALLEL_GROUP)
Zhuohan Li's avatar
Zhuohan Li committed
301
    _PIPELINE_MODEL_PARALLEL_GROUP = None
302
303
    global _PIPELINE_GLOBAL_RANKS
    _PIPELINE_GLOBAL_RANKS = None
304
    from vllm.distributed.device_communicators import pynccl_utils
Woosuk Kwon's avatar
Woosuk Kwon committed
305

306
307
    # Destroy the pynccl states if any.
    pynccl_utils.destroy_process_group()
Woosuk Kwon's avatar
Woosuk Kwon committed
308
309


310
311
# Whether to use pynccl for nccl all reduce.
# We use pynccl for all reduce when using CUDA graph, because torch.distributed
Woosuk Kwon's avatar
Woosuk Kwon committed
312
# is not well supported by CUDA graph.
313
_ENABLE_PYNCCL_FOR_ALL_REDUCE = False
Woosuk Kwon's avatar
Woosuk Kwon committed
314
315
316


@contextlib.contextmanager
317
def with_pynccl_for_all_reduce():
318
    from vllm.distributed.device_communicators import pynccl_utils
319
    """use pynccl instead of torch.distributed for all reduce"""
Woosuk Kwon's avatar
Woosuk Kwon committed
320
321
322
    tp_size = get_tensor_model_parallel_world_size()
    if tp_size == 1:
        # No-op.
323
        # NOTE(woosuk): We don't initialize pynccl when tp_size is 1.
Woosuk Kwon's avatar
Woosuk Kwon committed
324
325
        yield
    else:
326
327
328
        global _ENABLE_PYNCCL_FOR_ALL_REDUCE
        old = _ENABLE_PYNCCL_FOR_ALL_REDUCE
        _ENABLE_PYNCCL_FOR_ALL_REDUCE = True
Woosuk Kwon's avatar
Woosuk Kwon committed
329
330

        stream = torch.cuda.current_stream()
331
        with pynccl_utils.set_pynccl_stream(stream):
Woosuk Kwon's avatar
Woosuk Kwon committed
332
            yield
333
        _ENABLE_PYNCCL_FOR_ALL_REDUCE = old
Woosuk Kwon's avatar
Woosuk Kwon committed
334
335


336
337
338
339
def is_pynccl_enabled_for_all_reduce():
    """check if pynccl is enabled for all reduce"""
    global _ENABLE_PYNCCL_FOR_ALL_REDUCE
    return _ENABLE_PYNCCL_FOR_ALL_REDUCE