parallel_state.py 65 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25

26
import contextlib
27
import gc
28
import pickle
29
import weakref
30
from collections import namedtuple
31
from collections.abc import Callable
32
33
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
34
from datetime import timedelta
35
from multiprocessing import shared_memory
36
from typing import Any
37
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
38
39

import torch
40
import torch.distributed
41
42
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
43
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
44

45
import vllm.envs as envs
46
from vllm.distributed.device_communicators.base_device_communicator import (
47
48
    DeviceCommunicatorBase,
)
49
from vllm.distributed.utils import StatelessProcessGroup
50
from vllm.logger import init_logger
51
from vllm.utils.import_utils import resolve_obj_by_qualname
52
from vllm.utils.network_utils import get_distributed_init_method
53
from vllm.utils.system_utils import suppress_stdout
54
55
56
from vllm.utils.torch_utils import (
    direct_register_custom_op,
)
57
58


59
60
61
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
62

63

64
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
65

66
67

def _split_tensor_dict(
68
    tensor_dict: dict[str, torch.Tensor | Any],
69
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
70
71
72
73
74
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
75
76
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
77
78
79
80
81
82
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
83
            device = value.device.type
84
            metadata_list.append(
85
86
                (key, TensorMetadata(device, value.dtype, value.size()))
            )
87
88
            tensor_list.append(value)
        else:
89
            metadata_list.append((key, value))
90
91
92
    return metadata_list, tensor_list


93
_group_name_counter: dict[str, int] = {}
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


109
_groups: dict[str, Callable[[], "GroupCoordinator | None"]] = {}
110
111
112


def _register_group(group: "GroupCoordinator") -> None:
113
    _groups[group.unique_name] = weakref.ref(group)
114
115


116
117
118
119
120
121
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
122
123


124
125
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
126

127

128
129
130
def reduce_scatter(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
131
132
133
134
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
135
    return group._reduce_scatter_out_place(tensor, dim)
136
137


138
139
140
def reduce_scatter_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
141
142
143
144
145
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


146
147
148
def all_gather(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
149
150
151
152
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
153
    return group._all_gather_out_place(tensor, dim)
154
155


156
157
158
def all_gather_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
159
160
161
162
163
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def patched_fused_scaled_matmul_reduce_scatter_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    # Copied from
    # https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
    if A_scale.numel() > 1:
        if A_scale.shape[:-1] != A.shape[:-1]:
            raise ValueError(
                "For row-wise scaling, the leading dims of A_scale "
                "must match the leading dims of A "
                f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
            )
        A_scale = A_scale.flatten(0, -2).contiguous()
    elif A_scale.numel() != 1:
        raise ValueError(
            "Invalid A_scale shape "
            f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
        )

    C = torch._scaled_mm(
        A.flatten(0, -2).contiguous(),
        B,
        A_scale,
        B_scale,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )
    C = C.view(*output_shape[:-1], B.shape[1])
    res = funcol.reduce_scatter_tensor(
        C,
        reduce_op,
        orig_scatter_dim,  # need original scatter dim for 3D+ output tensor here
        group_name,
    )
    res = funcol.wait_tensor(res)
    return res


def patched_fused_scaled_matmul_reduce_scatter(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
        A,
        B,
        A_scale,
        B_scale,
        reduce_op,
        orig_scatter_dim,
        scatter_dim_after_maybe_reshape,
        group_name,
        output_shape,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )


248
249
250
251
252
direct_register_custom_op(
    op_name="all_reduce",
    op_func=all_reduce,
    fake_impl=all_reduce_fake,
)
253

254
255
256
257
258
direct_register_custom_op(
    op_name="reduce_scatter",
    op_func=reduce_scatter,
    fake_impl=reduce_scatter_fake,
)
259

260
261
262
263
264
direct_register_custom_op(
    op_name="all_gather",
    op_func=all_gather,
    fake_impl=all_gather_fake,
)
265

266
267
268
269
270
271
272
273
# TODO: Remove this once the pytorch fix
# (https://github.com/pytorch/pytorch/pull/165086) gets released,
# in either 2.9.1 or 2.10
direct_register_custom_op(
    op_name="patched_fused_scaled_matmul_reduce_scatter",
    op_func=patched_fused_scaled_matmul_reduce_scatter,
    fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake,
)
274

275

276
277
278
279
280
281
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
282
283
        the processes in the group. It manages both CPU and device
        communication.
284
285
286
287
    """

    # available attributes:
    rank: int  # global rank
288
    ranks: list[int]  # global ranks in the group
289
290
291
292
293
294
295
296
297
298
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
299
300
301
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    # device communicator (if use_device_communicator=True)
302
303
    device_communicator: DeviceCommunicatorBase | None
    mq_broadcaster: Any | None  # shared memory broadcaster
304
305
306

    def __init__(
        self,
307
        group_ranks: list[list[int]],
308
        local_rank: int,
309
        torch_distributed_backend: str | Backend,
310
        use_device_communicator: bool,  # whether to use device communicator
311
        use_message_queue_broadcaster: bool = False,
312
        group_name: str | None = None,
313
    ):
314
315
316
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
317
318
319

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
320
321
322

        self_device_group = None
        self_cpu_group = None
323
324
325

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
326
327
                ranks, backend=torch_distributed_backend
            )
328
329
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
330
331
            with suppress_stdout():
                cpu_group = torch.distributed.new_group(ranks, backend="gloo")
332
333
334
335
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
336
337
338
339
340
                self_device_group = device_group
                self_cpu_group = cpu_group

        assert self_cpu_group is not None
        assert self_device_group is not None
341

342
343
        self.cpu_group = self_cpu_group
        self.device_group = self_device_group
344

345
        from vllm.platforms import current_platform
346

347
        if current_platform.is_cuda_alike():
348
            self.device = torch.device(f"cuda:{local_rank}")
349
350
        elif current_platform.is_xpu():
            self.device = torch.device(f"xpu:{local_rank}")
351
        elif current_platform.is_out_of_tree():
352
            self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
353
354
355
        else:
            self.device = torch.device("cpu")

356
        self.use_device_communicator = use_device_communicator
357
        self.device_communicator = None
358
359
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
360
361
                current_platform.get_device_communicator_cls()
            )
362
363
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
364
                device=self.device,
365
366
                device_group=self.device_group,
                unique_name=self.unique_name,
367
368
            )

369
370
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

371
        self.mq_broadcaster: MessageQueue | None = None
372
373
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
374
375
                self.cpu_group, 1 << 22, 6
            )
376

377
378
        from vllm.platforms import current_platform

379
380
381
382
383
384
385
        self.use_custom_op_call = (
            current_platform.is_cuda_alike() or current_platform.is_tpu()
        )

        self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
            torch.ops._C, "init_shm_manager"
        )
386

387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
    def create_mq_broadcaster(
        self, writer_rank=0, external_writer_handle=None, blocking=True
    ):
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

        return MessageQueue.create_from_process_group(
            self.cpu_group,
            1 << 22,
            6,
            writer_rank=writer_rank,
            external_writer_handle=external_writer_handle,
            blocking=blocking,
        )

    def create_single_reader_mq_broadcasters(
        self, reader_rank_in_group=0, blocking=False
    ):
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

        return MessageQueue.create_from_process_group_single_reader(
            self.cpu_group,
            1 << 22,
            6,
            reader_rank=self.ranks[reader_rank_in_group],
            blocking=blocking,
        )

414
415
416
417
418
419
420
421
422
423
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

424
425
426
427
428
429
430
431
432
433
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
449
    def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None):
450
451
452
453
454
455
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

456
457
458
459
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
460
461
462
            CudaCommunicator,
        )

463
464
465
466
467
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
468
469
470
471
472
473
474

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

475
        with torch.cuda.stream(stream), maybe_ca_context:
476
            yield graph_capture_context
477
478
479

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
480
481
482
483
484
485
486
487
488
489
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
490
491
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
492
493
494
495
496
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

497
        if self.use_custom_op_call:
498
            return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
499
500
        else:
            return self._all_reduce_out_place(input_)
501

502
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
503
504
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
505
        return self.device_communicator.all_reduce(input_)
506
507
508
509
510
511
512

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
513
514
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
515

516
        if self.use_custom_op_call:
517
518
519
            return torch.ops.vllm.all_gather(
                input_, dim, world_size, group_name=self.unique_name
            )
520
521
522
        else:
            return self._all_gather_out_place(input_, dim)

523
    def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
524
525
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
526
        return self.device_communicator.all_gather(input_, dim)
527

528
529
    def all_gatherv(
        self,
530
        input_: torch.Tensor | list[torch.Tensor],
531
        dim: int = 0,
532
        sizes: list[int] | None = None,
533
    ):
534
535
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
536
537
        return self.device_communicator.all_gatherv(input_, dim, sizes)

538
    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
539
540
541
542
543
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
544
545
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
546

547
        if self.use_custom_op_call:
548
549
550
            return torch.ops.vllm.reduce_scatter(
                input_, dim, world_size, group_name=self.unique_name
            )
551
552
553
        else:
            return self._reduce_scatter_out_place(input_, dim)

554
    def reduce_scatterv(
555
        self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
556
    ) -> torch.Tensor:
557
558
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
559
560
        return self.device_communicator.reduce_scatterv(input_, dim, sizes)

561
    def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
562
563
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
564
565
        return self.device_communicator.reduce_scatter(input_, dim)

566
567
    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
568
    ) -> torch.Tensor | None:
569
570
571
572
573
574
575
576
577
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
578
579
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
580
        return self.device_communicator.gather(input_, dst, dim)
581
582
583
584
585
586
587
588
589
590
591

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
592
593
594
        torch.distributed.broadcast(
            input_, src=self.ranks[src], group=self.device_group
        )
595
596
        return input_

597
    def broadcast_object(self, obj: Any | None = None, src: int = 0):
598
599
600
601
602
603
604
605
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
606
607
608
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
609
        if self.rank_in_group == src:
610
611
612
            torch.distributed.broadcast_object_list(
                [obj], src=self.ranks[src], group=self.cpu_group
            )
613
614
615
            return obj
        else:
            recv = [None]
616
617
618
            torch.distributed.broadcast_object_list(
                recv, src=self.ranks[src], group=self.cpu_group
            )
619
620
            return recv[0]

621
    def broadcast_object_list(
622
        self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
623
    ):
624
625
626
627
628
629
630
631
632
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
633
634
635
        torch.distributed.broadcast_object_list(
            obj_list, src=self.ranks[src], group=self.device_group
        )
636
637
        return obj_list

638
639
640
641
642
643
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

644
        assert dst != self.rank_in_group, (
645
            "Invalid destination rank. Destination rank is the same "
646
647
            "as the current rank."
        )
648
649
650
651

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

652
653
654
        size_tensor = torch.tensor(
            [object_tensor.numel()], dtype=torch.long, device="cpu"
        )
655
656
657

        # Send object size

658
        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
659
660

        # Send object
661
        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
662
663
664
665
666
667
668
669
670

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

671
        assert src != self.rank_in_group, (
672
673
674
675
676
677
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
678
679
680
        rank_size = torch.distributed.recv(
            size_tensor, src=self.ranks[src], group=self.cpu_group
        )
681
682
683
684
685

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
686
687
            device="cpu",
        )
688

689
690
691
        rank_object = torch.distributed.recv(
            object_tensor, src=self.ranks[src], group=self.cpu_group
        )
692
693

        assert rank_object == rank_size, (
694
695
            "Received object sender rank does not match the size sender rank."
        )
696
697
698
699
700

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

701
702
    def broadcast_tensor_dict(
        self,
703
        tensor_dict: dict[str, torch.Tensor | Any] | None = None,
704
        src: int = 0,
705
706
707
        group: ProcessGroup | None = None,
        metadata_group: ProcessGroup | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
708
709
710
711
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
712
        if not torch.distributed.is_initialized() or self.world_size == 1:
713
714
715
716
717
718
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

719
720
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
721
            metadata_list: list[tuple[Any, Any]] = []
722
723
724
            assert isinstance(tensor_dict, dict), (
                f"Expecting a dictionary, got {type(tensor_dict)}"
            )
725
726
727
728
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
729
            self.broadcast_object(metadata_list, src=src)
730
731
732
733
734
735
736
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
737
738
739
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=metadata_group, async_op=True
                    )
740
741
                else:
                    # use group for GPU tensors
742
743
744
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=group, async_op=True
                    )
745
746
747
748
749
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
750
            metadata_list = self.broadcast_object(None, src=src)
751
752
            tensor_dict = {}
            async_handles = []
753
            for key, value in metadata_list:
754
                if isinstance(value, TensorMetadata):
755
756
757
                    tensor = torch.empty(
                        value.size, dtype=value.dtype, device=value.device
                    )
758
759
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
760
                        tensor_dict[key] = tensor
761
762
763
764
765
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
766
                            src=self.ranks[src],
767
                            group=metadata_group,
768
769
                            async_op=True,
                        )
770
771
                    else:
                        # use group for GPU tensors
772
                        handle = torch.distributed.broadcast(
773
774
                            tensor, src=self.ranks[src], group=group, async_op=True
                        )
775
                    async_handles.append(handle)
776
                    tensor_dict[key] = tensor
777
                else:
778
                    tensor_dict[key] = value
779
780
781
782
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

783
784
    def send_tensor_dict(
        self,
785
786
        tensor_dict: dict[str, torch.Tensor | Any],
        dst: int | None = None,
787
        all_gather_group: "GroupCoordinator | None" = None,
788
789
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
790
791
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
807
808
809
810
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict
811
812
813
814
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
815

816
817
818
819
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
820
            dst = (self.rank_in_group + 1) % self.world_size
821
822
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

823
        if self.use_cpu_custom_send_recv:
824
825
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
826
            self.device_communicator.send_tensor_dict(  # type: ignore
827
828
                tensor_dict, dst
            )
829
830
            return None

831
        metadata_list: list[tuple[Any, Any]] = []
832
833
834
        assert isinstance(tensor_dict, dict), (
            f"Expecting a dictionary, got {type(tensor_dict)}"
        )
835
836
837
838
839
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
840

841
        tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
842
843
844
        assert len(tensor_keys) == len(tensor_list)

        for key, tensor in zip(tensor_keys, tensor_list):
845
846
847
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
848
849

            # send-allgather: send only a slice, then do allgather.
850
851
852
853
854
855
856
857
            use_all_gather = (
                all_gather_group is not None and tensor.numel() % all_gather_size == 0
            )
            use_all_gather = (
                all_gather_tensors.get(key, use_all_gather)
                if all_gather_tensors
                else use_all_gather
            )
858
            if use_all_gather:
859
860
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

861
862
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
863
864
865
                torch.distributed.send(
                    tensor, dst=self.ranks[dst], group=metadata_group
                )
866
867
            else:
                # use group for GPU tensors
868
                torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
869
870
871
872
        return None

    def recv_tensor_dict(
        self,
873
        src: int | None = None,
874
        all_gather_group: "GroupCoordinator | None" = None,
875
876
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
877
878
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
894
895
896
897
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None
898
899
900
901
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
902

903
904
905
906
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
907
            src = (self.rank_in_group - 1) % self.world_size
908
909
        assert src < self.world_size, f"Invalid src rank ({src})"

910
        if self.use_cpu_custom_send_recv:
911
912
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
913
            return self.device_communicator.recv_tensor_dict(  # type: ignore
914
915
                src
            )
916

917
        recv_metadata_list = self.recv_object(src=src)
918
        tensor_dict: dict[str, Any] = {}
919
920
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
921
                tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
922
923
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
924
                    tensor_dict[key] = tensor
925
                    continue
926
927

                # send-allgather: send only a slice, then do allgather.
928
929
930
931
932
933
934
935
936
                use_all_gather = (
                    all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0
                )
                use_all_gather = (
                    all_gather_tensors.get(key, use_all_gather)
                    if all_gather_tensors
                    else use_all_gather
                )
937
938
939

                if use_all_gather:
                    orig_shape = tensor.shape
940
                    tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
941

942
943
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
944
945
946
                    torch.distributed.recv(
                        tensor, src=self.ranks[src], group=metadata_group
                    )
947
948
                else:
                    # use group for GPU tensors
949
                    torch.distributed.recv(tensor, src=self.ranks[src], group=group)
950
951
952
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
953
954
                        tensor, dim=0
                    )
955
956
                    tensor = tensor.reshape(orig_shape)

957
                tensor_dict[key] = tensor
958
            else:
959
                tensor_dict[key] = value
960
961
        return tensor_dict

962
963
964
965
966
967
968
969
970
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

971
    def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
972
        """Sends a tensor to the destination rank in a blocking way"""
973
        """NOTE: `dst` is the local rank of the destination rank."""
974
975
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
976
        self.device_communicator.send(tensor, dst)
977

978
    def recv(
979
        self, size: torch.Size, dtype: torch.dtype, src: int | None = None
980
    ) -> torch.Tensor:
981
982
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
983
984
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
985
        return self.device_communicator.recv(size, dtype, src)
986

987
    def destroy(self):
988
        if hasattr(self, "device_group"):
989
            torch.distributed.destroy_process_group(self.device_group)
990
991
            del self.device_group
        if hasattr(self, "cpu_group"):
992
            torch.distributed.destroy_process_group(self.cpu_group)
993
            del self.cpu_group
994
995
        if self.device_communicator is not None:
            self.device_communicator.destroy()
996
997
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
998

999
1000
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
1001
            self.device_communicator.prepare_communication_buffer_for_model(model)
1002

1003
    def dispatch_router_logits(
1004
1005
1006
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
1007
        is_sequence_parallel: bool = False,
1008
1009
1010
1011
1012
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
    ):
1013
        if self.device_communicator is not None:
1014
            return self.device_communicator.dispatch_router_logits(
1015
1016
1017
1018
                hidden_states,
                router_logits,
                is_sequence_parallel,
                extra_tensors,
1019
            )
1020
1021
        else:
            return hidden_states, router_logits
1022

1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
    def dispatch(
        self,
        hidden_states: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        is_sequence_parallel: bool = False,
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
        | tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ):
        if self.device_communicator is not None:
            return self.device_communicator.dispatch(
                hidden_states,
                topk_weights,
                topk_ids,
                is_sequence_parallel,
                extra_tensors,
            )
        else:
            return hidden_states, topk_weights, topk_ids

1045
1046
1047
    def combine(
        self, hidden_states, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
1048
        if self.device_communicator is not None:
1049
            return self.device_communicator.combine(hidden_states, is_sequence_parallel)
1050
1051
        else:
            return hidden_states
1052

1053

1054
_WORLD: GroupCoordinator | None = None
1055
_INNER_DP_WORLD: GroupCoordinator | None = None
1056
_NODE_COUNT: int | None = None
1057
1058
1059


def get_world_group() -> GroupCoordinator:
1060
    assert _WORLD is not None, "world group is not initialized"
1061
1062
1063
    return _WORLD


1064
1065
1066
1067
1068
def get_inner_dp_world_group() -> GroupCoordinator:
    assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
    return _INNER_DP_WORLD


1069
1070
1071
def init_world_group(
    ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
1072
1073
1074
1075
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
1076
        use_device_communicator=False,
1077
        group_name="world",
1078
1079
1080
    )


1081
def init_model_parallel_group(
1082
    group_ranks: list[list[int]],
1083
1084
1085
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
1086
    group_name: str | None = None,
1087
    use_device_communicator: bool = True,
1088
) -> GroupCoordinator:
1089
1090
1091
1092
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
1093
        use_device_communicator=use_device_communicator,
1094
        use_message_queue_broadcaster=use_message_queue_broadcaster,
1095
        group_name=group_name,
1096
1097
1098
    )


1099
_TP: GroupCoordinator | None = None
1100
1101
1102


def get_tp_group() -> GroupCoordinator:
1103
    assert _TP is not None, "tensor model parallel group is not initialized"
1104
1105
1106
    return _TP


1107
_DCP: GroupCoordinator | None = None
1108
1109
1110


def get_dcp_group() -> GroupCoordinator:
1111
    assert _DCP is not None, "decode context model parallel group is not initialized"
1112
1113
1114
1115
1116
1117
    return _DCP


# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group

1118
_PP: GroupCoordinator | None = None
1119

1120
1121
1122
1123
1124
1125

def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, "pipeline model parallel group is not initialized"
    return _PP


1126
_DP: GroupCoordinator | None = None
1127
1128
1129


def get_dp_group() -> GroupCoordinator:
1130
    assert _DP is not None, "data parallel group is not initialized"
1131
1132
    return _DP

1133

1134
_EP: GroupCoordinator | None = None
1135
1136
1137


def get_ep_group() -> GroupCoordinator:
1138
1139
1140
1141
1142
    assert _EP is not None, (
        "expert parallel group is not initialized. "
        "EP group is only created for MoE models with num_experts > 0. "
        "This function should only be called for MoE models."
    )
1143
1144
1145
    return _EP


1146
1147
1148
1149
1150
1151
_PCP: GroupCoordinator | None = None


def get_pcp_group() -> GroupCoordinator:
    assert _PCP is not None, "prefill context parallel group is not initialized"
    return _PCP
1152
1153


1154
@contextmanager
1155
def graph_capture(device: torch.device):
1156
1157
    """
    `graph_capture` is a context manager which should surround the code that
1158
1159
    is capturing the CUDA graph. Its main purpose is to ensure that some
    operations will be run after the graph is captured, before the graph
1160
1161
1162
1163
1164
1165
1166
1167
1168
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
1169
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
1170
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
1171
1172
        yield context

1173

1174
logger = init_logger(__name__)
1175

1176
_ENABLE_CUSTOM_ALL_REDUCE = True
1177
1178


1179
1180
1181
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
1182

Zhuohan Li's avatar
Zhuohan Li committed
1183

1184
1185
1186
1187
1188
1189
def init_distributed_environment(
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
    local_rank: int = -1,
    backend: str = "nccl",
1190
    timeout: timedelta | None = None,
1191
):
1192
    logger.debug(
1193
1194
1195
1196
1197
1198
1199
        "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
        world_size,
        rank,
        local_rank,
        distributed_init_method,
        backend,
    )
1200
    from vllm.config import get_current_vllm_config_or_none
1201

1202
    config = get_current_vllm_config_or_none()
1203
    if (
1204
1205
        config is not None
        and config.parallel_config.distributed_executor_backend != "external_launcher"
1206
1207
1208
1209
        and (
            config.parallel_config.nnodes > 1
            or config.parallel_config.data_parallel_size > 1
        )
1210
    ):
1211
1212
1213
1214
1215
1216
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232

        # Use appropriate IP and port based on configuration
        if parallel_config.nnodes > 1:
            ip = parallel_config.master_addr
            port = parallel_config.master_port
            distributed_init_method = get_distributed_init_method(ip, port)
        else:
            ip = parallel_config.data_parallel_master_ip
            port = parallel_config.get_next_dp_init_port()
            distributed_init_method = get_distributed_init_method(ip, port)
            logger.debug(
                "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
                world_size,
                rank,
                distributed_init_method,
            )
1233
    if not torch.distributed.is_initialized():
1234
1235
1236
1237
1238
1239
1240
1241
        logger.info(
            "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
            world_size,
            rank,
            local_rank,
            distributed_init_method,
            backend,
        )
1242
1243
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
1244
1245
            "distributed environment"
        )
1246
1247
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
1248
1249
1250
                "Distributed backend %s is not available; falling back to gloo.",
                backend,
            )
1251
            assert torch.distributed.is_gloo_available(), (
1252
1253
                "Fallback Gloo backend is not available."
            )
1254
            backend = "gloo"
1255
1256
1257
1258
1259
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
1260
            rank=rank,
1261
1262
            timeout=timeout,
        )
