parallel_state.py 74.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25

26
import contextlib
27
import gc
28
import pickle
29
import weakref
30
from collections import namedtuple
31
from collections.abc import Callable
32
33
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
34
from datetime import timedelta
35
from multiprocessing import shared_memory
36
from typing import TYPE_CHECKING, Any, Protocol
37
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
38
39

import torch
40
import torch.distributed
41
42
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
43
from torch.distributed import Backend, ProcessGroup, Store
Zhuohan Li's avatar
Zhuohan Li committed
44

45
import vllm.envs as envs
46
from vllm.distributed.device_communicators.base_device_communicator import (
47
48
    DeviceCommunicatorBase,
)
49
50
51
52
from vllm.distributed.utils import (
    StatelessProcessGroup,
    get_cached_tcp_store_client,
)
53
from vllm.logger import init_logger
54
from vllm.utils.import_utils import resolve_obj_by_qualname
55
from vllm.utils.network_utils import get_distributed_init_method
56
from vllm.utils.system_utils import suppress_stdout
57
58
59
from vllm.utils.torch_utils import (
    direct_register_custom_op,
)
60

61
62
63
if TYPE_CHECKING:
    from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator

64

65
66
67
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
68

69

70
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
71

72

73
74
75
76
77
78
79
80
class Handle(Protocol):
    """Minimal async work handle used by P2P send/recv methods."""

    def is_completed(self) -> bool: ...

    def wait(self) -> None: ...


