parallel_state.py 74.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25

26
import contextlib
27
import gc
28
import pickle
29
import weakref
30
from collections import namedtuple
31
from collections.abc import Callable
32
33
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
34
from datetime import timedelta
35
from multiprocessing import shared_memory
36
from typing import TYPE_CHECKING, Any, Protocol
37
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
38
39

import torch
40
import torch.distributed
41
42
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
43
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
44

45
import vllm.envs as envs
46
from vllm.distributed.device_communicators.base_device_communicator import (
47
48
    DeviceCommunicatorBase,
)
49
from vllm.distributed.utils import StatelessProcessGroup
50
from vllm.logger import init_logger
51
from vllm.utils.import_utils import resolve_obj_by_qualname
52
from vllm.utils.network_utils import get_distributed_init_method
53
from vllm.utils.system_utils import suppress_stdout
54
55
56
from vllm.utils.torch_utils import (
    direct_register_custom_op,
)
57

58
59
60
if TYPE_CHECKING:
    from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator

61

62
63
64
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
65

66

67
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
68

69

70
71
72
73
74
75
76
77
class Handle(Protocol):
    """Minimal async work handle used by P2P send/recv methods."""

    def is_completed(self) -> bool: ...

    def wait(self) -> None: ...


78
def _split_tensor_dict(
79
    tensor_dict: dict[str, torch.Tensor | Any],
80
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
81
82
83
84
85
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
86
87
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
88
89
90
91
92
93
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
94
            device = value.device.type
95
            metadata_list.append(
96
97
                (key, TensorMetadata(device, value.dtype, value.size()))
            )
98
99
            tensor_list.append(value)
        else:
100
            metadata_list.append((key, value))
101
102
103
    return metadata_list, tensor_list


104
_group_name_counter: dict[str, int] = {}
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


120
_groups: dict[str, Callable[[], "GroupCoordinator | None"]] = {}
121
122
123


def _register_group(group: "GroupCoordinator") -> None:
124
    _groups[group.unique_name] = weakref.ref(group)
125
126


127
128
129
130
131
132
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
133
134


135
136
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
137

138

139
140
141
def reduce_scatter(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
142
143
144
145
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
146
    return group._reduce_scatter_out_place(tensor, dim)
147
148


149
150
151
def reduce_scatter_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
152
153
154
155
156
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


157
158
159
def all_gather(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
160
161
162
163
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
164
    return group._all_gather_out_place(tensor, dim)
165
166


167
168
169
def all_gather_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
170
171
172
173
174
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
def patched_fused_scaled_matmul_reduce_scatter_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    # Copied from
    # https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
    if A_scale.numel() > 1:
        if A_scale.shape[:-1] != A.shape[:-1]:
            raise ValueError(
                "For row-wise scaling, the leading dims of A_scale "
                "must match the leading dims of A "
                f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
            )
        A_scale = A_scale.flatten(0, -2).contiguous()
    elif A_scale.numel() != 1:
        raise ValueError(
            "Invalid A_scale shape "
            f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
        )

    C = torch._scaled_mm(
        A.flatten(0, -2).contiguous(),
        B,
        A_scale,
        B_scale,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )
    C = C.view(*output_shape[:-1], B.shape[1])
    res = funcol.reduce_scatter_tensor(
        C,
        reduce_op,
        orig_scatter_dim,  # need original scatter dim for 3D+ output tensor here
        group_name,
    )
    res = funcol.wait_tensor(res)
    return res


def patched_fused_scaled_matmul_reduce_scatter(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
        A,
        B,
        A_scale,
        B_scale,
        reduce_op,
        orig_scatter_dim,
        scatter_dim_after_maybe_reshape,
        group_name,
        output_shape,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )


259
260
261
262
263
direct_register_custom_op(
    op_name="all_reduce",
    op_func=all_reduce,
    fake_impl=all_reduce_fake,
)
264

265
266
267
268
269
direct_register_custom_op(
    op_name="reduce_scatter",
    op_func=reduce_scatter,
    fake_impl=reduce_scatter_fake,
)
270

271
272
273
274
275
direct_register_custom_op(
    op_name="all_gather",
    op_func=all_gather,
    fake_impl=all_gather_fake,
)
276

277
278
279
280
281
282
283
284
# TODO: Remove this once the pytorch fix
# (https://github.com/pytorch/pytorch/pull/165086) gets released,
# in either 2.9.1 or 2.10
direct_register_custom_op(
    op_name="patched_fused_scaled_matmul_reduce_scatter",
    op_func=patched_fused_scaled_matmul_reduce_scatter,
    fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake,
)
285

286

287
288
289
290
291
292
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
293
294
        the processes in the group. It manages both CPU and device
        communication.
295
296
297
298
    """

    # available attributes:
    rank: int  # global rank
299
    ranks: list[int]  # global ranks in the group
300
301
302
303
304
305
306
307
308
309
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
310
311
312
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    # device communicator (if use_device_communicator=True)
313
314
    device_communicator: DeviceCommunicatorBase | None
    mq_broadcaster: Any | None  # shared memory broadcaster
315
316
317

    def __init__(
        self,
318
        group_ranks: list[list[int]],
319
        local_rank: int,
320
        torch_distributed_backend: str | Backend,
321
        use_device_communicator: bool,  # whether to use device communicator
322
        use_message_queue_broadcaster: bool = False,
323
        group_name: str | None = None,
324
    ):
325
326
327
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
328
329
330

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
331
332
333

        self_device_group = None
        self_cpu_group = None
334
335
336

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
337
338
                ranks, backend=torch_distributed_backend
            )
339
340
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
341
342
            with suppress_stdout():
                cpu_group = torch.distributed.new_group(ranks, backend="gloo")
343
344
345
346
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
347
348
349
350
351
                self_device_group = device_group
                self_cpu_group = cpu_group

        assert self_cpu_group is not None
        assert self_device_group is not None
352

353
354
        self.cpu_group = self_cpu_group
        self.device_group = self_device_group
355

356
        from vllm.platforms import current_platform
357

358
        if current_platform.is_cuda_alike():
359
            self.device = torch.device(f"cuda:{local_rank}")
360
361
        elif current_platform.is_xpu():
            self.device = torch.device(f"xpu:{local_rank}")
362
        elif current_platform.is_out_of_tree():
363
            self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
364
365
366
        else:
            self.device = torch.device("cpu")

367
        self.use_device_communicator = use_device_communicator
368
        self.device_communicator = None
369
370
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
371
372
                current_platform.get_device_communicator_cls()
            )
373
374
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
375
                device=self.device,
376
377
                device_group=self.device_group,
                unique_name=self.unique_name,
378
379
            )

380
381
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

382
        self.mq_broadcaster: MessageQueue | None = None
383
384
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
385
386
                self.cpu_group, 1 << 22, 6
            )
387

388
389
        from vllm.platforms import current_platform

390
391
392
393
394
395
396
        self.use_custom_op_call = (
            current_platform.is_cuda_alike() or current_platform.is_tpu()
        )

        self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
            torch.ops._C, "init_shm_manager"
        )
397

398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    def create_mq_broadcaster(
        self, writer_rank=0, external_writer_handle=None, blocking=True
    ):
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

        return MessageQueue.create_from_process_group(
            self.cpu_group,
            1 << 22,
            6,
            writer_rank=writer_rank,
            external_writer_handle=external_writer_handle,
            blocking=blocking,
        )

    def create_single_reader_mq_broadcasters(
        self, reader_rank_in_group=0, blocking=False
    ):
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

        return MessageQueue.create_from_process_group_single_reader(
            self.cpu_group,
            1 << 22,
            6,
            reader_rank=self.ranks[reader_rank_in_group],
            blocking=blocking,
        )

425
426
427
428
429
430
431
432
433
434
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

435
436
437
438
439
440
441
442
443
444
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
460
    def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None):
461
462
463
464
465
466
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

467
468
469
470
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
471
472
473
            CudaCommunicator,
        )

474
475
476
477
478
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
479
480
481
482
483
484
485

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

486
        with torch.cuda.stream(stream), maybe_ca_context:
487
            yield graph_capture_context
488
489
490

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
491
492
493
494
495
496
497
498
499
500
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
501
502
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
503
504
505
506
507
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

508
        if self.use_custom_op_call:
509
            return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
510
511
        else:
            return self._all_reduce_out_place(input_)
512

513
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
514
515
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
516
        return self.device_communicator.all_reduce(input_)
517
518
519
520
521
522
523

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
524
525
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
526

527
        if self.use_custom_op_call:
528
529
530
            return torch.ops.vllm.all_gather(
                input_, dim, world_size, group_name=self.unique_name
            )
531
532
533
        else:
            return self._all_gather_out_place(input_, dim)

534
    def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
535
536
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
537
        return self.device_communicator.all_gather(input_, dim)
538

539
540
    def all_gatherv(
        self,
541
        input_: torch.Tensor | list[torch.Tensor],
542
        dim: int = 0,
543
        sizes: list[int] | None = None,
544
    ):
545
546
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
547
548
        return self.device_communicator.all_gatherv(input_, dim, sizes)

549
    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
550
551
552
553
554
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
555
556
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
557

558
        if self.use_custom_op_call:
559
560
561
            return torch.ops.vllm.reduce_scatter(
                input_, dim, world_size, group_name=self.unique_name
            )
562
563
564
        else:
            return self._reduce_scatter_out_place(input_, dim)

565
    def reduce_scatterv(
566
        self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
567
    ) -> torch.Tensor:
568
569
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
570
571
        return self.device_communicator.reduce_scatterv(input_, dim, sizes)

572
    def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
573
574
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
575
576
        return self.device_communicator.reduce_scatter(input_, dim)

577
578
    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
579
    ) -> torch.Tensor | None:
580
581
582
583
584
585
586
587
588
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
589
590
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
591
        return self.device_communicator.gather(input_, dst, dim)
592
593
594
595
596
597
598
599
600
601
602

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
603
604
605
        torch.distributed.broadcast(
            input_, src=self.ranks[src], group=self.device_group
        )
606
607
        return input_

608
    def broadcast_object(self, obj: Any | None = None, src: int = 0):
609
610
611
612
613
614
615
616
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
617
618
619
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
620
        if self.rank_in_group == src:
621
622
623
            torch.distributed.broadcast_object_list(
                [obj], src=self.ranks[src], group=self.cpu_group
            )
624
625
626
            return obj
        else:
            recv = [None]
627
628
629
            torch.distributed.broadcast_object_list(
                recv, src=self.ranks[src], group=self.cpu_group
            )
630
631
            return recv[0]

632
    def broadcast_object_list(
633
        self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
634
    ):
635
636
637
638
639
640
641
642
643
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
644
645
646
        torch.distributed.broadcast_object_list(
            obj_list, src=self.ranks[src], group=self.device_group
        )
647
648
        return obj_list

649
650
651
652
653
654
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

655
        assert dst != self.rank_in_group, (
656
            "Invalid destination rank. Destination rank is the same "
657
658
            "as the current rank."
        )
659
660
661
662

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

663
664
665
        size_tensor = torch.tensor(
            [object_tensor.numel()], dtype=torch.long, device="cpu"
        )
666
667
668

        # Send object size

669
        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
670
671

        # Send object
672
        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
673
674
675
676
677
678
679
680
681

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

682
        assert src != self.rank_in_group, (
683
684
685
686
687
688
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
689
690
691
        rank_size = torch.distributed.recv(
            size_tensor, src=self.ranks[src], group=self.cpu_group
        )
692
693
694
695
696

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
697
698
            device="cpu",
        )
699

700
701
702
        rank_object = torch.distributed.recv(
            object_tensor, src=self.ranks[src], group=self.cpu_group
        )
703
704

        assert rank_object == rank_size, (
705
706
            "Received object sender rank does not match the size sender rank."
        )
707
708
709
710
711

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

712
713
    def broadcast_tensor_dict(
        self,
714
        tensor_dict: dict[str, torch.Tensor | Any] | None = None,
715
        src: int = 0,
716
717
718
        group: ProcessGroup | None = None,
        metadata_group: ProcessGroup | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
719
720
721
722
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
723
        if not torch.distributed.is_initialized() or self.world_size == 1:
724
725
726
727
728
729
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

730
731
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
732
            metadata_list: list[tuple[Any, Any]] = []
733
734
735
            assert isinstance(tensor_dict, dict), (
                f"Expecting a dictionary, got {type(tensor_dict)}"
            )
736
737
738
739
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
740
            self.broadcast_object(metadata_list, src=src)
741
742
743
744
745
746
747
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
748
749
750
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=metadata_group, async_op=True
                    )
751
752
                else:
                    # use group for GPU tensors
753
754
755
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=group, async_op=True
                    )
756
757
758
759
760
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
761
            metadata_list = self.broadcast_object(None, src=src)
762
763
            tensor_dict = {}
            async_handles = []
764
            for key, value in metadata_list:
765
                if isinstance(value, TensorMetadata):
766
767
768
                    tensor = torch.empty(
                        value.size, dtype=value.dtype, device=value.device
                    )
769
770
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
771
                        tensor_dict[key] = tensor
772
773
774
775
776
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
777
                            src=self.ranks[src],
778
                            group=metadata_group,
779
780
                            async_op=True,
                        )
781
782
                    else:
                        # use group for GPU tensors
783
                        handle = torch.distributed.broadcast(
784
785
                            tensor, src=self.ranks[src], group=group, async_op=True
                        )
786
                    async_handles.append(handle)
787
                    tensor_dict[key] = tensor
788
                else:
789
                    tensor_dict[key] = value
790
791
792
793
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

794
795
796
797
798
799
800
801
802
803
804
805
806
807
    def _should_use_all_gather(
        self,
        key: str,
        numel: int,
        all_gather_group: "GroupCoordinator | None",
        all_gather_tensors: dict[str, bool] | None,
    ) -> bool:
        if all_gather_group is None:
            return False
        use_all_gather = numel % all_gather_group.world_size == 0
        if all_gather_tensors is not None:
            use_all_gather = all_gather_tensors.get(key, use_all_gather)
        return use_all_gather

808
809
    def send_tensor_dict(
        self,
810
811
        tensor_dict: dict[str, torch.Tensor | Any],
        dst: int | None = None,
812
        all_gather_group: "GroupCoordinator | None" = None,
813
814
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
815
816
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
832
833
834
835
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
        handles = self.isend_tensor_dict(
            tensor_dict,
            dst=dst,
            all_gather_group=all_gather_group,
            all_gather_tensors=all_gather_tensors,
        )
        for handle in handles:
            handle.wait()
        return None

    def isend_tensor_dict(
        self,
        tensor_dict: dict[str, torch.Tensor | Any],
        dst: int | None = None,
        all_gather_group: "GroupCoordinator | None" = None,
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> list[Handle]:
        if self.world_size <= 1:
            return []

        if self.use_cpu_custom_send_recv:
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
            # custom device communicator path is synchronous
            self.device_communicator.send_tensor_dict(  # type: ignore
                tensor_dict, dst
            )
            return []

865
866
867
868
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
869

870
871
872
873
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
874
            dst = (self.rank_in_group + 1) % self.world_size
875
876
877
878
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        self.send_object(metadata_list, dst=dst)
879

880
        tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
881
882
        assert len(tensor_keys) == len(tensor_list)

883
        handles: list[Handle] = []
884
        for key, tensor in zip(tensor_keys, tensor_list):
885
886
            if tensor.numel() == 0:
                continue
887

888
889
890
            if self._should_use_all_gather(
                key, tensor.numel(), all_gather_group, all_gather_tensors
            ):
891
892
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

893
894
895
896
897
898
899
900
901
            comm_group = metadata_group if tensor.is_cpu else group
            handle = torch.distributed.isend(
                tensor, dst=self.ranks[dst], group=comm_group
            )
            if tensor.is_cuda:
                tensor.record_stream(torch.cuda.current_stream(tensor.device))
            handles.append(handle)

        return handles
902
903
904

    def recv_tensor_dict(
        self,
905
        src: int | None = None,
906
        all_gather_group: "GroupCoordinator | None" = None,
907
908
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
909
910
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
926
927
928
929
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
        tensor_dict, handles, postprocess = self.irecv_tensor_dict(
            src=src,
            all_gather_group=all_gather_group,
            all_gather_tensors=all_gather_tensors,
        )
        for handle in handles:
            handle.wait()
        for fn in postprocess:
            fn()
        return tensor_dict

    def irecv_tensor_dict(
        self,
        src: int | None = None,
        all_gather_group: "GroupCoordinator | None" = None,
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> tuple[
        dict[str, torch.Tensor | Any] | None,
        list[Handle],
        list[Callable[[], None]],
    ]:
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None, [], []
        if self.use_cpu_custom_send_recv:
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
            # custom device communicator path is synchronous
            sync_tensor_dict = self.device_communicator.recv_tensor_dict(  # type: ignore
                src
            )
            return sync_tensor_dict, [], []

962
963
964
965
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
966

967
968
969
970
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
971
            src = (self.rank_in_group - 1) % self.world_size
972
973
974
        assert src < self.world_size, f"Invalid src rank ({src})"

        recv_metadata_list = self.recv_object(src=src)
975
        tensor_dict: dict[str, Any] = {}
976
977
978
        handles: list[Handle] = []
        postprocess: list[Callable[[], None]] = []

979
980
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
981
982
                full_tensor = torch.empty(
                    value.size, dtype=value.dtype, device=value.device
983
                )
984
985
986
                if full_tensor.numel() == 0:
                    tensor_dict[key] = full_tensor
                    continue
987

988
989
990
991
992
993
994
995
996
997
                if self._should_use_all_gather(
                    key, full_tensor.numel(), all_gather_group, all_gather_tensors
                ):
                    orig_shape = full_tensor.shape
                    slice_tensor = full_tensor.reshape(all_gather_size, -1)[
                        all_gather_rank
                    ]
                    comm_group = metadata_group if slice_tensor.is_cpu else group
                    handle = torch.distributed.irecv(
                        slice_tensor, src=self.ranks[src], group=comm_group
998
                    )
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
                    handles.append(handle)

                    def _postprocess(
                        key: str = key,
                        slice_tensor: torch.Tensor = slice_tensor,
                        orig_shape: tuple[int, ...] = tuple(orig_shape),
                        all_gather_group=all_gather_group,
                    ) -> None:
                        assert all_gather_group is not None
                        tensor_dict[key] = all_gather_group.all_gather(
                            slice_tensor, dim=0
                        ).reshape(orig_shape)

                    postprocess.append(_postprocess)
                    tensor_dict[key] = slice_tensor
1014
                else:
1015
1016
1017
                    comm_group = metadata_group if full_tensor.is_cpu else group
                    handle = torch.distributed.irecv(
                        full_tensor, src=self.ranks[src], group=comm_group
1018
                    )
1019
1020
                    handles.append(handle)
                    tensor_dict[key] = full_tensor
1021
            else:
1022
                tensor_dict[key] = value
1023
1024

        return tensor_dict, handles, postprocess
1025

1026
1027
1028
1029
1030
1031
1032
1033
1034
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

1035
    def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
1036
        """Sends a tensor to the destination rank in a blocking way"""
1037
        """NOTE: `dst` is the local rank of the destination rank."""
1038
1039
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
1040
        self.device_communicator.send(tensor, dst)
1041

1042
    def recv(
1043
        self, size: torch.Size, dtype: torch.dtype, src: int | None = None
1044
    ) -> torch.Tensor:
1045
1046
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
1047
1048
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
1049
        return self.device_communicator.recv(size, dtype, src)
1050

1051
    def destroy(self):
1052
        if hasattr(self, "device_group"):
1053
            torch.distributed.destroy_process_group(self.device_group)
1054
1055
            del self.device_group
        if hasattr(self, "cpu_group"):
1056
            torch.distributed.destroy_process_group(self.cpu_group)
1057
            del self.cpu_group
1058
1059
        if self.device_communicator is not None:
            self.device_communicator.destroy()
1060
1061
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
1062

1063
1064
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
1065
            self.device_communicator.prepare_communication_buffer_for_model(model)
1066

1067
    def dispatch_router_logits(
1068
1069
1070
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
1071
        is_sequence_parallel: bool = False,
1072
1073
1074
1075
1076
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
    ):
1077
        if self.device_communicator is not None:
1078
            return self.device_communicator.dispatch_router_logits(
1079
1080
1081
1082
                hidden_states,
                router_logits,
                is_sequence_parallel,
                extra_tensors,
1083
            )
1084
1085
        else:
            return hidden_states, router_logits
1086

1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
    def dispatch(
        self,
        hidden_states: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        is_sequence_parallel: bool = False,
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
        | tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ):
        if self.device_communicator is not None:
            return self.device_communicator.dispatch(
                hidden_states,
                topk_weights,
                topk_ids,
                is_sequence_parallel,
                extra_tensors,
            )
        else:
            return hidden_states, topk_weights, topk_ids

1109
1110
1111
    def combine(
        self, hidden_states, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
1112
        if self.device_communicator is not None:
1113
            return self.device_communicator.combine(hidden_states, is_sequence_parallel)
1114
1115
        else:
            return hidden_states
1116

1117

1118
_WORLD: GroupCoordinator | None = None
1119
_INNER_DP_WORLD: GroupCoordinator | None = None
1120
_NODE_COUNT: int | None = None
1121
1122
1123


def get_world_group() -> GroupCoordinator:
1124
    assert _WORLD is not None, "world group is not initialized"
1125
1126
1127
    return _WORLD


1128
1129
1130
1131
1132
def get_inner_dp_world_group() -> GroupCoordinator:
    assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
    return _INNER_DP_WORLD


1133
1134
1135
def init_world_group(
    ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
1136
1137
1138
1139
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
1140
        use_device_communicator=False,
1141
        group_name="world",
1142
1143
1144
    )


1145
def init_model_parallel_group(
1146
    group_ranks: list[list[int]],
1147
1148
1149
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
1150
    group_name: str | None = None,
1151
    use_device_communicator: bool = True,
1152
) -> GroupCoordinator:
1153
1154
1155
1156
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
1157
        use_device_communicator=use_device_communicator,
1158
        use_message_queue_broadcaster=use_message_queue_broadcaster,
1159
        group_name=group_name,
1160
1161
1162
    )


1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
def _init_stateless_group(
    group_ranks: list[list[int]],
    group_name: str,
    group_ports: list[list[int]],
    host: str,
    backend: str,
    use_device_communicator: bool = True,
) -> "StatelessGroupCoordinator":
    """Create a StatelessGroupCoordinator with the given parameters."""
    from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator

    world = get_world_group()
    return StatelessGroupCoordinator(
        group_ranks=group_ranks,
        local_rank=world.local_rank,
        torch_distributed_backend=backend,
        use_device_communicator=use_device_communicator,
        group_name=group_name,
        host=host,
        group_ports=group_ports,
        global_rank=world.rank,
        global_world_size=world.world_size,
    )


def _replace_active_groups(
    *,
    world: GroupCoordinator | None,
    dp: GroupCoordinator | None,
    ep: GroupCoordinator | None,
    eplb: GroupCoordinator | None,
    node_count: int | None,
) -> None:
    """Destroy the current DP/EP/WORLD/EPLB groups and replace them.

    Destruction is collective — all ranks in the old groups must call this
    function together.  Pass all-``None`` to tear down without replacement.
    """
    global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
    for group in (_DP, _EP, _WORLD, _EPLB):
        if group is not None:
            group.destroy()
    _WORLD = world
    _DP = dp
    _EP = ep
    _EPLB = eplb
    _NODE_COUNT = node_count


1212
_TP: GroupCoordinator | None = None
1213
1214
1215


def get_tp_group() -> GroupCoordinator:
1216
    assert _TP is not None, "tensor model parallel group is not initialized"
1217
1218
1219
    return _TP


1220
_DCP: GroupCoordinator | None = None
1221
1222
1223


def get_dcp_group() -> GroupCoordinator:
1224
    assert _DCP is not None, "decode context model parallel group is not initialized"
1225
1226
1227
1228
1229
1230
    return _DCP


# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group

1231
_PP: GroupCoordinator | None = None
1232

1233
1234
1235
1236
1237
1238

def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, "pipeline model parallel group is not initialized"
    return _PP


1239
_DP: GroupCoordinator | None = None
1240
1241
1242


def get_dp_group() -> GroupCoordinator:
1243
    assert _DP is not None, "data parallel group is not initialized"
1244
1245
    return _DP

1246

1247
_EP: GroupCoordinator | None = None
1248
1249
1250


def get_ep_group() -> GroupCoordinator:
1251
1252
1253
1254
1255
    assert _EP is not None, (
        "expert parallel group is not initialized. "
        "EP group is only created for MoE models with num_experts > 0. "
        "This function should only be called for MoE models."
    )
1256
1257
1258
    return _EP


1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
_EPLB: GroupCoordinator | None = None


def get_eplb_group() -> GroupCoordinator:
    assert _EPLB is not None, (
        "EPLB group is not initialized. "
        "EPLB group is only created for MoE models when EPLB is enabled. "
        "Ensure parallel_config.enable_eplb is True."
    )
    return _EPLB


1271
1272
1273
1274
1275
1276
_PCP: GroupCoordinator | None = None


def get_pcp_group() -> GroupCoordinator:
    assert _PCP is not None, "prefill context parallel group is not initialized"
    return _PCP
1277
1278


1279
@contextmanager
1280
def graph_capture(device: torch.device):
1281
1282
    """
    `graph_capture` is a context manager which should surround the code that
1283
1284
    is capturing the CUDA graph. Its main purpose is to ensure that some
    operations will be run after the graph is captured, before the graph
1285
1286
1287
1288
1289
1290
1291
1292
1293
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
1294
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
1295
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
1296
1297
        yield context

1298

1299
logger = init_logger(__name__)
1300

1301
_ENABLE_CUSTOM_ALL_REDUCE = True
1302
1303


1304
1305
1306
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
1307

Zhuohan Li's avatar
Zhuohan Li committed
1308

1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
def _init_elastic_ep_world(
    config, local_rank: int, backend: str, rank: int, world_size: int
) -> None:
    from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator

    global _WORLD, _NODE_COUNT
    assert _WORLD is None, "world group already initialized"
    parallel_config = config.parallel_config
    global_rank = parallel_config.data_parallel_rank * world_size + rank
    global_world_size = parallel_config.world_size_across_dp
    all_ranks = list(range(global_world_size))
    group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
    if global_rank in all_ranks:
        group_ranks = [all_ranks]
    group_ports = [parallel_config.get_next_stateless_world_group_port()]
    world = StatelessGroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
        use_device_communicator=False,
        group_name="world",
        host=parallel_config.data_parallel_master_ip,
        group_ports=group_ports,
        global_rank=global_rank,
        global_world_size=global_world_size,
    )
    assert parallel_config.nnodes_within_dp == 1, (
        "Elastic EP is not supported with multi-node TP/PP"
    )
    _NODE_COUNT = _node_count(world.tcp_store_group)
    _WORLD = world


1342
1343
1344
1345
1346
1347
def init_distributed_environment(
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
    local_rank: int = -1,
    backend: str = "nccl",
1348
    timeout: timedelta | None = None,
1349
):
1350
    logger.debug(
1351
1352
1353
1354
1355
1356
1357
        "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
        world_size,
        rank,
        local_rank,
        distributed_init_method,
        backend,
    )
1358
    from vllm.config import get_current_vllm_config_or_none
1359

1360
    config = get_current_vllm_config_or_none()
1361
    enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
1362
    if (
1363
1364
        config is not None
        and config.parallel_config.distributed_executor_backend != "external_launcher"
1365
1366
1367
1368
        and (
            config.parallel_config.nnodes > 1
            or config.parallel_config.data_parallel_size > 1
        )
1369
        and not enable_elastic_ep
1370
    ):
1371
1372
1373
1374
1375
1376
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392

        # Use appropriate IP and port based on configuration
        if parallel_config.nnodes > 1:
            ip = parallel_config.master_addr
            port = parallel_config.master_port
            distributed_init_method = get_distributed_init_method(ip, port)
        else:
            ip = parallel_config.data_parallel_master_ip
            port = parallel_config.get_next_dp_init_port()
            distributed_init_method = get_distributed_init_method(ip, port)
            logger.debug(
                "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
                world_size,
                rank,
                distributed_init_method,
            )
1393
    if not torch.distributed.is_initialized():
1394
1395
1396
1397
1398
1399
1400
1401
        logger.info(
            "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
            world_size,
            rank,
            local_rank,
            distributed_init_method,
            backend,
        )
1402
1403
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
1404
1405
            "distributed environment"
        )
1406
1407
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
1408
1409
1410
                "Distributed backend %s is not available; falling back to gloo.",
                backend,
            )