1263
1264
1265
1266
1267
1268
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
1269
        local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
1270
    global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
1271
    if _WORLD is None:
1272
        ranks = list(range(torch.distributed.get_world_size()))
1273
        _WORLD = init_world_group(ranks, local_rank, backend)
1274
        if config is not None and config.parallel_config.nnodes > 1:
1275
1276
1277
            _NODE_COUNT = config.parallel_config.nnodes
        else:
            _NODE_COUNT = _node_count(_WORLD.cpu_group)
1278
        logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
1279
1280
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
1281
1282
            "world group already initialized with a different world size"
        )
1283
    if config is not None and config.parallel_config.nnodes_within_dp > 1:
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
        if parallel_config.data_parallel_size > 1:
            world_size_inner_dp = parallel_config.world_size
            group_ranks = [
                [dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
                for dp_rank in range(parallel_config.data_parallel_size)
            ]
            _INNER_DP_WORLD = init_model_parallel_group(
                group_ranks,
                get_world_group().local_rank,
                backend,
                use_message_queue_broadcaster=True,
                group_name="inner_dp_world",
                use_device_communicator=False,
            )
        else:
            _INNER_DP_WORLD = _WORLD
1300
1301


Zhuohan Li's avatar
Zhuohan Li committed
1302
1303
1304
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
1305
    prefill_context_model_parallel_size: int = 1,
1306
1307
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
Zhuohan Li's avatar
Zhuohan Li committed
1308
1309
) -> None:
    """
1310
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
1311
1312

    Arguments:
1313
1314
1315
1316
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.
1317
        backend: name of torch distributed communication backend.
1318
1319

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1320
1321
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1322
1323
1324
1325
1326
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1327
1328
1329
1330
1331
1332
1333
1334
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1335
    rank = torch.distributed.get_rank()
1336
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1337

1338
    data_parallel_size = 1
1339
    from vllm.config import get_current_vllm_config_or_none
1340

1341
    config = get_current_vllm_config_or_none()
1342
    if config is not None:
1343
1344
1345
1346
1347
1348
1349
1350
1351
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1352
1353
1354
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1355
1356
1357
1358
1359
        -1,
        data_parallel_size,
        pipeline_model_parallel_size,
        prefill_context_model_parallel_size,
        tensor_model_parallel_size,
1360
    )  # noqa
1361

1362
1363
    # Build the tensor model-parallel groups.
    global _TP
1364
    assert _TP is None, "tensor model parallel group is already initialized"
1365
1366
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1367
1368

    # message queue broadcaster is only used in tensor model parallel group
1369
1370
1371
1372
1373
1374
1375
    _TP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="tp",
    )
1376

1377
1378
    # Build the DCP model-parallel groups.
    global _DCP
1379
    assert _DCP is None, "decode context model parallel group is already initialized"
1380
1381
    # Note(hc): In the current implementation of decode context parallel,
    # dcp_size must not exceed tp_size, because the world size does not
1382
    # change by DCP, it simply reuses the GPUs of TP group, and split one
1383
    # TP group into tp_size//dcp_size DCP groups.
1384
    group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
1385
    group_ranks = [x.tolist() for x in group_ranks]
1386
1387
1388
1389
1390
1391
1392
    _DCP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="dcp",
    )
1393

1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
    global _PCP
    assert _PCP is None, "prefill context parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(3, 4)
        .reshape(-1, prefill_context_model_parallel_size)
        .unbind(0)
    )
    group_ranks = [x.tolist() for x in group_ranks]
    _PCP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pcp"
    )

1406
    # Build the pipeline model-parallel groups.
1407
    global _PP
1408
1409
    assert _PP is None, "pipeline model parallel group is already initialized"
    group_ranks = (
1410
        all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
1411
    )
1412
    group_ranks = [x.tolist() for x in group_ranks]
1413
1414
1415
    _PP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pp"
    )
1416

1417
    global _DP
1418
    assert _DP is None, "data parallel group is already initialized"
1419
    group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
1420
    group_ranks = [x.tolist() for x in group_ranks]
1421
1422
1423
    _DP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="dp"
    )
1424

1425
    global _EP
1426
    assert _EP is None, "expert parallel group is already initialized"
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
    # Don't create EP group for dense models.
    if config is None or config.model_config is None or config.model_config.is_moe:
        group_ranks = (
            all_ranks.transpose(1, 2)
            .reshape(
                -1,
                data_parallel_size
                * prefill_context_model_parallel_size
                * tensor_model_parallel_size,
            )
            .unbind(0)
1438
        )
1439
1440
1441
1442
1443
        group_ranks = [x.tolist() for x in group_ranks]
        _EP = init_model_parallel_group(
            group_ranks, get_world_group().local_rank, backend, group_name="ep"
        )
    # If no EP group needed, _EP remains None
1444

1445
    logger.info_once(
1446
        "rank %s in world size %s is assigned as "
1447
1448
        "DP rank %s, PP rank %s, PCP rank %s, "
        "TP rank %s, EP rank %s",
1449
1450
1451
1452
        rank,
        world_size,
        _DP.rank_in_group,
        _PP.rank_in_group,
1453
        _PCP.rank_in_group,
1454
        _TP.rank_in_group,
1455
        _EP.rank_in_group if _EP is not None else "N/A",
1456
    )
1457

Zhuohan Li's avatar
Zhuohan Li committed
1458

1459
1460
1461
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1462
    prefill_context_model_parallel_size: int = 1,
1463
1464
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
1465
1466
1467
1468
1469
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1470
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
1471
    if not model_parallel_is_initialized():
1472
1473
1474
        initialize_model_parallel(
            tensor_model_parallel_size,
            pipeline_model_parallel_size,
1475
            prefill_context_model_parallel_size,
1476
1477
1478
            decode_context_model_parallel_size,
            backend,
        )
1479
1480
        return

1481
1482
    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
        "tensor parallel group already initialized, but of unexpected size. "
1483
        f"got: {get_tensor_model_parallel_world_size()=} vs. "
1484
1485
        f"wanted: {tensor_model_parallel_size=}"
    )
1486
    pp_world_size = get_pp_group().world_size
1487
    assert pp_world_size == pipeline_model_parallel_size, (
1488
1489
        "pipeline parallel group already initialized, but of unexpected size. "
        f"got: {pp_world_size=} vs. "
1490
1491
        f"wanted: {pipeline_model_parallel_size=}"
    )
1492
1493
1494
1495
1496
1497
    pcp_world_size = get_pcp_group().world_size
    assert pcp_world_size == prefill_context_model_parallel_size, (
        "prefill context parallel group already initialized, but of unexpected size: "
        f"{pcp_world_size=} vs. "
        f"{prefill_context_model_parallel_size=}"
    )
1498
1499


1500
1501
1502
1503
1504
1505
1506
1507
1508
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
1509
1510
    if _PCP is not None:
        _PCP.prepare_communication_buffer_for_model(model)
1511
1512
1513
1514
1515
1516
1517
1518
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)


Zhuohan Li's avatar
Zhuohan Li committed
1519
def model_parallel_is_initialized():
1520
    """Check if tensor and pipeline parallel groups are initialized."""
1521
    return _TP is not None and _PP is not None
1522
1523


1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


1552
def get_tensor_model_parallel_world_size() -> int:
Zhuohan Li's avatar
Zhuohan Li committed
1553
    """Return world size for the tensor model parallel group."""
1554
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1555
1556


1557
def get_tensor_model_parallel_rank() -> int:
Zhuohan Li's avatar
Zhuohan Li committed
1558
    """Return my rank for the tensor model parallel group."""
1559
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1560
1561


1562
def get_decode_context_model_parallel_world_size() -> int:
1563
1564
1565
1566
    """Return world size for the decode context model parallel group."""
    return get_dcp_group().world_size


1567
def get_decode_context_model_parallel_rank() -> int:
1568
1569
1570
1571
    """Return my rank for the decode context model parallel group."""
    return get_dcp_group().rank_in_group


1572
def get_node_count() -> int:
1573
1574
    """Return the total number of nodes in the distributed environment."""
    assert _NODE_COUNT is not None, "distributed environment is not initialized"
1575
1576
1577
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1578
def destroy_model_parallel():
1579
    """Set the groups to none and destroy them."""
1580
    global _TP
1581

1582
1583
1584
1585
    if _TP:
        _TP.destroy()
    _TP = None

1586
1587
1588
1589
1590
    global _DCP
    if _DCP:
        _DCP.destroy()
    _DCP = None

1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
    global _PCP
    if _PCP:
        _PCP.destroy()
    _PCP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1601
1602
1603
1604
1605
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1606
1607
1608
1609
1610
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1611
1612

def destroy_distributed_environment():
1613
    global _WORLD, _NODE_COUNT
1614
1615
1616
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1617
    _NODE_COUNT = None
1618
1619
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1620
1621


1622
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1623
1624
    # Reset environment variable cache
    envs.disable_envs_cache()
1625
    # Ensure all objects are not frozen before cleanup
1626
1627
    gc.unfreeze()

1628
1629
1630
1631
    destroy_model_parallel()
    destroy_distributed_environment()
    if shutdown_ray:
        import ray  # Lazy import Ray
1632

1633
1634
        ray.shutdown()
    gc.collect()
1635
    from vllm.platforms import current_platform
1636

1637
1638
1639
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1640
    try:
1641
1642
        if not current_platform.is_cpu():
            torch._C._host_emptyCache()
1643
    except AttributeError:
1644
        logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5")
1645
1646


1647
def in_the_same_node_as(
1648
    pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0
1649
) -> list[bool]:
1650
    """