81
def _split_tensor_dict(
82
    tensor_dict: dict[str, torch.Tensor | Any],
83
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
84
85
86
87
88
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
89
90
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
91
92
93
94
95
96
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
97
            device = value.device.type
98
            metadata_list.append(
99
100
                (key, TensorMetadata(device, value.dtype, value.size()))
            )
101
102
            tensor_list.append(value)
        else:
103
            metadata_list.append((key, value))
104
105
106
    return metadata_list, tensor_list


107
_group_name_counter: dict[str, int] = {}
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


123
_groups: dict[str, Callable[[], "GroupCoordinator | None"]] = {}
124
125
126


def _register_group(group: "GroupCoordinator") -> None:
127
    _groups[group.unique_name] = weakref.ref(group)
128
129


130
131
132
133
134
135
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
136
137


138
139
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
140

141

142
143
144
def reduce_scatter(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
145
146
147
148
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
149
    return group._reduce_scatter_out_place(tensor, dim)
150
151


152
153
154
def reduce_scatter_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
155
156
157
158
159
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


160
161
162
def all_gather(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
163
164
165
166
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
167
    return group._all_gather_out_place(tensor, dim)
168
169


170
171
172
def all_gather_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
173
174
175
176
177
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def patched_fused_scaled_matmul_reduce_scatter_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    # Copied from
    # https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
    if A_scale.numel() > 1:
        if A_scale.shape[:-1] != A.shape[:-1]:
            raise ValueError(
                "For row-wise scaling, the leading dims of A_scale "
                "must match the leading dims of A "
                f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
            )
        A_scale = A_scale.flatten(0, -2).contiguous()
    elif A_scale.numel() != 1:
        raise ValueError(
            "Invalid A_scale shape "
            f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
        )

    C = torch._scaled_mm(
        A.flatten(0, -2).contiguous(),
        B,
        A_scale,
        B_scale,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )
    C = C.view(*output_shape[:-1], B.shape[1])
    res = funcol.reduce_scatter_tensor(
        C,
        reduce_op,
        orig_scatter_dim,  # need original scatter dim for 3D+ output tensor here
        group_name,
    )
    res = funcol.wait_tensor(res)
    return res


def patched_fused_scaled_matmul_reduce_scatter(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
        A,
        B,
        A_scale,
        B_scale,
        reduce_op,
        orig_scatter_dim,
        scatter_dim_after_maybe_reshape,
        group_name,
        output_shape,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )


262
263
264
265
266
direct_register_custom_op(
    op_name="all_reduce",
    op_func=all_reduce,
    fake_impl=all_reduce_fake,
)
267

268
269
270
271
272
direct_register_custom_op(
    op_name="reduce_scatter",
    op_func=reduce_scatter,
    fake_impl=reduce_scatter_fake,
)
273

274
275
276
277
278
direct_register_custom_op(
    op_name="all_gather",
    op_func=all_gather,
    fake_impl=all_gather_fake,
)
279

280
281
282
283
284
285
286
287
# TODO: Remove this once the pytorch fix
# (https://github.com/pytorch/pytorch/pull/165086) gets released,
# in either 2.9.1 or 2.10
direct_register_custom_op(
    op_name="patched_fused_scaled_matmul_reduce_scatter",
    op_func=patched_fused_scaled_matmul_reduce_scatter,
    fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake,
)
288

289

290
291
292
293
294
295
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
296
297
        the processes in the group. It manages both CPU and device
        communication.
298
299
300
301
    """

    # available attributes:
    rank: int  # global rank
302
    ranks: list[int]  # global ranks in the group
303
304
305
306
307
308
309
310
311
312
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
313
314
315
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    # device communicator (if use_device_communicator=True)
316
317
    device_communicator: DeviceCommunicatorBase | None
    mq_broadcaster: Any | None  # shared memory broadcaster
318
319
320

    def __init__(
        self,
321
        group_ranks: list[list[int]],
322
        local_rank: int,
323
        torch_distributed_backend: str | Backend,
324
        use_device_communicator: bool,  # whether to use device communicator
325
        use_message_queue_broadcaster: bool = False,
326
        group_name: str | None = None,
327
    ):
328
329
330
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
331
332
333

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
334
335
336

        self_device_group = None
        self_cpu_group = None
337
338
339

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
340
341
                ranks, backend=torch_distributed_backend
            )
342
343
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
344
345
            with suppress_stdout():
                cpu_group = torch.distributed.new_group(ranks, backend="gloo")
346
347
348
349
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
350
351
352
353
354
                self_device_group = device_group
                self_cpu_group = cpu_group

        assert self_cpu_group is not None
        assert self_device_group is not None
355

356
357
        self.cpu_group = self_cpu_group
        self.device_group = self_device_group
358

359
        from vllm.platforms import current_platform
360

361
        if current_platform.is_cuda_alike():
362
            self.device = torch.device(f"cuda:{local_rank}")
363
364
        elif current_platform.is_xpu():
            self.device = torch.device(f"xpu:{local_rank}")
365
        elif current_platform.is_out_of_tree():
366
            self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
367
368
369
        else:
            self.device = torch.device("cpu")

370
        self.use_device_communicator = use_device_communicator
371
        self.device_communicator = None
372
373
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
374
375
                current_platform.get_device_communicator_cls()
            )
376
377
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
378
                device=self.device,
379
380
                device_group=self.device_group,
                unique_name=self.unique_name,
381
382
            )

383
384
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

385
        self.mq_broadcaster: MessageQueue | None = None
386
387
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
388
389
                self.cpu_group, 1 << 22, 6
            )
390

391
392
        # TODO(#35915): Remove is_tpu() check once tpu_inference
        # overrides use_custom_op_collectives() to return True.
393
        self.use_custom_op_call = (
394
            current_platform.is_tpu() or current_platform.use_custom_op_collectives()
395
396
397
398
399
        )

        self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
            torch.ops._C, "init_shm_manager"
        )
400

401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
    def create_mq_broadcaster(
        self, writer_rank=0, external_writer_handle=None, blocking=True
    ):
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

        return MessageQueue.create_from_process_group(
            self.cpu_group,
            1 << 22,
            6,
            writer_rank=writer_rank,
            external_writer_handle=external_writer_handle,
            blocking=blocking,
        )

    def create_single_reader_mq_broadcasters(
        self, reader_rank_in_group=0, blocking=False
    ):
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

        return MessageQueue.create_from_process_group_single_reader(
            self.cpu_group,
            1 << 22,
            6,
            reader_rank=self.ranks[reader_rank_in_group],
            blocking=blocking,
        )

428
429
430
431
432
433
434
435
436
437
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

438
439
440
441
442
443
444
445
446
447
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
463
    def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None):
464
465
466
467
468
469
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

470
471
472
473
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
474
475
476
            CudaCommunicator,
        )

477
478
479
480
481
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
482
483
484
485
486
487
488

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

489
        with torch.cuda.stream(stream), maybe_ca_context:
490
            yield graph_capture_context
491
492
493

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
494
495
496
497
498
499
500
501
502
503
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
504
505
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
506
507
508
509
510
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

511
        if self.use_custom_op_call:
512
            return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
513
514
        else:
            return self._all_reduce_out_place(input_)
515

516
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
517
518
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
519
        return self.device_communicator.all_reduce(input_)
520
521
522
523
524
525
526

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
527
528
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
529

530
        if self.use_custom_op_call:
531
532
533
            return torch.ops.vllm.all_gather(
                input_, dim, world_size, group_name=self.unique_name
            )
534
535
536
        else:
            return self._all_gather_out_place(input_, dim)

537
    def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
538
539
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
540
        return self.device_communicator.all_gather(input_, dim)
541

542
543
    def all_gatherv(
        self,
544
        input_: torch.Tensor | list[torch.Tensor],
545
        dim: int = 0,
546
        sizes: list[int] | None = None,
547
    ):
548
549
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
550
551
        return self.device_communicator.all_gatherv(input_, dim, sizes)

552
    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
553
554
555
556
557
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
558
559
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
560

561
        if self.use_custom_op_call:
562
563
564
            return torch.ops.vllm.reduce_scatter(
                input_, dim, world_size, group_name=self.unique_name
            )
565
566
567
        else:
            return self._reduce_scatter_out_place(input_, dim)

568
    def reduce_scatterv(
569
        self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
570
    ) -> torch.Tensor:
571
572
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
573
574
        return self.device_communicator.reduce_scatterv(input_, dim, sizes)

575
    def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
576
577
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
578
579
        return self.device_communicator.reduce_scatter(input_, dim)

580
581
    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
582
    ) -> torch.Tensor | None:
583
584
585
586
587
588
589
590
591
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
592
593
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
594
        return self.device_communicator.gather(input_, dst, dim)
595
596
597
598
599
600
601
602
603
604
605

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
606
607
608
        torch.distributed.broadcast(
            input_, src=self.ranks[src], group=self.device_group
        )
609
610
        return input_

611
    def broadcast_object(self, obj: Any | None = None, src: int = 0):
612
613
614
615
616
617
618
619
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
620
621
622
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
623
        if self.rank_in_group == src:
624
625
626
            torch.distributed.broadcast_object_list(
                [obj], src=self.ranks[src], group=self.cpu_group
            )
627
628
629
            return obj
        else:
            recv = [None]
630
631
632
            torch.distributed.broadcast_object_list(
                recv, src=self.ranks[src], group=self.cpu_group
            )
633
634
            return recv[0]

635
    def broadcast_object_list(
636
        self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
637
    ):
638
639
640
641
642
643
644
645
646
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
647
648
649
        torch.distributed.broadcast_object_list(
            obj_list, src=self.ranks[src], group=self.device_group
        )
650
651
        return obj_list

652
653
654
655
656
657
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

658
        assert dst != self.rank_in_group, (
659
            "Invalid destination rank. Destination rank is the same "
660
661
            "as the current rank."
        )
662
663
664
665

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

666
667
668
        size_tensor = torch.tensor(
            [object_tensor.numel()], dtype=torch.long, device="cpu"
        )
669
670
671

        # Send object size

672
        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
673
674

        # Send object
675
        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
676
677
678
679
680
681
682
683
684

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

685
        assert src != self.rank_in_group, (
686
687
688
689
690
691
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
692
693
694
        rank_size = torch.distributed.recv(
            size_tensor, src=self.ranks[src], group=self.cpu_group
        )
695
696
697
698
699

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
700
701
            device="cpu",
        )
702

703
704
705
        rank_object = torch.distributed.recv(
            object_tensor, src=self.ranks[src], group=self.cpu_group
        )
706
707

        assert rank_object == rank_size, (
708
709
            "Received object sender rank does not match the size sender rank."
        )
710
711
712
713
714

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

715
716
    def broadcast_tensor_dict(
        self,
717
        tensor_dict: dict[str, torch.Tensor | Any] | None = None,
718
        src: int = 0,
719
720
721
        group: ProcessGroup | None = None,
        metadata_group: ProcessGroup | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
722
723
724
725
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
726
        if not torch.distributed.is_initialized() or self.world_size == 1:
727
728
729
730
731
732
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

733
734
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
735
            metadata_list: list[tuple[Any, Any]] = []
736
737
738
            assert isinstance(tensor_dict, dict), (
                f"Expecting a dictionary, got {type(tensor_dict)}"
            )
739
740
741
742
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
743
            self.broadcast_object(metadata_list, src=src)
744
745
746
747
748
749
750
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
751
752
753
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=metadata_group, async_op=True
                    )
754
755
                else:
                    # use group for GPU tensors
756
757
758
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=group, async_op=True
                    )
759
760
761
762
763
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
764
            metadata_list = self.broadcast_object(None, src=src)
765
766
            tensor_dict = {}
            async_handles = []
767
            for key, value in metadata_list:
768
                if isinstance(value, TensorMetadata):
769
770
771
                    tensor = torch.empty(
                        value.size, dtype=value.dtype, device=value.device
                    )
772
773
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
774
                        tensor_dict[key] = tensor
775
776
777
778
779
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
780
                            src=self.ranks[src],
781
                            group=metadata_group,
782
783
                            async_op=True,
                        )
784
785
                    else:
                        # use group for GPU tensors
786
                        handle = torch.distributed.broadcast(
787
788
                            tensor, src=self.ranks[src], group=group, async_op=True
                        )
789
                    async_handles.append(handle)
790
                    tensor_dict[key] = tensor
791
                else:
792
                    tensor_dict[key] = value
793
794
795
796
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

797
798
799
800
801
802
803
804
805
806
807
808
809
810
    def _should_use_all_gather(
        self,
        key: str,
        numel: int,
        all_gather_group: "GroupCoordinator | None",
        all_gather_tensors: dict[str, bool] | None,
    ) -> bool:
        if all_gather_group is None:
            return False
        use_all_gather = numel % all_gather_group.world_size == 0
        if all_gather_tensors is not None:
            use_all_gather = all_gather_tensors.get(key, use_all_gather)
        return use_all_gather

811
812
    def send_tensor_dict(
        self,
813
814
        tensor_dict: dict[str, torch.Tensor | Any],
        dst: int | None = None,
815
        all_gather_group: "GroupCoordinator | None" = None,
816
817
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
818
819
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
835
836
837
838
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
        handles = self.isend_tensor_dict(
            tensor_dict,
            dst=dst,
            all_gather_group=all_gather_group,
            all_gather_tensors=all_gather_tensors,
        )
        for handle in handles:
            handle.wait()
        return None

    def isend_tensor_dict(
        self,
        tensor_dict: dict[str, torch.Tensor | Any],
        dst: int | None = None,
        all_gather_group: "GroupCoordinator | None" = None,
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> list[Handle]:
        if self.world_size <= 1:
            return []

859
860
861
862
        if dst is None:
            dst = (self.rank_in_group + 1) % self.world_size
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

863
864
865
866
867
868
869
870
871
        if self.use_cpu_custom_send_recv:
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
            # custom device communicator path is synchronous
            self.device_communicator.send_tensor_dict(  # type: ignore
                tensor_dict, dst
            )
            return []

872
873
874
875
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
876

877
878
879
880
881
        group = self.device_group
        metadata_group = self.cpu_group

        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        self.send_object(metadata_list, dst=dst)
882

883
        tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
884
885
        assert len(tensor_keys) == len(tensor_list)

886
        handles: list[Handle] = []
887
        for key, tensor in zip(tensor_keys, tensor_list):
888
889
            if tensor.numel() == 0:
                continue
890

891
892
893
            if self._should_use_all_gather(
                key, tensor.numel(), all_gather_group, all_gather_tensors
            ):
894
895
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

896
897
898
899
900
901
902
903
904
            comm_group = metadata_group if tensor.is_cpu else group
            handle = torch.distributed.isend(
                tensor, dst=self.ranks[dst], group=comm_group
            )
            if tensor.is_cuda:
                tensor.record_stream(torch.cuda.current_stream(tensor.device))
            handles.append(handle)

        return handles
905
906
907

    def recv_tensor_dict(
        self,
908
        src: int | None = None,
909
        all_gather_group: "GroupCoordinator | None" = None,
910
911
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
912
913
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
929
930
931
932
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
        tensor_dict, handles, postprocess = self.irecv_tensor_dict(
            src=src,
            all_gather_group=all_gather_group,
            all_gather_tensors=all_gather_tensors,
        )
        for handle in handles:
            handle.wait()
        for fn in postprocess:
            fn()
        return tensor_dict

    def irecv_tensor_dict(
        self,
        src: int | None = None,
        all_gather_group: "GroupCoordinator | None" = None,
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> tuple[
        dict[str, torch.Tensor | Any] | None,
        list[Handle],
        list[Callable[[], None]],
    ]:
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None, [], []
956
957
958
959
960

        if src is None:
            src = (self.rank_in_group - 1) % self.world_size
        assert src < self.world_size, f"Invalid src rank ({src})"

961
962
963
964
965
966
967
968
969
        if self.use_cpu_custom_send_recv:
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
            # custom device communicator path is synchronous
            sync_tensor_dict = self.device_communicator.recv_tensor_dict(  # type: ignore
                src
            )
            return sync_tensor_dict, [], []

970
971
972
973
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
974

975
976
977
978
        group = self.device_group
        metadata_group = self.cpu_group

        recv_metadata_list = self.recv_object(src=src)
979
        tensor_dict: dict[str, Any] = {}
980
981
982
        handles: list[Handle] = []
        postprocess: list[Callable[[], None]] = []

983
984
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
985
986
                full_tensor = torch.empty(
                    value.size, dtype=value.dtype, device=value.device
987
                )
988
989
990
                if full_tensor.numel() == 0:
                    tensor_dict[key] = full_tensor
                    continue
991

992
993
994
995
996
997
998
999
1000
1001
                if self._should_use_all_gather(
                    key, full_tensor.numel(), all_gather_group, all_gather_tensors
                ):
                    orig_shape = full_tensor.shape
                    slice_tensor = full_tensor.reshape(all_gather_size, -1)[
                        all_gather_rank
                    ]
                    comm_group = metadata_group if slice_tensor.is_cpu else group
                    handle = torch.distributed.irecv(
                        slice_tensor, src=self.ranks[src], group=comm_group
1002
                    )
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
                    handles.append(handle)

                    def _postprocess(
                        key: str = key,
                        slice_tensor: torch.Tensor = slice_tensor,
                        orig_shape: tuple[int, ...] = tuple(orig_shape),
                        all_gather_group=all_gather_group,
                    ) -> None:
                        assert all_gather_group is not None
                        tensor_dict[key] = all_gather_group.all_gather(
                            slice_tensor, dim=0
                        ).reshape(orig_shape)

                    postprocess.append(_postprocess)
                    tensor_dict[key] = slice_tensor
1018
                else:
1019
1020
1021
                    comm_group = metadata_group if full_tensor.is_cpu else group
                    handle = torch.distributed.irecv(
                        full_tensor, src=self.ranks[src], group=comm_group
1022
                    )
1023
1024
                    handles.append(handle)
                    tensor_dict[key] = full_tensor
1025
            else:
1026
                tensor_dict[key] = value
1027
1028

        return tensor_dict, handles, postprocess
1029

1030
1031
1032
1033
1034
1035
1036
1037
1038
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

1039
    def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
1040
        """Sends a tensor to the destination rank in a blocking way"""
1041
        """NOTE: `dst` is the local rank of the destination rank."""
1042
1043
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
1044
        self.device_communicator.send(tensor, dst)
1045

1046
    def recv(
1047
        self, size: torch.Size, dtype: torch.dtype, src: int | None = None
1048
    ) -> torch.Tensor:
1049
1050
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
1051
1052
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
1053
        return self.device_communicator.recv(size, dtype, src)
1054

1055
    def destroy(self):
1056
        if hasattr(self, "device_group"):
1057
            torch.distributed.destroy_process_group(self.device_group)
1058
1059
            del self.device_group
        if hasattr(self, "cpu_group"):
1060
            torch.distributed.destroy_process_group(self.cpu_group)
1061
            del self.cpu_group
1062
1063
        if self.device_communicator is not None:
            self.device_communicator.destroy()
1064
1065
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
1066

1067
1068
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
1069
            self.device_communicator.prepare_communication_buffer_for_model(model)
1070

1071
    def dispatch_router_logits(
1072
1073
1074
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
1075
        is_sequence_parallel: bool = False,
1076
1077
1078
1079
1080
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
    ):
1081
        if self.device_communicator is not None:
1082
            return self.device_communicator.dispatch_router_logits(
1083
1084
1085
1086
                hidden_states,
                router_logits,
                is_sequence_parallel,
                extra_tensors,
1087
            )
1088
1089
        else:
            return hidden_states, router_logits
1090

1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
    def dispatch(
        self,
        hidden_states: torch.Tensor,
        topk_weights: torch.Tensor,
        topk_ids: torch.Tensor,
        is_sequence_parallel: bool = False,
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor, torch.Tensor, list[torch.Tensor]]
        | tuple[torch.Tensor, torch.Tensor, torch.Tensor]
    ):
        if self.device_communicator is not None:
            return self.device_communicator.dispatch(
                hidden_states,
                topk_weights,
                topk_ids,
                is_sequence_parallel,
                extra_tensors,
            )
        else:
            return hidden_states, topk_weights, topk_ids

1113
1114
1115
    def combine(
        self, hidden_states, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
1116
        if self.device_communicator is not None:
1117
            return self.device_communicator.combine(hidden_states, is_sequence_parallel)
1118
1119
        else:
            return hidden_states
1120

1121

1122
_WORLD: GroupCoordinator | None = None
1123
_INNER_DP_WORLD: GroupCoordinator | None = None
1124
_NODE_COUNT: int | None = None
1125
1126
1127


def get_world_group() -> GroupCoordinator:
1128
    assert _WORLD is not None, "world group is not initialized"
1129
1130
1131
    return _WORLD


1132
1133
1134
1135
1136
def get_inner_dp_world_group() -> GroupCoordinator:
    assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
    return _INNER_DP_WORLD


1137
1138
1139
def init_world_group(
    ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
1140
1141
1142
1143
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
1144
        use_device_communicator=False,
1145
        group_name="world",
1146
1147
1148
    )


1149
def init_model_parallel_group(
1150
    group_ranks: list[list[int]],
1151
1152
1153
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
1154
    group_name: str | None = None,
1155
    use_device_communicator: bool = True,
1156
) -> GroupCoordinator:
1157
1158
1159
1160
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
1161
        use_device_communicator=use_device_communicator,
1162
        use_message_queue_broadcaster=use_message_queue_broadcaster,
1163
        group_name=group_name,
1164
1165
1166
    )


1167
1168
1169
1170
1171
def _init_stateless_group(
    group_ranks: list[list[int]],
    group_name: str,
    host: str,
    backend: str,
1172
    coord_store: Store,
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
    use_device_communicator: bool = True,
) -> "StatelessGroupCoordinator":
    """Create a StatelessGroupCoordinator with the given parameters."""
    from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator

    world = get_world_group()
    return StatelessGroupCoordinator(
        group_ranks=group_ranks,
        local_rank=world.local_rank,
        torch_distributed_backend=backend,
        use_device_communicator=use_device_communicator,
        group_name=group_name,
        host=host,
1186
        coord_store=coord_store,
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
        global_rank=world.rank,
        global_world_size=world.world_size,
    )


def _replace_active_groups(
    *,
    world: GroupCoordinator | None,
    dp: GroupCoordinator | None,
    ep: GroupCoordinator | None,
    eplb: GroupCoordinator | None,
    node_count: int | None,
) -> None:
    """Destroy the current DP/EP/WORLD/EPLB groups and replace them.

    Destruction is collective — all ranks in the old groups must call this
    function together.  Pass all-``None`` to tear down without replacement.
    """
    global _WORLD, _DP, _EP, _EPLB, _NODE_COUNT
    for group in (_DP, _EP, _WORLD, _EPLB):
        if group is not None:
            group.destroy()
    _WORLD = world
    _DP = dp
    _EP = ep
    _EPLB = eplb
    _NODE_COUNT = node_count


1216
_TP: GroupCoordinator | None = None
1217
1218
1219


def get_tp_group() -> GroupCoordinator:
1220
    assert _TP is not None, "tensor model parallel group is not initialized"
1221
1222
1223
    return _TP


1224
_DCP: GroupCoordinator | None = None
1225
1226
1227


def get_dcp_group() -> GroupCoordinator:
1228
    assert _DCP is not None, "decode context model parallel group is not initialized"
1229
1230
1231
1232
1233
1234
    return _DCP


# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group

1235
_PP: GroupCoordinator | None = None
1236

1237
1238
1239
1240
1241
1242

def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, "pipeline model parallel group is not initialized"
    return _PP


1243
_DP: GroupCoordinator | None = None
1244
1245
1246


def get_dp_group() -> GroupCoordinator:
1247
    assert _DP is not None, "data parallel group is not initialized"
1248
1249
    return _DP

1250

1251
_EP: GroupCoordinator | None = None
1252
1253
1254


def get_ep_group() -> GroupCoordinator:
1255
1256
1257
1258
1259
    assert _EP is not None, (
        "expert parallel group is not initialized. "
        "EP group is only created for MoE models with num_experts > 0. "
        "This function should only be called for MoE models."
    )
1260
1261
1262
    return _EP


1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
_EPLB: GroupCoordinator | None = None


def get_eplb_group() -> GroupCoordinator:
    assert _EPLB is not None, (
        "EPLB group is not initialized. "
        "EPLB group is only created for MoE models when EPLB is enabled. "
        "Ensure parallel_config.enable_eplb is True."
    )
    return _EPLB


1275
1276
1277
1278
1279
1280
_PCP: GroupCoordinator | None = None


def get_pcp_group() -> GroupCoordinator:
    assert _PCP is not None, "prefill context parallel group is not initialized"
    return _PCP
1281
1282


1283
@contextmanager
1284
def graph_capture(device: torch.device):
1285
1286
    """
    `graph_capture` is a context manager which should surround the code that
1287
1288
    is capturing the CUDA graph. Its main purpose is to ensure that some
    operations will be run after the graph is captured, before the graph
1289
1290
1291
1292
1293
1294
1295
1296
1297
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
1298
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
1299
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
1300
1301
        yield context

1302

1303
logger = init_logger(__name__)
1304

1305
_ENABLE_CUSTOM_ALL_REDUCE = True
1306
1307


1308
1309
1310
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
1311

Zhuohan Li's avatar
Zhuohan Li committed
1312

1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
def _init_elastic_ep_world(
    config, local_rank: int, backend: str, rank: int, world_size: int
) -> None:
    from vllm.distributed.stateless_coordinator import StatelessGroupCoordinator

    global _WORLD, _NODE_COUNT
    assert _WORLD is None, "world group already initialized"
    parallel_config = config.parallel_config
    global_rank = parallel_config.data_parallel_rank * world_size + rank
    global_world_size = parallel_config.world_size_across_dp
    all_ranks = list(range(global_world_size))
    group_ranks = [all_ranks[i : i + 1] for i in range(global_world_size)]
    if global_rank in all_ranks:
        group_ranks = [all_ranks]
1327
1328
1329
    coord_store = get_cached_tcp_store_client(
        parallel_config.data_parallel_master_ip, parallel_config._coord_store_port
    )
1330
1331
1332
1333
1334
1335
1336
    world = StatelessGroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
        use_device_communicator=False,
        group_name="world",
        host=parallel_config.data_parallel_master_ip,
1337
        coord_store=coord_store,
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
        global_rank=global_rank,
        global_world_size=global_world_size,
    )
    assert parallel_config.nnodes_within_dp == 1, (
        "Elastic EP is not supported with multi-node TP/PP"
    )
    _NODE_COUNT = _node_count(world.tcp_store_group)
    _WORLD = world


1348
1349
1350
1351
1352
1353
def init_distributed_environment(
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
    local_rank: int = -1,
    backend: str = "nccl",
1354
    timeout: timedelta | None = None,
1355
):
1356
    logger.debug(
1357
1358
1359
1360
1361
1362
1363
        "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
        world_size,
        rank,
        local_rank,
        distributed_init_method,
        backend,
    )
1364
    from vllm.config import get_current_vllm_config_or_none
1365

1366
    config = get_current_vllm_config_or_none()
1367
    enable_elastic_ep = config is not None and config.parallel_config.enable_elastic_ep
1368
    if (
1369
1370
        config is not None
        and config.parallel_config.distributed_executor_backend != "external_launcher"
1371
1372
1373
1374
        and (
            config.parallel_config.nnodes > 1
            or config.parallel_config.data_parallel_size > 1
        )
1375
        and not enable_elastic_ep
1376
    ):
1377
1378
1379
1380
1381
1382
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398

        # Use appropriate IP and port based on configuration
        if parallel_config.nnodes > 1:
            ip = parallel_config.master_addr
            port = parallel_config.master_port
            distributed_init_method = get_distributed_init_method(ip, port)
        else:
            ip = parallel_config.data_parallel_master_ip
            port = parallel_config.get_next_dp_init_port()
            distributed_init_method = get_distributed_init_method(ip, port)
            logger.debug(
                "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
                world_size,
                rank,
                distributed_init_method,
            )
1399
    if not torch.distributed.is_initialized():
1400
1401
1402
1403
1404
1405
1406
1407
        logger.info(
            "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
            world_size,
            rank,
            local_rank,
            distributed_init_method,
            backend,
        )
1408
1409
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
1410
1411
            "distributed environment"
        )
1412
1413
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
1414
1415
1416
                "Distributed backend %s is not available; falling back to gloo.",
                backend,
            )
