parallel_state.py 12.6 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
# Copyright 2023 The vLLM team.
2
3
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
4
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
5
"""Tensor and pipeline parallel groups."""
6
from typing import List, Optional
Zhuohan Li's avatar
Zhuohan Li committed
7
8

import torch
9
from torch.distributed import ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import vllm.envs as envs
12
13
14
15
from vllm.logger import init_logger

logger = init_logger(__name__)

16
# Tensor model parallel group that the current rank belongs to.
17
18
19
_TP_DEVICE_GROUP: Optional[ProcessGroup] = None
_TP_CPU_GROUP: Optional[ProcessGroup] = None
_TP_PYNCCL_COMMUNICATOR = None
20
# Pipeline model parallel group that the current rank belongs to.
21
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
Zhuohan Li's avatar
Zhuohan Li committed
22

23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# parameter of `init_distributed_environment` below.
# Essentially, this is `torch.distributed.group.WORLD`.
# We leave a line here to note that this is device-specific.
# Note that this variable is not safe to use, because when users
# call `init_distributed_environment` first, and then destroy
# the process group themselves, this variable will keep a reference to the
# destroyed process group, which is not useful.
_DEVICE_WORLD_GROUP = None

# duing `init_distributed_environment`, we will also initialize a
# group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
_CPU_WORLD_GROUP = None

# In summary, after calling `init_distributed_environment`, we will
# always have two groups: one for device-specific (and is the default)
# and one for CPU. All processes will be part of both groups.

43
44
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
45
_PP_GLOBAL_RANKS: Optional[List[int]] = None
Zhuohan Li's avatar
Zhuohan Li committed
46

47
48
49
_LOCAL_RANK = -1


50
51
52
53
54
def get_tp_pynccl_communicator():
    global _TP_PYNCCL_COMMUNICATOR
    return _TP_PYNCCL_COMMUNICATOR


55
56
57
58
def get_local_rank():
    global _LOCAL_RANK
    return _LOCAL_RANK

Zhuohan Li's avatar
Zhuohan Li committed
59

60
def init_distributed_environment(
61
62
63
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
64
65
66
    local_rank: int = -1,
    backend: str = "nccl",
):
67
68
69
70
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
        global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
        _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
        ranks = list(range(torch.distributed.get_world_size()))
        _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
                                                       backend="gloo")
86
87
88
        # set the local rank
        # local_rank is not available in torch ProcessGroup,
        # see https://github.com/pytorch/pytorch/issues/122816
89
90
91
92
93
94
95
        if local_rank == -1:
            # local rank not set, this usually happens in single-node
            # setting, where we can use rank as local rank
            if distributed_init_method == "env://":
                local_rank = envs.LOCAL_RANK
            else:
                local_rank = rank
96
97
        global _LOCAL_RANK
        _LOCAL_RANK = local_rank
98
99
100
101
102
        # A small all_reduce for warmup.
        data = torch.zeros(1)
        if torch.cuda.is_available():
            data = data.to(device=f"cuda:{local_rank}")
        torch.distributed.all_reduce(data)
103
104