1411
            assert torch.distributed.is_gloo_available(), (
1412
1413
                "Fallback Gloo backend is not available."
            )
1414
            backend = "gloo"
1415
1416
1417
1418
1419
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
1420
            rank=rank,
1421
1422
            timeout=timeout,
        )
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
        if enable_elastic_ep:
            tp_pp_cpu_group = torch.distributed.new_group(
                backend="gloo", timeout=timeout
            )
            if _node_count(tp_pp_cpu_group) > 1:
                # NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
                # to initialize all DP/EP groups, hence all ranks within TP/PP group
                # must reside on the same node
                raise RuntimeError(
                    "Elastic EP is not yet supported with multi-node TP/PP"
                )

1435
1436
1437
1438
1439
1440
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
1441
        local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
1442
    global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
1443
1444
1445
    if enable_elastic_ep:
        _init_elastic_ep_world(config, local_rank, backend, rank, world_size)
        return
1446
    if _WORLD is None:
1447
        ranks = list(range(torch.distributed.get_world_size()))
1448
        _WORLD = init_world_group(ranks, local_rank, backend)
1449
        if config is not None and config.parallel_config.nnodes > 1:
1450
1451
1452
            _NODE_COUNT = config.parallel_config.nnodes
        else:
            _NODE_COUNT = _node_count(_WORLD.cpu_group)