1417
            assert torch.distributed.is_gloo_available(), (
1418
1419
                "Fallback Gloo backend is not available."
            )
1420
            backend = "gloo"
1421
1422
1423
1424
1425
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
1426
            rank=rank,
1427
1428
            timeout=timeout,
        )
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
        if enable_elastic_ep:
            tp_pp_cpu_group = torch.distributed.new_group(
                backend="gloo", timeout=timeout
            )
            if _node_count(tp_pp_cpu_group) > 1:
                # NOTE(yongji): StatelessGroupCoordinator uses data_parallel_master_ip
                # to initialize all DP/EP groups, hence all ranks within TP/PP group
                # must reside on the same node
                raise RuntimeError(
                    "Elastic EP is not yet supported with multi-node TP/PP"
                )

1441
1442
1443
1444
1445
1446
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
1447
        local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
1448
    global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
1449
1450
1451
    if enable_elastic_ep:
        _init_elastic_ep_world(config, local_rank, backend, rank, world_size)
        return
1452
    if _WORLD is None:
1453
        ranks = list(range(torch.distributed.get_world_size()))
1454
        _WORLD = init_world_group(ranks, local_rank, backend)
1455
        if config is not None and config.parallel_config.nnodes > 1:
1456
1457
1458
            _NODE_COUNT = config.parallel_config.nnodes
        else:
            _NODE_COUNT = _node_count(_WORLD.cpu_group)