Zhuohan Li's avatar
Zhuohan Li committed
105
106
107
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
108
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
109
110
) -> None:
    """
111
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
112
113

    Arguments:
114
115
116
117
118
119
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
120
121
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
122
123
124
125
126
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
127
128
129
130
131
132
133
134
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
135
136
    # get the backend of _DEVICE_WORLD_GROUP
    backend = backend or torch.distributed.get_backend()
Zhuohan Li's avatar
Zhuohan Li committed
137

138
139
    if (world_size !=
            tensor_model_parallel_size * pipeline_model_parallel_size):
Zhuohan Li's avatar
Zhuohan Li committed
140
        raise RuntimeError(
141
142
143
144
145
146
147
148
            f"world_size ({world_size}) is not equal to "
            f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
            f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    num_pipeline_model_parallel_groups: int = (world_size //
                                               pipeline_model_parallel_size)
Zhuohan Li's avatar
Zhuohan Li committed
149
150
151
    rank = torch.distributed.get_rank()

    # Build the tensor model-parallel groups.
152
    global _TP_DEVICE_GROUP, _TP_CPU_GROUP, _TP_PYNCCL_COMMUNICATOR
153
    assert _TP_DEVICE_GROUP is None, (
154
        "tensor model parallel group is already initialized")
Zhuohan Li's avatar
Zhuohan Li committed
155
    for i in range(num_tensor_model_parallel_groups):
156
157
158
        ranks = list(
            range(i * tensor_model_parallel_size,
                  (i + 1) * tensor_model_parallel_size))
159
        group = torch.distributed.new_group(ranks, backend=backend)
160
        cpu_group = torch.distributed.new_group(ranks, backend="gloo")
Zhuohan Li's avatar
Zhuohan Li committed
161
        if rank in ranks:
162
163
            _TP_DEVICE_GROUP = group
            _TP_CPU_GROUP = cpu_group
Zhuohan Li's avatar
Zhuohan Li committed
164

165
166
167
168
169
170
    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
    _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
        group=_TP_CPU_GROUP,
        device=_LOCAL_RANK,
    )

171
    # Build the pipeline model-parallel groups.
172
173
174
    global _PP_DEVICE_GROUP
    global _PP_GLOBAL_RANKS
    assert _PP_DEVICE_GROUP is None, (
175
        "pipeline model parallel group is already initialized")
Zhuohan Li's avatar
Zhuohan Li committed
176
    for i in range(num_pipeline_model_parallel_groups):
177
        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
178
        group = torch.distributed.new_group(ranks, backend=backend)
Zhuohan Li's avatar
Zhuohan Li committed
179
        if rank in ranks:
180
181
            _PP_DEVICE_GROUP = group
            _PP_GLOBAL_RANKS = ranks
Zhuohan Li's avatar
Zhuohan Li committed
182
183


184
185
186
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
187
    backend: Optional[str] = None,
188
189
190
191
192
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
193
194
    # get the backend of _DEVICE_WORLD_GROUP
    backend = backend or torch.distributed.get_backend()
195
196
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
197
                                  pipeline_model_parallel_size, backend)
198
199
200
201
202
203
204
205
206
207
208
209
210
211
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
    assert (get_pipeline_model_parallel_world_size(
    ) == pipeline_model_parallel_size), (
        "pipeline parallel group already initialized, but of unexpected size: "
        f"{get_pipeline_model_parallel_world_size()=} vs. "
        f"{pipeline_model_parallel_size=}")


Zhuohan Li's avatar
Zhuohan Li committed
212
def model_parallel_is_initialized():
213
    """Check if tensor and pipeline parallel groups are initialized."""
214
    return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None)
Zhuohan Li's avatar
Zhuohan Li committed
215
216


217
218
219
220
221
222
def get_cpu_world_group():
    """Get the CPU world group."""
    assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
    return _CPU_WORLD_GROUP


Zhuohan Li's avatar
Zhuohan Li committed
223
224
def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
225
    assert _TP_DEVICE_GROUP is not None, (
226
        "tensor model parallel group is not initialized")
227
228
229
230
231
232
233
234
    return _TP_DEVICE_GROUP


def get_tensor_model_parallel_cpu_group():
    """Get the tensor model parallel cpu group the caller rank belongs to."""
    assert _TP_CPU_GROUP is not None, (
        "tensor model parallel cpu group is not initialized")
    return _TP_CPU_GROUP
Zhuohan Li's avatar
Zhuohan Li committed
235
236
237
238


def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
239
    assert _PP_DEVICE_GROUP is not None, (
240
        "pipeline model parallel group is not initialized")
241
    return _PP_DEVICE_GROUP
Zhuohan Li's avatar
Zhuohan Li committed
242
243
244
245


def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
246
247
    return torch.distributed.get_world_size(
        group=get_tensor_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
248
249
250
251


def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
252
253
    return torch.distributed.get_world_size(
        group=get_pipeline_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
254
255
256
257
258
259
260
261
262


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())


def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
263
264
    return torch.distributed.get_rank(
        group=get_pipeline_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
265
266
267
268
269
270
271
272
273
274
275
276
277


def get_tensor_model_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the tensor model parallel group."""
    global_rank = torch.distributed.get_rank()
    local_world_size = get_tensor_model_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size


def get_pipeline_model_parallel_first_rank():
    """Return the global rank of the first process in the pipeline for the
    current tensor parallel group"""
278
    assert _PP_GLOBAL_RANKS is not None, (
279
        "Pipeline parallel group is not initialized")
280
    return _PP_GLOBAL_RANKS[0]
Zhuohan Li's avatar
Zhuohan Li committed
281
282
283
284
285


def get_pipeline_model_parallel_last_rank():
    """Return the global rank of the last process in the pipeline for the
    current tensor parallel group"""
286
    assert _PP_GLOBAL_RANKS is not None, (
287
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
288
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
289
    return _PP_GLOBAL_RANKS[last_rank_local]
Zhuohan Li's avatar
Zhuohan Li committed
290

Zhuohan Li's avatar
Zhuohan Li committed
291

Zhuohan Li's avatar
Zhuohan Li committed
292
293
def get_pipeline_model_parallel_next_rank():
    """Return the global rank that follows the caller in the pipeline"""
294
    assert _PP_GLOBAL_RANKS is not None, (
295
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
296
297
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
298
    return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
Zhuohan Li's avatar
Zhuohan Li committed
299
300
301


def get_pipeline_model_parallel_prev_rank():
302
    """Return the global rank that precedes the caller in the pipeline"""
303
    assert _PP_GLOBAL_RANKS is not None, (
304
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
305
306
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
307
    return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
Zhuohan Li's avatar
Zhuohan Li committed
308
309
310


def destroy_model_parallel():
311
    """Set the groups to none and destroy them."""
312
313
314
315
316
317
318
319
    global _TP_DEVICE_GROUP
    if _TP_DEVICE_GROUP:
        torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
    _TP_DEVICE_GROUP = None
    global _TP_CPU_GROUP
    if _TP_CPU_GROUP:
        torch.distributed.destroy_process_group(_TP_CPU_GROUP)
    _TP_CPU_GROUP = None
320
321
322
323
324
325
326
327
328
    global _TP_PYNCCL_COMMUNICATOR
    _TP_PYNCCL_COMMUNICATOR = None

    global _PP_DEVICE_GROUP
    if _PP_DEVICE_GROUP:
        torch.distributed.destroy_process_group(_PP_DEVICE_GROUP)
    _PP_DEVICE_GROUP = None
    global _PP_GLOBAL_RANKS
    _PP_GLOBAL_RANKS = None