1453
        logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
1454
1455
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
1456
1457
            "world group already initialized with a different world size"
        )
1458
    if config is not None and config.parallel_config.nnodes_within_dp > 1:
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
        if parallel_config.data_parallel_size > 1:
            world_size_inner_dp = parallel_config.world_size
            group_ranks = [
                [dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
                for dp_rank in range(parallel_config.data_parallel_size)
            ]
            _INNER_DP_WORLD = init_model_parallel_group(
                group_ranks,
                get_world_group().local_rank,
                backend,
                use_message_queue_broadcaster=True,
                group_name="inner_dp_world",
                use_device_communicator=False,
            )
        else:
            _INNER_DP_WORLD = _WORLD
1475
1476


Zhuohan Li's avatar
Zhuohan Li committed
1477
1478
1479
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
1480
    prefill_context_model_parallel_size: int = 1,
1481
1482
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
Zhuohan Li's avatar
Zhuohan Li committed
1483
1484
) -> None:
    """
1485
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
1486
1487

    Arguments:
1488
1489
1490
1491
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.
1492
        backend: name of torch distributed communication backend.
1493
1494

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1495
1496
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1497
1498
1499
1500
1501
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1502
1503
1504
1505
1506
1507
1508
1509
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()

1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
    from vllm.config import get_current_vllm_config

    config = get_current_vllm_config()
    data_parallel_size = config.parallel_config.data_parallel_size
    enable_elastic_ep = config.parallel_config.enable_elastic_ep
    if enable_elastic_ep:
        # Use stateless world group for global information
        world_size = get_world_group().world_size
        rank = get_world_group().rank
        backend = backend or "nccl"
        tp_pp_pcp_size = (
            tensor_model_parallel_size
            * pipeline_model_parallel_size
            * prefill_context_model_parallel_size
        )
        local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
            pipeline_model_parallel_size,
            prefill_context_model_parallel_size,
            tensor_model_parallel_size,
        )
    else:
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        backend = backend or torch.distributed.get_backend(
            get_world_group().device_group
        )
1536
1537
1538
1539
1540
1541
1542
1543

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1544
1545
1546
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1547
1548
1549
1550
1551
        -1,
        data_parallel_size,
        pipeline_model_parallel_size,
        prefill_context_model_parallel_size,
        tensor_model_parallel_size,
1552
    )  # noqa
1553

1554
1555
    # Build the tensor model-parallel groups.
    global _TP
1556
    assert _TP is None, "tensor model parallel group is already initialized"
1557
1558
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1559
1560
1561
    if enable_elastic_ep:
        group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
        group_ranks = [x.tolist() for x in group_ranks]
1562
    # message queue broadcaster is only used in tensor model parallel group
1563
1564
1565
1566
1567
1568
1569
    _TP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="tp",
    )
1570

1571
1572
    # Build the DCP model-parallel groups.
    global _DCP
1573
    assert _DCP is None, "decode context model parallel group is already initialized"
1574
1575
    # Note(hc): In the current implementation of decode context parallel,
    # dcp_size must not exceed tp_size, because the world size does not
1576
    # change by DCP, it simply reuses the GPUs of TP group, and split one
1577
    # TP group into tp_size//dcp_size DCP groups.
1578
    group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
1579
    group_ranks = [x.tolist() for x in group_ranks]
1580
1581
1582
1583
1584
    if enable_elastic_ep:
        group_ranks = local_all_ranks.reshape(
            -1, decode_context_model_parallel_size
        ).unbind(0)
        group_ranks = [x.tolist() for x in group_ranks]
1585
1586
1587
1588
1589
1590
1591
    _DCP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="dcp",
    )
1592

1593
1594
1595
1596
1597
1598
1599
1600
    global _PCP
    assert _PCP is None, "prefill context parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(3, 4)
        .reshape(-1, prefill_context_model_parallel_size)
        .unbind(0)
    )
    group_ranks = [x.tolist() for x in group_ranks]
1601
1602
1603
1604
1605
1606
1607
    if enable_elastic_ep:
        group_ranks = (
            local_all_ranks.transpose(1, 2)
            .reshape(-1, prefill_context_model_parallel_size)
            .unbind(0)
        )
        group_ranks = [x.tolist() for x in group_ranks]
1608
1609
1610
1611
    _PCP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pcp"
    )

1612
    # Build the pipeline model-parallel groups.
1613
    global _PP
1614
1615
    assert _PP is None, "pipeline model parallel group is already initialized"
    group_ranks = (
1616
        all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
1617
    )
1618
    group_ranks = [x.tolist() for x in group_ranks]
1619
1620
1621
1622
1623
1624
1625
    if enable_elastic_ep:
        group_ranks = (
            local_all_ranks.transpose(0, 2)
            .reshape(-1, pipeline_model_parallel_size)
            .unbind(0)
        )
        group_ranks = [x.tolist() for x in group_ranks]
1626
1627
1628
    _PP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pp"
    )
1629

1630
    global _DP
1631
    assert _DP is None, "data parallel group is already initialized"
1632
    group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
1633
    group_ranks = [x.tolist() for x in group_ranks]
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
    if enable_elastic_ep:
        parallel_config = config.parallel_config
        dp_ports = [
            parallel_config.get_next_stateless_dp_group_port() for _ in group_ranks
        ]
        _DP = _init_stateless_group(
            group_ranks,
            "dp",
            dp_ports,
            parallel_config.data_parallel_master_ip,
            backend,
        )
    else:
        _DP = init_model_parallel_group(
            group_ranks, get_world_group().local_rank, backend, group_name="dp"
        )
1650

1651
    global _EP
1652
    assert _EP is None, "expert parallel group is already initialized"
1653
    # Don't create EP group for dense models.
1654
    if config.model_config is None or config.model_config.is_moe:
1655
1656
1657
1658
1659
1660
1661
1662
1663
        group_ranks = (
            all_ranks.transpose(1, 2)
            .reshape(
                -1,
                data_parallel_size
                * prefill_context_model_parallel_size
                * tensor_model_parallel_size,
            )
            .unbind(0)
1664
        )
1665
        group_ranks = [x.tolist() for x in group_ranks]
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
        if enable_elastic_ep:
            parallel_config = config.parallel_config
            ep_ports = [
                parallel_config.get_next_stateless_ep_group_port() for _ in group_ranks
            ]
            _EP = _init_stateless_group(
                group_ranks,
                "ep",
                ep_ports,
                parallel_config.data_parallel_master_ip,
                backend,
            )
        else:
            _EP = init_model_parallel_group(
                group_ranks, get_world_group().local_rank, backend, group_name="ep"
            )
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693

        # Create EPLB group with the same ranks as EP if EPLB is enabled.
        # This is a separate process group to isolate EPLB communications
        # from MoE forward pass collectives and prevent deadlocks when
        # using torch.distributed in execution with torch.distributed in EPLB.
        global _EPLB
        assert _EPLB is None, "EPLB group is already initialized"
        if (
            config is not None
            and config.parallel_config is not None
            and config.parallel_config.enable_eplb
        ):
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
            if enable_elastic_ep:
                eplb_ports = [
                    parallel_config.get_next_stateless_eplb_group_port()
                    for _ in group_ranks
                ]
                _EPLB = _init_stateless_group(
                    group_ranks,
                    "eplb",
                    eplb_ports,
                    parallel_config.data_parallel_master_ip,
                    backend,
                )
            else:
                _EPLB = init_model_parallel_group(
                    group_ranks,
                    get_world_group().local_rank,
                    backend,
                    group_name="eplb",
                )
1713
    # If no EP group needed, _EP remains None
1714
    # If no EPLB group needed, _EPLB remains None
1715

1716
    logger.info_once(
1717
        "rank %s in world size %s is assigned as "
1718
        "DP rank %s, PP rank %s, PCP rank %s, "
1719
        "TP rank %s, EP rank %s, EPLB rank %s",
1720
1721
1722
1723
        rank,
        world_size,
        _DP.rank_in_group,
        _PP.rank_in_group,
1724
        _PCP.rank_in_group,
1725
        _TP.rank_in_group,
1726
        _EP.rank_in_group if _EP is not None else "N/A",
1727
        _EPLB.rank_in_group if _EPLB is not None else "N/A",
1728
    )
1729

Zhuohan Li's avatar
Zhuohan Li committed
1730

1731
1732
1733
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1734
    prefill_context_model_parallel_size: int = 1,
1735
1736
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
1737
1738
1739
1740
1741
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1742
1743
1744
1745
1746
    world_group = get_world_group()
    if hasattr(world_group, "backend"):
        backend = backend or world_group.backend
    else:
        backend = backend or torch.distributed.get_backend(world_group.device_group)
1747
    if not model_parallel_is_initialized():
1748
1749
1750
        initialize_model_parallel(
            tensor_model_parallel_size,
            pipeline_model_parallel_size,
1751
            prefill_context_model_parallel_size,
1752
1753
1754
            decode_context_model_parallel_size,
            backend,
        )
1755
1756
        return

1757
1758
    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
        "tensor parallel group already initialized, but of unexpected size. "
1759
        f"got: {get_tensor_model_parallel_world_size()=} vs. "
1760
1761
        f"wanted: {tensor_model_parallel_size=}"
    )
1762
    pp_world_size = get_pp_group().world_size
1763
    assert pp_world_size == pipeline_model_parallel_size, (
1764
1765
        "pipeline parallel group already initialized, but of unexpected size. "
        f"got: {pp_world_size=} vs. "
1766
1767
        f"wanted: {pipeline_model_parallel_size=}"
    )
1768
1769
1770
1771
1772
1773
    pcp_world_size = get_pcp_group().world_size
    assert pcp_world_size == prefill_context_model_parallel_size, (
        "prefill context parallel group already initialized, but of unexpected size: "
        f"{pcp_world_size=} vs. "
        f"{prefill_context_model_parallel_size=}"
    )
1774
1775


1776
1777
1778
1779
1780
1781
1782
1783
1784
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
1785
1786
    if _PCP is not None:
        _PCP.prepare_communication_buffer_for_model(model)
1787
1788
1789
1790
1791
1792
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)
1793
1794
    if _EPLB is not None:
        _EPLB.prepare_communication_buffer_for_model(model)
1795
1796


Zhuohan Li's avatar
Zhuohan Li committed
1797
def model_parallel_is_initialized():
1798
    """Check if tensor and pipeline parallel groups are initialized."""
1799
    return _TP is not None and _PP is not None
1800
1801


1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


1830
def get_tensor_model_parallel_world_size() -> int:
Zhuohan Li's avatar
Zhuohan Li committed
1831
    """Return world size for the tensor model parallel group."""
1832
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1833
1834


1835
def get_tensor_model_parallel_rank() -> int:
Zhuohan Li's avatar
Zhuohan Li committed
1836
    """Return my rank for the tensor model parallel group."""
1837
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1838
1839


1840
def get_decode_context_model_parallel_world_size() -> int:
1841
1842
1843
1844
    """Return world size for the decode context model parallel group."""
    return get_dcp_group().world_size


1845
def get_decode_context_model_parallel_rank() -> int:
1846
1847
1848
1849
    """Return my rank for the decode context model parallel group."""
    return get_dcp_group().rank_in_group


1850
def get_node_count() -> int:
1851
1852
    """Return the total number of nodes in the distributed environment."""
    assert _NODE_COUNT is not None, "distributed environment is not initialized"
1853
1854
1855
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1856
def destroy_model_parallel():
1857
    """Set the groups to none and destroy them."""
1858
    global _TP
1859

1860
1861
1862
1863
    if _TP:
        _TP.destroy()
    _TP = None

1864
1865
1866
1867
1868
    global _DCP
    if _DCP:
        _DCP.destroy()
    _DCP = None

1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
    global _PCP
    if _PCP:
        _PCP.destroy()
    _PCP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1879
1880
1881
1882
1883
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1884
1885
1886
1887
1888
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1889
1890
1891
1892
1893
    global _EPLB
    if _EPLB:
        _EPLB.destroy()
    _EPLB = None

1894
1895

def destroy_distributed_environment():
1896
    global _WORLD, _NODE_COUNT
1897
1898
1899
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1900
    _NODE_COUNT = None
1901
1902
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1903
1904


1905
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1906
1907
    # Reset environment variable cache
    envs.disable_envs_cache()
1908
    # Ensure all objects are not frozen before cleanup
1909
1910
    gc.unfreeze()

1911
1912
1913
1914
    destroy_model_parallel()
    destroy_distributed_environment()
    if shutdown_ray:
        import ray  # Lazy import Ray
1915

1916
1917
        ray.shutdown()
    gc.collect()
1918
    from vllm.platforms import current_platform
1919

1920
1921
1922
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1923
    try:
1924
1925
        if not current_platform.is_cpu():
            torch._C._host_emptyCache()
1926
    except AttributeError:
1927
        logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5")
1928
1929


1930
def in_the_same_node_as(
1931
    pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0
1932
) -> list[bool]:
1933
    """