1459
        logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
1460
1461
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
1462
1463
            "world group already initialized with a different world size"
        )
1464
    if config is not None and config.parallel_config.nnodes_within_dp > 1:
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
        if parallel_config.data_parallel_size > 1:
            world_size_inner_dp = parallel_config.world_size
            group_ranks = [
                [dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
                for dp_rank in range(parallel_config.data_parallel_size)
            ]
            _INNER_DP_WORLD = init_model_parallel_group(
                group_ranks,
                get_world_group().local_rank,
                backend,
                use_message_queue_broadcaster=True,
                group_name="inner_dp_world",
                use_device_communicator=False,
            )
        else:
            _INNER_DP_WORLD = _WORLD
1481
1482


Zhuohan Li's avatar
Zhuohan Li committed
1483
1484
1485
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
1486
    prefill_context_model_parallel_size: int = 1,
1487
1488
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
Zhuohan Li's avatar
Zhuohan Li committed
1489
1490
) -> None:
    """
1491
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
1492
1493

    Arguments:
1494
1495
1496
1497
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.
1498
        backend: name of torch distributed communication backend.
1499
1500

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1501
1502
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1503
1504
1505
1506
1507
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1508
1509
1510
1511
1512
1513
1514
1515
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()

1516
1517
1518
1519
1520
    from vllm.config import get_current_vllm_config

    config = get_current_vllm_config()
    data_parallel_size = config.parallel_config.data_parallel_size
    enable_elastic_ep = config.parallel_config.enable_elastic_ep
1521
1522
    parallel_config = config.parallel_config
    coord_store: Store | None = None
1523
    if enable_elastic_ep:
1524
1525
1526
1527
        coord_store = get_cached_tcp_store_client(
            parallel_config.data_parallel_master_ip,
            parallel_config._coord_store_port,
        )
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
        # Use stateless world group for global information
        world_size = get_world_group().world_size
        rank = get_world_group().rank
        backend = backend or "nccl"
        tp_pp_pcp_size = (
            tensor_model_parallel_size
            * pipeline_model_parallel_size
            * prefill_context_model_parallel_size
        )
        local_all_ranks = torch.arange(tp_pp_pcp_size).reshape(
            pipeline_model_parallel_size,
            prefill_context_model_parallel_size,
            tensor_model_parallel_size,
        )
    else:
        world_size = torch.distributed.get_world_size()
        rank = torch.distributed.get_rank()
        backend = backend or torch.distributed.get_backend(
            get_world_group().device_group
        )
1548
1549
1550
1551
1552
1553
1554
1555

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1556
1557
1558
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1559
1560
1561
1562
1563
        -1,
        data_parallel_size,
        pipeline_model_parallel_size,
        prefill_context_model_parallel_size,
        tensor_model_parallel_size,
1564
    )  # noqa
1565

1566
1567
    # Build the tensor model-parallel groups.
    global _TP
1568
    assert _TP is None, "tensor model parallel group is already initialized"
1569
1570
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1571
1572
1573
    if enable_elastic_ep:
        group_ranks = local_all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
        group_ranks = [x.tolist() for x in group_ranks]
1574
    # message queue broadcaster is only used in tensor model parallel group
1575
1576
1577
1578
1579
1580
1581
    _TP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="tp",
    )
1582

1583
1584
    # Build the DCP model-parallel groups.
    global _DCP
1585
    assert _DCP is None, "decode context model parallel group is already initialized"
1586
1587
    # Note(hc): In the current implementation of decode context parallel,
    # dcp_size must not exceed tp_size, because the world size does not
1588
    # change by DCP, it simply reuses the GPUs of TP group, and split one
1589
    # TP group into tp_size//dcp_size DCP groups.
1590
    group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
1591
    group_ranks = [x.tolist() for x in group_ranks]
1592
1593
1594
1595
1596
    if enable_elastic_ep:
        group_ranks = local_all_ranks.reshape(
            -1, decode_context_model_parallel_size
        ).unbind(0)
        group_ranks = [x.tolist() for x in group_ranks]
1597
1598
1599
1600
1601
1602
1603
    _DCP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="dcp",
    )
1604

1605
1606
1607
1608
1609
1610
1611
1612
    global _PCP
    assert _PCP is None, "prefill context parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(3, 4)
        .reshape(-1, prefill_context_model_parallel_size)
        .unbind(0)
    )
    group_ranks = [x.tolist() for x in group_ranks]
1613
1614
1615
1616
1617
1618
1619
    if enable_elastic_ep:
        group_ranks = (
            local_all_ranks.transpose(1, 2)
            .reshape(-1, prefill_context_model_parallel_size)
            .unbind(0)
        )
        group_ranks = [x.tolist() for x in group_ranks]
1620
1621
1622
1623
    _PCP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pcp"
    )

1624
    # Build the pipeline model-parallel groups.
1625
    global _PP
1626
1627
    assert _PP is None, "pipeline model parallel group is already initialized"
    group_ranks = (
1628
        all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
1629
    )
1630
    group_ranks = [x.tolist() for x in group_ranks]
1631
1632
1633
1634
1635
1636
1637
    if enable_elastic_ep:
        group_ranks = (
            local_all_ranks.transpose(0, 2)
            .reshape(-1, pipeline_model_parallel_size)
            .unbind(0)
        )
        group_ranks = [x.tolist() for x in group_ranks]
1638
1639
1640
    _PP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pp"
    )
1641

1642
    global _DP
1643
    assert _DP is None, "data parallel group is already initialized"
1644
    group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
1645
    group_ranks = [x.tolist() for x in group_ranks]
1646
1647
1648
1649
1650
1651
    if enable_elastic_ep:
        _DP = _init_stateless_group(
            group_ranks,
            "dp",
            parallel_config.data_parallel_master_ip,
            backend,
1652
            coord_store=coord_store,
1653
1654
1655
1656
1657
        )
    else:
        _DP = init_model_parallel_group(
            group_ranks, get_world_group().local_rank, backend, group_name="dp"
        )
1658

1659
    global _EP
1660
    assert _EP is None, "expert parallel group is already initialized"
1661
    # Don't create EP group for dense models.
1662
    if config.model_config is None or config.model_config.is_moe:
1663
1664
1665
1666
1667
1668
1669
1670
1671
        group_ranks = (
            all_ranks.transpose(1, 2)
            .reshape(
                -1,
                data_parallel_size
                * prefill_context_model_parallel_size
                * tensor_model_parallel_size,
            )
            .unbind(0)
1672
        )
1673
        group_ranks = [x.tolist() for x in group_ranks]
1674
1675
1676
1677
1678
1679
        if enable_elastic_ep:
            _EP = _init_stateless_group(
                group_ranks,
                "ep",
                parallel_config.data_parallel_master_ip,
                backend,
1680
                coord_store=coord_store,
1681
1682
1683
1684
1685
            )
        else:
            _EP = init_model_parallel_group(
                group_ranks, get_world_group().local_rank, backend, group_name="ep"
            )
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697

        # Create EPLB group with the same ranks as EP if EPLB is enabled.
        # This is a separate process group to isolate EPLB communications
        # from MoE forward pass collectives and prevent deadlocks when
        # using torch.distributed in execution with torch.distributed in EPLB.
        global _EPLB
        assert _EPLB is None, "EPLB group is already initialized"
        if (
            config is not None
            and config.parallel_config is not None
            and config.parallel_config.enable_eplb
        ):
1698
1699
1700
1701
1702
1703
            if enable_elastic_ep:
                _EPLB = _init_stateless_group(
                    group_ranks,
                    "eplb",
                    parallel_config.data_parallel_master_ip,
                    backend,
1704
                    coord_store=coord_store,
1705
1706
1707
1708
1709
1710
1711
1712
                )
            else:
                _EPLB = init_model_parallel_group(
                    group_ranks,
                    get_world_group().local_rank,
                    backend,
                    group_name="eplb",
                )
1713
    # If no EP group needed, _EP remains None
1714
    # If no EPLB group needed, _EPLB remains None
1715

1716
    logger.info_once(
1717
        "rank %s in world size %s is assigned as "
1718
        "DP rank %s, PP rank %s, PCP rank %s, "
1719
        "TP rank %s, EP rank %s, EPLB rank %s",
1720
1721
1722
1723
        rank,
        world_size,
        _DP.rank_in_group,
        _PP.rank_in_group,
1724
        _PCP.rank_in_group,
1725
        _TP.rank_in_group,
1726
        _EP.rank_in_group if _EP is not None else "N/A",
1727
        _EPLB.rank_in_group if _EPLB is not None else "N/A",
1728
    )
1729

Zhuohan Li's avatar
Zhuohan Li committed
1730

1731
1732
1733
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1734
    prefill_context_model_parallel_size: int = 1,
1735
1736
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
1737
1738
1739
1740
1741
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1742
1743
1744
1745
1746
    world_group = get_world_group()
    if hasattr(world_group, "backend"):
        backend = backend or world_group.backend
    else:
        backend = backend or torch.distributed.get_backend(world_group.device_group)
1747
    if not model_parallel_is_initialized():
1748
1749
1750
        initialize_model_parallel(
            tensor_model_parallel_size,
            pipeline_model_parallel_size,
1751
            prefill_context_model_parallel_size,
1752
1753
1754
            decode_context_model_parallel_size,
            backend,
        )
1755
1756
        return

1757
1758
    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
        "tensor parallel group already initialized, but of unexpected size. "
1759
        f"got: {get_tensor_model_parallel_world_size()=} vs. "
1760
1761
        f"wanted: {tensor_model_parallel_size=}"
    )
1762
    pp_world_size = get_pp_group().world_size
1763
    assert pp_world_size == pipeline_model_parallel_size, (
1764
1765
        "pipeline parallel group already initialized, but of unexpected size. "
        f"got: {pp_world_size=} vs. "
1766
1767
        f"wanted: {pipeline_model_parallel_size=}"
    )
1768
1769
1770
1771
1772
1773
    pcp_world_size = get_pcp_group().world_size
    assert pcp_world_size == prefill_context_model_parallel_size, (
        "prefill context parallel group already initialized, but of unexpected size: "
        f"{pcp_world_size=} vs. "
        f"{prefill_context_model_parallel_size=}"
    )
1774
1775


1776
1777
1778
1779
1780
1781
1782
1783
1784
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
1785
1786
    if _PCP is not None:
        _PCP.prepare_communication_buffer_for_model(model)
1787
1788
1789
1790
1791
1792
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)
1793
1794
    if _EPLB is not None:
        _EPLB.prepare_communication_buffer_for_model(model)
1795
1796


Zhuohan Li's avatar
Zhuohan Li committed
1797
def model_parallel_is_initialized():
1798
    """Check if tensor and pipeline parallel groups are initialized."""
1799
    return _TP is not None and _PP is not None
1800
1801


1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


1830
def get_tensor_model_parallel_world_size() -> int:
Zhuohan Li's avatar
Zhuohan Li committed
1831
    """Return world size for the tensor model parallel group."""
1832
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1833
1834


1835
def get_tensor_model_parallel_rank() -> int:
Zhuohan Li's avatar
Zhuohan Li committed
1836
    """Return my rank for the tensor model parallel group."""
1837
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1838
1839


1840
def get_decode_context_model_parallel_world_size() -> int:
1841
1842
1843
1844
    """Return world size for the decode context model parallel group."""
    return get_dcp_group().world_size


1845
def get_decode_context_model_parallel_rank() -> int:
1846
1847
1848
1849
    """Return my rank for the decode context model parallel group."""
    return get_dcp_group().rank_in_group


1850
def get_node_count() -> int:
1851
1852
    """Return the total number of nodes in the distributed environment."""
    assert _NODE_COUNT is not None, "distributed environment is not initialized"
1853
1854
1855
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1856
def destroy_model_parallel():
1857
    """Set the groups to none and destroy them."""
1858
    global _TP
1859

1860
1861
1862
1863
    if _TP:
        _TP.destroy()
    _TP = None

1864
1865
1866
1867
1868
    global _DCP
    if _DCP:
        _DCP.destroy()
    _DCP = None

1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
    global _PCP
    if _PCP:
        _PCP.destroy()
    _PCP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1879
1880
1881
1882
1883
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1884
1885
1886
1887
1888
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1889
1890
1891
1892
1893
    global _EPLB
    if _EPLB:
        _EPLB.destroy()
    _EPLB = None

1894
1895

def destroy_distributed_environment():
1896
    global _WORLD, _NODE_COUNT
1897
1898
1899
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1900
    _NODE_COUNT = None
1901
1902
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1903
1904


1905
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1906
1907
    # Reset environment variable cache
    envs.disable_envs_cache()
1908
    # Ensure all objects are not frozen before cleanup
1909
1910
    gc.unfreeze()

1911
1912
1913
1914
    destroy_model_parallel()
    destroy_distributed_environment()
    if shutdown_ray:
        import ray  # Lazy import Ray
1915

1916
1917
        ray.shutdown()
    gc.collect()
1918
    from vllm.platforms import current_platform
1919

1920
1921
1922
    if not current_platform.is_cpu():
        torch.accelerator.empty_cache()
        try:
1923
            torch._C._host_emptyCache()
1924
1925
1926
1927
        except AttributeError:
            logger.warning(
                "torch._C._host_emptyCache() only available in Pytorch >=2.5"
            )
1928
1929


1930
def in_the_same_node_as(
1931
    pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0
1932
) -> list[bool]:
1933
    """