1651
1652
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1653
1654
    memory system (shared access to shared memory).
    """
1655
    if isinstance(pg, ProcessGroup):
1656
1657
1658
        assert torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL, (
            "in_the_same_node_as should be tested with a non-NCCL group."
        )
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1669
1670

    # local tensor in each process to store the result
1671
1672
1673
    is_in_the_same_node = torch.tensor(
        [0] * world_size, dtype=torch.int32, device="cpu"
    )
1674
1675
1676
1677
1678
1679

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1680
            if rank == source_rank:
1681
1682
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
1683
                shm.buf[: len(magic_message)] = magic_message
1684
1685
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
1686
1687
                        [shm.name], src=ranks[source_rank], group=pg
                    )
1688
1689
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1690
                is_in_the_same_node[rank] = 1
1691
1692
            else:
                # try to open the shared memory segment
1693
1694
1695
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
1696
1697
                        recv, src=ranks[source_rank], group=pg
                    )
1698
1699
1700
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1701
1702
1703
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
1704
1705
1706
1707
                with patch(
                    "multiprocessing.resource_tracker.register",
                    lambda *args, **kwargs: None,
                ):
1708
                    shm = shared_memory.SharedMemory(name=name)
1709
                if shm.buf[: len(magic_message)] == magic_message:
1710
1711
1712
1713
1714
1715
1716
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1717
1718
1719
1720
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1721
1722
1723

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1724
        if rank == source_rank and shm:
1725
            shm.unlink()
1726

1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
1737
1738


1739
1740
def is_global_first_rank() -> bool:
    """
