parallel_state.py 14.1 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
# Copyright 2023 The vLLM team.
2
3
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
4
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
5
"""Tensor and pipeline parallel groups."""
6
from typing import List, Optional
Zhuohan Li's avatar
Zhuohan Li committed
7
8

import torch
9
from torch.distributed import ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
10

11
import vllm.envs as envs
12
13
14
15
from vllm.logger import init_logger

logger = init_logger(__name__)

16
17
_ENABLE_CUSTOM_ALL_REDUCE = True

18
# Tensor model parallel group that the current rank belongs to.
19
20
21
_TP_DEVICE_GROUP: Optional[ProcessGroup] = None
_TP_CPU_GROUP: Optional[ProcessGroup] = None
_TP_PYNCCL_COMMUNICATOR = None
22
_TP_CA_COMMUNICATOR = None
23
# Pipeline model parallel group that the current rank belongs to.
24
_PP_DEVICE_GROUP: Optional[ProcessGroup] = None
25
26
_PP_CPU_GROUP: Optional[ProcessGroup] = None
_PP_PYNCCL_COMMUNICATOR = None
Zhuohan Li's avatar
Zhuohan Li committed
27

28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
# when people blindly call `torch.distributed.all_reduce` etc,
# it will use this group. It is initialized with the `backend`
# parameter of `init_distributed_environment` below.
# Essentially, this is `torch.distributed.group.WORLD`.
# We leave a line here to note that this is device-specific.
# Note that this variable is not safe to use, because when users
# call `init_distributed_environment` first, and then destroy
# the process group themselves, this variable will keep a reference to the
# destroyed process group, which is not useful.
_DEVICE_WORLD_GROUP = None

# duing `init_distributed_environment`, we will also initialize a
# group with `gloo` backend, to allow direct coordination between
# processes through the CPU.
_CPU_WORLD_GROUP = None

# In summary, after calling `init_distributed_environment`, we will
# always have two groups: one for device-specific (and is the default)
# and one for CPU. All processes will be part of both groups.

48
49
# A list of global ranks for each pipeline group to ease calculation of the
# source rank when broadcasting from the first or last pipeline stage.
50
_PP_GLOBAL_RANKS: Optional[List[int]] = None
Zhuohan Li's avatar
Zhuohan Li committed
51

52
53
54
_LOCAL_RANK = -1


55
56
57
58
59
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable


60
61
62
63
64
def get_pp_pynccl_communicator():
    global _PP_PYNCCL_COMMUNICATOR
    return _PP_PYNCCL_COMMUNICATOR


65
66
67
68
69
def get_tp_pynccl_communicator():
    global _TP_PYNCCL_COMMUNICATOR
    return _TP_PYNCCL_COMMUNICATOR


70
71
72
73
74
def get_tp_ca_communicator():
    global _TP_CA_COMMUNICATOR
    return _TP_CA_COMMUNICATOR


75
76
77
78
def get_local_rank():
    global _LOCAL_RANK
    return _LOCAL_RANK

Zhuohan Li's avatar
Zhuohan Li committed
79

80
def init_distributed_environment(
81
82
83
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
84
85
86
    local_rank: int = -1,
    backend: str = "nccl",
):
87
88
89
90
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
        global _DEVICE_WORLD_GROUP, _CPU_WORLD_GROUP
        _DEVICE_WORLD_GROUP = torch.distributed.group.WORLD
        ranks = list(range(torch.distributed.get_world_size()))
        _CPU_WORLD_GROUP = torch.distributed.new_group(ranks=ranks,
                                                       backend="gloo")
106
107
108
        # set the local rank
        # local_rank is not available in torch ProcessGroup,
        # see https://github.com/pytorch/pytorch/issues/122816
109
110
111
112
113
114
115
        if local_rank == -1:
            # local rank not set, this usually happens in single-node
            # setting, where we can use rank as local rank
            if distributed_init_method == "env://":
                local_rank = envs.LOCAL_RANK
            else:
                local_rank = rank
116
117
        global _LOCAL_RANK
        _LOCAL_RANK = local_rank
118
119
120
121
122
        # A small all_reduce for warmup.
        data = torch.zeros(1)
        if torch.cuda.is_available():
            data = data.to(device=f"cuda:{local_rank}")
        torch.distributed.all_reduce(data)
123
124
125
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        del data
126
127