1934
1935
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1936
1937
    memory system (shared access to shared memory).
    """
1938
    if isinstance(pg, ProcessGroup):
1939
1940
1941
        assert torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL, (
            "in_the_same_node_as should be tested with a non-NCCL group."
        )
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1952
1953

    # local tensor in each process to store the result
1954
1955
1956
    is_in_the_same_node = torch.tensor(
        [0] * world_size, dtype=torch.int32, device="cpu"
    )
1957
1958
1959
1960
1961
1962

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1963
            if rank == source_rank:
1964
1965
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
1966
                assert shm.buf is not None, "Buffer was not created"
1967
                shm.buf[: len(magic_message)] = magic_message
1968
1969
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
1970
1971
                        [shm.name], src=ranks[source_rank], group=pg
                    )
1972
1973
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1974
                is_in_the_same_node[rank] = 1
1975
1976
            else:
                # try to open the shared memory segment
1977
1978
1979
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
1980
1981
                        recv, src=ranks[source_rank], group=pg
                    )
1982
1983
1984
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1985
1986
1987
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
1988
1989
1990
1991
                with patch(
                    "multiprocessing.resource_tracker.register",
                    lambda *args, **kwargs: None,
                ):
1992
                    shm = shared_memory.SharedMemory(name=name)
1993
                assert shm.buf is not None, "Buffer was not opened"
1994
                if shm.buf[: len(magic_message)] == magic_message:
1995
1996
1997
1998
1999
2000
2001
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

2002
2003
2004
2005
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
2006
2007
2008

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
2009
        if rank == source_rank and shm:
2010
            shm.unlink()
2011

2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
2022
2023


2024
2025
def is_global_first_rank() -> bool:
    """