1741
    Check if the current process is the first rank globally across all
1742
    parallelism strategies (PP, TP, DP, EP, etc.).
1743

1744
1745
1746
    Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
    or `get_pp_group().is_first_rank`, this function checks the global rank
    across all parallelism dimensions.
1747

1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
    Returns:
        bool: True if this is the global first rank (rank 0), False otherwise.
              Returns True if distributed is not initialized (single process).
    """
    try:
        # If world group is available, use it for the most accurate check
        global _WORLD
        if _WORLD is not None:
            return _WORLD.is_first_rank

        # If torch distributed is not initialized, assume single process
        if not torch.distributed.is_initialized():
            return True

        # Fallback to torch's global rank
        return torch.distributed.get_rank() == 0

    except Exception:
        # If anything goes wrong, assume this is the first rank
        return True


1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
def is_local_first_rank() -> bool:
    """
    Check if the current process is the first local rank (rank 0 on its node).
    """
    try:
        # prefer the initialized world group if available
        global _WORLD
        if _WORLD is not None:
            return _WORLD.local_rank == 0

        if not torch.distributed.is_initialized():
            return True

        # fallback to environment-provided local rank if available
        # note: envs.LOCAL_RANK is set when using env:// launchers (e.g., torchrun)
        try:
            return int(envs.LOCAL_RANK) == 0  # type: ignore[arg-type]
        except Exception:
            return torch.distributed.get_rank() == 0
    except Exception:
        return True


1793
def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
1794
1795
1796
1797
1798
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
1799

1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id