Zhuohan Li's avatar
Zhuohan Li committed
128
129
130
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
131
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
132
133
) -> None:
    """
134
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
135
136

    Arguments:
137
138
139
140
141
142
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
143
144
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
145
146
147
148
149
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
150
151
152
153
154
155
156
157
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
158
159
    # get the backend of _DEVICE_WORLD_GROUP
    backend = backend or torch.distributed.get_backend()
Zhuohan Li's avatar
Zhuohan Li committed
160

161
162
    if (world_size !=
            tensor_model_parallel_size * pipeline_model_parallel_size):
Zhuohan Li's avatar
Zhuohan Li committed
163
        raise RuntimeError(
164
165
166
167
168
169
170
171
            f"world_size ({world_size}) is not equal to "
            f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
            f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
    num_pipeline_model_parallel_groups: int = (world_size //
                                               pipeline_model_parallel_size)
Zhuohan Li's avatar
Zhuohan Li committed
172
173
174
    rank = torch.distributed.get_rank()

    # Build the tensor model-parallel groups.
175
176
    global _TP_DEVICE_GROUP, _TP_CPU_GROUP
    global _TP_PYNCCL_COMMUNICATOR, _TP_CA_COMMUNICATOR
177
    assert _TP_DEVICE_GROUP is None, (
178
        "tensor model parallel group is already initialized")
Zhuohan Li's avatar
Zhuohan Li committed
179
    for i in range(num_tensor_model_parallel_groups):
180
181
182
        ranks = list(
            range(i * tensor_model_parallel_size,
                  (i + 1) * tensor_model_parallel_size))
183
        group = torch.distributed.new_group(ranks, backend=backend)
184
        cpu_group = torch.distributed.new_group(ranks, backend="gloo")
Zhuohan Li's avatar
Zhuohan Li committed
185
        if rank in ranks:
186
187
            _TP_DEVICE_GROUP = group
            _TP_CPU_GROUP = cpu_group
Zhuohan Li's avatar
Zhuohan Li committed
188

189
    from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
190
191
192
193
194
    if tensor_model_parallel_size > 1:
        _TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
            group=_TP_CPU_GROUP,
            device=_LOCAL_RANK,
        )
195

196
197
198
199
200
201
202
203
204
    # Initialize a custom fast all-reduce implementation.
    if _ENABLE_CUSTOM_ALL_REDUCE:
        from vllm.distributed.device_communicators.custom_all_reduce import (
            CustomAllreduce)
        _TP_CA_COMMUNICATOR = CustomAllreduce(
            group=_TP_CPU_GROUP,
            device=_LOCAL_RANK,
        )

205
    # Build the pipeline model-parallel groups.
206
207
    global _PP_DEVICE_GROUP, _PP_CPU_GROUP
    global _PP_PYNCCL_COMMUNICATOR
208
209
    global _PP_GLOBAL_RANKS
    assert _PP_DEVICE_GROUP is None, (
210
        "pipeline model parallel group is already initialized")
Zhuohan Li's avatar
Zhuohan Li committed
211
    for i in range(num_pipeline_model_parallel_groups):
212
        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
213
        group = torch.distributed.new_group(ranks, backend=backend)
214
        cpu_group = torch.distributed.new_group(ranks, backend="gloo")
Zhuohan Li's avatar
Zhuohan Li committed
215
        if rank in ranks:
216
            _PP_DEVICE_GROUP = group
217
            _PP_CPU_GROUP = cpu_group
218
            _PP_GLOBAL_RANKS = ranks
Zhuohan Li's avatar
Zhuohan Li committed
219

220
221
222
223
224
225
    if pipeline_model_parallel_size > 1:
        _PP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
            group=_PP_CPU_GROUP,
            device=_LOCAL_RANK,
        )

Zhuohan Li's avatar
Zhuohan Li committed
226

227
228
229
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
230
    backend: Optional[str] = None,
231
232
233
234
235
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
236
237
    # get the backend of _DEVICE_WORLD_GROUP
    backend = backend or torch.distributed.get_backend()
238
239
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
240
                                  pipeline_model_parallel_size, backend)
241
242
243
244
245
246
247
248
249
250
251
252
253
254
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
    assert (get_pipeline_model_parallel_world_size(
    ) == pipeline_model_parallel_size), (
        "pipeline parallel group already initialized, but of unexpected size: "
        f"{get_pipeline_model_parallel_world_size()=} vs. "
        f"{pipeline_model_parallel_size=}")


Zhuohan Li's avatar
Zhuohan Li committed
255
def model_parallel_is_initialized():
256
    """Check if tensor and pipeline parallel groups are initialized."""
257
    return (_TP_DEVICE_GROUP is not None and _PP_DEVICE_GROUP is not None)
Zhuohan Li's avatar
Zhuohan Li committed
258
259


260
261
262
263
264
265
def get_cpu_world_group():
    """Get the CPU world group."""
    assert _CPU_WORLD_GROUP is not None, ("CPU world group is not initialized")
    return _CPU_WORLD_GROUP


Zhuohan Li's avatar
Zhuohan Li committed
266
267
def get_tensor_model_parallel_group():
    """Get the tensor model parallel group the caller rank belongs to."""
268
    assert _TP_DEVICE_GROUP is not None, (
269
        "tensor model parallel group is not initialized")
270
271
272
273
274
275
276
277
    return _TP_DEVICE_GROUP


def get_tensor_model_parallel_cpu_group():
    """Get the tensor model parallel cpu group the caller rank belongs to."""
    assert _TP_CPU_GROUP is not None, (
        "tensor model parallel cpu group is not initialized")
    return _TP_CPU_GROUP
Zhuohan Li's avatar
Zhuohan Li committed
278
279
280
281


def get_pipeline_model_parallel_group():
    """Get the pipeline model parallel group the caller rank belongs to."""
282
    assert _PP_DEVICE_GROUP is not None, (
283
        "pipeline model parallel group is not initialized")
284
    return _PP_DEVICE_GROUP
Zhuohan Li's avatar
Zhuohan Li committed
285
286


287
288
289
290
291
292
293
def get_pipeline_model_parallel_cpu_group():
    """Get the pipeline model parallel cpu group the caller rank belongs to."""
    assert _PP_CPU_GROUP is not None, (
        "pipeline model parallel cpu group is not initialized")
    return _PP_CPU_GROUP


Zhuohan Li's avatar
Zhuohan Li committed
294
295
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
296
297
    return torch.distributed.get_world_size(
        group=get_tensor_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
298
299
300
301


def get_pipeline_model_parallel_world_size():
    """Return world size for the pipeline model parallel group."""
302
303
    return torch.distributed.get_world_size(
        group=get_pipeline_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
304
305
306
307
308
309
310
311
312


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
    return torch.distributed.get_rank(group=get_tensor_model_parallel_group())


def get_pipeline_model_parallel_rank():
    """Return my rank for the pipeline model parallel group."""
313
314
    return torch.distributed.get_rank(
        group=get_pipeline_model_parallel_group())
Zhuohan Li's avatar
Zhuohan Li committed
315
316
317
318
319
320
321
322
323
324
325
326
327


def get_tensor_model_parallel_src_rank():
    """Calculate the global rank corresponding to the first local rank
    in the tensor model parallel group."""
    global_rank = torch.distributed.get_rank()
    local_world_size = get_tensor_model_parallel_world_size()
    return (global_rank // local_world_size) * local_world_size


def get_pipeline_model_parallel_first_rank():
    """Return the global rank of the first process in the pipeline for the
    current tensor parallel group"""
328
    assert _PP_GLOBAL_RANKS is not None, (
329
        "Pipeline parallel group is not initialized")
330
    return _PP_GLOBAL_RANKS[0]
Zhuohan Li's avatar
Zhuohan Li committed
331
332
333
334
335


def get_pipeline_model_parallel_last_rank():
    """Return the global rank of the last process in the pipeline for the
    current tensor parallel group"""
336
    assert _PP_GLOBAL_RANKS is not None, (
337
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
338
    last_rank_local = get_pipeline_model_parallel_world_size() - 1
339
    return _PP_GLOBAL_RANKS[last_rank_local]
Zhuohan Li's avatar
Zhuohan Li committed
340

Zhuohan Li's avatar
Zhuohan Li committed
341

Zhuohan Li's avatar
Zhuohan Li committed
342
343
def get_pipeline_model_parallel_next_rank():
    """Return the global rank that follows the caller in the pipeline"""
344
    assert _PP_GLOBAL_RANKS is not None, (
345
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
346
347
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
348
    return _PP_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
Zhuohan Li's avatar
Zhuohan Li committed
349
350
351


def get_pipeline_model_parallel_prev_rank():
352
    """Return the global rank that precedes the caller in the pipeline"""
353
    assert _PP_GLOBAL_RANKS is not None, (
354
        "Pipeline parallel group is not initialized")
Zhuohan Li's avatar
Zhuohan Li committed
355
356
    rank_in_pipeline = get_pipeline_model_parallel_rank()
    world_size = get_pipeline_model_parallel_world_size()
357
    return _PP_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
Zhuohan Li's avatar
Zhuohan Li committed
358
359
360


def destroy_model_parallel():
361
    """Set the groups to none and destroy them."""
362
363
364
365
366
367
368
369
    global _TP_DEVICE_GROUP
    if _TP_DEVICE_GROUP:
        torch.distributed.destroy_process_group(_TP_DEVICE_GROUP)
    _TP_DEVICE_GROUP = None
    global _TP_CPU_GROUP
    if _TP_CPU_GROUP:
        torch.distributed.destroy_process_group(_TP_CPU_GROUP)
    _TP_CPU_GROUP = None
370
371
372
373
374
375
376
377
378
    global _TP_PYNCCL_COMMUNICATOR
    _TP_PYNCCL_COMMUNICATOR = None

    global _PP_DEVICE_GROUP
    if _PP_DEVICE_GROUP:
        torch.distributed.destroy_process_group(_PP_DEVICE_GROUP)
    _PP_DEVICE_GROUP = None
    global _PP_GLOBAL_RANKS
    _PP_GLOBAL_RANKS = None