1934
1935
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1936
1937
    memory system (shared access to shared memory).
    """
1938
    if isinstance(pg, ProcessGroup):
1939
1940
1941
        assert torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL, (
            "in_the_same_node_as should be tested with a non-NCCL group."
        )
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1952
1953

    # local tensor in each process to store the result
1954
1955
1956
    is_in_the_same_node = torch.tensor(
        [0] * world_size, dtype=torch.int32, device="cpu"
    )
1957
1958
1959
1960
1961
1962

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1963
            if rank == source_rank:
1964
1965
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
1966
                shm.buf[: len(magic_message)] = magic_message
1967
1968
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
1969
1970
                        [shm.name], src=ranks[source_rank], group=pg
                    )
1971
1972
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1973
                is_in_the_same_node[rank] = 1
1974
1975
            else:
                # try to open the shared memory segment
1976
1977
1978
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
1979
1980
                        recv, src=ranks[source_rank], group=pg
                    )
1981
1982
1983
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1984
1985
1986
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
1987
1988
1989
1990
                with patch(
                    "multiprocessing.resource_tracker.register",
                    lambda *args, **kwargs: None,
                ):
1991
                    shm = shared_memory.SharedMemory(name=name)
1992
                if shm.buf[: len(magic_message)] == magic_message:
1993
1994
1995
1996
1997
1998
1999
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

2000
2001
2002
2003
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
2004
2005
2006

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
2007
        if rank == source_rank and shm:
2008
            shm.unlink()
2009

2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
2020
2021


2022
2023
def is_global_first_rank() -> bool:
    """