2026
    Check if the current process is the first rank globally across all
2027
    parallelism strategies (PP, TP, DP, EP, etc.).
2028

2029
2030
2031
    Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
    or `get_pp_group().is_first_rank`, this function checks the global rank
    across all parallelism dimensions.
2032

2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
    Returns:
        bool: True if this is the global first rank (rank 0), False otherwise.
              Returns True if distributed is not initialized (single process).
    """
    try:
        # If world group is available, use it for the most accurate check
        global _WORLD
        if _WORLD is not None:
            return _WORLD.is_first_rank

        # If torch distributed is not initialized, assume single process
        if not torch.distributed.is_initialized():
            return True

        # Fallback to torch's global rank
        return torch.distributed.get_rank() == 0

    except Exception:
        # If anything goes wrong, assume this is the first rank
        return True


2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
def is_local_first_rank() -> bool:
    """
    Check if the current process is the first local rank (rank 0 on its node).
    """
    try:
        # prefer the initialized world group if available
        global _WORLD
        if _WORLD is not None:
            return _WORLD.local_rank == 0

        if not torch.distributed.is_initialized():
            return True

        # fallback to environment-provided local rank if available
        # note: envs.LOCAL_RANK is set when using env:// launchers (e.g., torchrun)
        try:
            return int(envs.LOCAL_RANK) == 0  # type: ignore[arg-type]
        except Exception:
            return torch.distributed.get_rank() == 0
    except Exception:
        return True


2078
def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
2079
2080
2081
2082
2083
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
2084

2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id