2024
    Check if the current process is the first rank globally across all
2025
    parallelism strategies (PP, TP, DP, EP, etc.).
2026

2027
2028
2029
    Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
    or `get_pp_group().is_first_rank`, this function checks the global rank
    across all parallelism dimensions.
2030

2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
    Returns:
        bool: True if this is the global first rank (rank 0), False otherwise.
              Returns True if distributed is not initialized (single process).
    """
    try:
        # If world group is available, use it for the most accurate check
        global _WORLD
        if _WORLD is not None:
            return _WORLD.is_first_rank

        # If torch distributed is not initialized, assume single process
        if not torch.distributed.is_initialized():
            return True

        # Fallback to torch's global rank
        return torch.distributed.get_rank() == 0

    except Exception:
        # If anything goes wrong, assume this is the first rank
        return True


2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
def is_local_first_rank() -> bool:
    """
    Check if the current process is the first local rank (rank 0 on its node).
    """
    try:
        # prefer the initialized world group if available
        global _WORLD
        if _WORLD is not None:
            return _WORLD.local_rank == 0

        if not torch.distributed.is_initialized():
            return True

        # fallback to environment-provided local rank if available
        # note: envs.LOCAL_RANK is set when using env:// launchers (e.g., torchrun)
        try:
            return int(envs.LOCAL_RANK) == 0  # type: ignore[arg-type]
        except Exception:
            return torch.distributed.get_rank() == 0
    except Exception:
        return True


2076
def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
2077
2078
2079
2080
2081
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
2082

2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id