parallel_state.py 59.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25

26
import contextlib
27
import gc
28
import pickle
29
import weakref
30
from collections import namedtuple
31
from collections.abc import Callable
32
33
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
34
from datetime import timedelta
35
from multiprocessing import shared_memory
36
from typing import Any, Optional
37
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
38
39

import torch
40
import torch.distributed
41
42
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
43
from torch.distributed import Backend, ProcessGroup
44
from typing_extensions import deprecated
Zhuohan Li's avatar
Zhuohan Li committed
45

46
import vllm.envs as envs
47
from vllm.distributed.device_communicators.base_device_communicator import (
48
49
    DeviceCommunicatorBase,
)
50
from vllm.distributed.utils import StatelessProcessGroup
51
from vllm.logger import init_logger
52
53
54
from vllm.utils import (
    get_distributed_init_method,
)
55
from vllm.utils.import_utils import resolve_obj_by_qualname
56
57
58
59
from vllm.utils.torch_utils import (
    direct_register_custom_op,
    supports_custom_op,
)
60
61


62
63
64
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
65

66

67
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
68

69
70

def _split_tensor_dict(
71
    tensor_dict: dict[str, torch.Tensor | Any],
72
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
73
74
75
76
77
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
78
79
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
80
81
82
83
84
85
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
86
            device = value.device.type
87
            metadata_list.append(
88
89
                (key, TensorMetadata(device, value.dtype, value.size()))
            )
90
91
            tensor_list.append(value)
        else:
92
            metadata_list.append((key, value))
93
94
95
    return metadata_list, tensor_list


96
_group_name_counter: dict[str, int] = {}
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


112
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
113
114
115


def _register_group(group: "GroupCoordinator") -> None:
116
    _groups[group.unique_name] = weakref.ref(group)
117
118


119
120
121
122
123
124
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
125
126


127
128
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
129

130

131
132
133
def reduce_scatter(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
134
135
136
137
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
138
    return group._reduce_scatter_out_place(tensor, dim)
139
140


141
142
143
def reduce_scatter_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
144
145
146
147
148
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


149
150
151
def all_gather(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
152
153
154
155
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
156
    return group._all_gather_out_place(tensor, dim)
157
158


159
160
161
def all_gather_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
162
163
164
165
166
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
def patched_fused_scaled_matmul_reduce_scatter_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    # Copied from
    # https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
    if A_scale.numel() > 1:
        if A_scale.shape[:-1] != A.shape[:-1]:
            raise ValueError(
                "For row-wise scaling, the leading dims of A_scale "
                "must match the leading dims of A "
                f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
            )
        A_scale = A_scale.flatten(0, -2).contiguous()
    elif A_scale.numel() != 1:
        raise ValueError(
            "Invalid A_scale shape "
            f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
        )

    C = torch._scaled_mm(
        A.flatten(0, -2).contiguous(),
        B,
        A_scale,
        B_scale,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )
    C = C.view(*output_shape[:-1], B.shape[1])
    res = funcol.reduce_scatter_tensor(
        C,
        reduce_op,
        orig_scatter_dim,  # need original scatter dim for 3D+ output tensor here
        group_name,
    )
    res = funcol.wait_tensor(res)
    return res


def patched_fused_scaled_matmul_reduce_scatter(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
        A,
        B,
        A_scale,
        B_scale,
        reduce_op,
        orig_scatter_dim,
        scatter_dim_after_maybe_reshape,
        group_name,
        output_shape,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )


251
if supports_custom_op():
252
    direct_register_custom_op(
253
254
255
        op_name="all_reduce",
        op_func=all_reduce,
        fake_impl=all_reduce_fake,
256
257
    )

258
259
260
261
262
263
264
265
266
267
268
269
    direct_register_custom_op(
        op_name="reduce_scatter",
        op_func=reduce_scatter,
        fake_impl=reduce_scatter_fake,
    )

    direct_register_custom_op(
        op_name="all_gather",
        op_func=all_gather,
        fake_impl=all_gather_fake,
    )

270
271
272
273
274
275
276
277
278
    # TODO: Remove this once the pytorch fix
    # (https://github.com/pytorch/pytorch/pull/165086) gets released,
    # in either 2.9.1 or 2.10
    direct_register_custom_op(
        op_name="patched_fused_scaled_matmul_reduce_scatter",
        op_func=patched_fused_scaled_matmul_reduce_scatter,
        fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake,
    )

279

280
281
282
283
284
285
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
286
287
        the processes in the group. It manages both CPU and device
        communication.
288
289
290
291
    """

    # available attributes:
    rank: int  # global rank
292
    ranks: list[int]  # global ranks in the group
293
294
295
296
297
298
299
300
301
302
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
303
304
305
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    # device communicator (if use_device_communicator=True)
306
307
    device_communicator: DeviceCommunicatorBase | None
    mq_broadcaster: Any | None  # shared memory broadcaster
308
309
310

    def __init__(
        self,
311
        group_ranks: list[list[int]],
312
        local_rank: int,
313
        torch_distributed_backend: str | Backend,
314
        use_device_communicator: bool,  # whether to use device communicator
315
        use_message_queue_broadcaster: bool = False,
316
        group_name: str | None = None,
317
    ):
318
319
320
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
321
322
323

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
324
325
326

        self_device_group = None
        self_cpu_group = None
327
328
329

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
330
331
                ranks, backend=torch_distributed_backend
            )
332
333
334
335
336
337
338
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
339
340
341
342
343
                self_device_group = device_group
                self_cpu_group = cpu_group

        assert self_cpu_group is not None
        assert self_device_group is not None
344

345
346
        self.cpu_group = self_cpu_group
        self.device_group = self_device_group
347

348
        from vllm.platforms import current_platform
349

350
        if current_platform.is_cuda_alike():
351
            self.device = torch.device(f"cuda:{local_rank}")
352
353
        elif current_platform.is_xpu():
            self.device = torch.device(f"xpu:{local_rank}")
354
        elif current_platform.is_out_of_tree():
355
            self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
356
357
358
        else:
            self.device = torch.device("cpu")

359
        self.use_device_communicator = use_device_communicator
360
        self.device_communicator = None
361
362
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
363
364
                current_platform.get_device_communicator_cls()
            )
365
366
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
367
                device=self.device,
368
369
                device_group=self.device_group,
                unique_name=self.unique_name,
370
371
            )

372
373
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

374
        self.mq_broadcaster: MessageQueue | None = None
375
376
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
377
378
                self.cpu_group, 1 << 22, 6
            )
379

380
381
        from vllm.platforms import current_platform

382
383
384
385
386
387
388
        self.use_custom_op_call = (
            current_platform.is_cuda_alike() or current_platform.is_tpu()
        )

        self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
            torch.ops._C, "init_shm_manager"
        )
389

390
391
392
393
394
395
396
397
398
399
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

400
401
402
403
404
405
406
407
408
409
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
425
    def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None):
426
427
428
429
430
431
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

432
433
434
435
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
436
437
438
            CudaCommunicator,
        )

439
440
441
442
443
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
444
445
446
447
448
449
450

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

451
        with torch.cuda.stream(stream), maybe_ca_context:
452
            yield graph_capture_context
453
454
455

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
456
457
458
459
460
461
462
463
464
465
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
466
467
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
468
469
470
471
472
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

473
        if self.use_custom_op_call:
474
            return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
475
476
        else:
            return self._all_reduce_out_place(input_)
477

478
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
479
480
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
481
        return self.device_communicator.all_reduce(input_)
482
483
484
485
486
487
488

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
489
490
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
491

492
        if self.use_custom_op_call:
493
494
495
            return torch.ops.vllm.all_gather(
                input_, dim, world_size, group_name=self.unique_name
            )
496
497
498
        else:
            return self._all_gather_out_place(input_, dim)

499
    def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
500
501
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
502
        return self.device_communicator.all_gather(input_, dim)
503

504
505
    def all_gatherv(
        self,
506
        input_: torch.Tensor | list[torch.Tensor],
507
        dim: int = 0,
508
        sizes: list[int] | None = None,
509
    ):
510
511
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
512
513
        return self.device_communicator.all_gatherv(input_, dim, sizes)

514
    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
515
516
517
518
519
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
520
521
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
522

523
        if self.use_custom_op_call:
524
525
526
            return torch.ops.vllm.reduce_scatter(
                input_, dim, world_size, group_name=self.unique_name
            )
527
528
529
        else:
            return self._reduce_scatter_out_place(input_, dim)

530
    def reduce_scatterv(
531
        self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
532
    ) -> torch.Tensor:
533
534
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
535
536
        return self.device_communicator.reduce_scatterv(input_, dim, sizes)

537
    def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
538
539
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
540
541
        return self.device_communicator.reduce_scatter(input_, dim)

542
543
    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
544
    ) -> torch.Tensor | None:
545
546
547
548
549
550
551
552
553
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
554
555
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
556
        return self.device_communicator.gather(input_, dst, dim)
557
558
559
560
561
562
563
564
565
566
567

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
568
569
570
        torch.distributed.broadcast(
            input_, src=self.ranks[src], group=self.device_group
        )
571
572
        return input_

573
    def broadcast_object(self, obj: Any | None = None, src: int = 0):
574
575
576
577
578
579
580
581
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
582
583
584
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
585
        if self.rank_in_group == src:
586
587
588
            torch.distributed.broadcast_object_list(
                [obj], src=self.ranks[src], group=self.cpu_group
            )
589
590
591
            return obj
        else:
            recv = [None]
592
593
594
            torch.distributed.broadcast_object_list(
                recv, src=self.ranks[src], group=self.cpu_group
            )
595
596
            return recv[0]

597
    def broadcast_object_list(
598
        self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
599
    ):
600
601
602
603
604
605
606
607
608
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
609
610
611
        torch.distributed.broadcast_object_list(
            obj_list, src=self.ranks[src], group=self.device_group
        )
612
613
        return obj_list

614
615
616
617
618
619
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

620
        assert dst != self.rank_in_group, (
621
            "Invalid destination rank. Destination rank is the same "
622
623
            "as the current rank."
        )
624
625
626
627

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

628
629
630
        size_tensor = torch.tensor(
            [object_tensor.numel()], dtype=torch.long, device="cpu"
        )
631
632
633

        # Send object size

634
        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
635
636

        # Send object
637
        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
638
639
640
641
642
643
644
645
646

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

647
        assert src != self.rank_in_group, (
648
649
650
651
652
653
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
654
655
656
        rank_size = torch.distributed.recv(
            size_tensor, src=self.ranks[src], group=self.cpu_group
        )
657
658
659
660
661

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
662
663
            device="cpu",
        )
664

665
666
667
        rank_object = torch.distributed.recv(
            object_tensor, src=self.ranks[src], group=self.cpu_group
        )
668
669

        assert rank_object == rank_size, (
670
671
            "Received object sender rank does not match the size sender rank."
        )
672
673
674
675
676

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

677
678
    def broadcast_tensor_dict(
        self,
679
        tensor_dict: dict[str, torch.Tensor | Any] | None = None,
680
        src: int = 0,
681
682
683
        group: ProcessGroup | None = None,
        metadata_group: ProcessGroup | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
684
685
686
687
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
688
        if not torch.distributed.is_initialized() or self.world_size == 1:
689
690
691
692
693
694
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

695
696
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
697
            metadata_list: list[tuple[Any, Any]] = []
698
699
700
            assert isinstance(tensor_dict, dict), (
                f"Expecting a dictionary, got {type(tensor_dict)}"
            )
701
702
703
704
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
705
            self.broadcast_object(metadata_list, src=src)
706
707
708
709
710
711
712
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
713
714
715
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=metadata_group, async_op=True
                    )
716
717
                else:
                    # use group for GPU tensors
718
719
720
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=group, async_op=True
                    )
721
722
723
724
725
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
726
            metadata_list = self.broadcast_object(None, src=src)
727
728
            tensor_dict = {}
            async_handles = []
729
            for key, value in metadata_list:
730
                if isinstance(value, TensorMetadata):
731
732
733
                    tensor = torch.empty(
                        value.size, dtype=value.dtype, device=value.device
                    )
734
735
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
736
                        tensor_dict[key] = tensor
737
738
739
740
741
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
742
                            src=self.ranks[src],
743
                            group=metadata_group,
744
745
                            async_op=True,
                        )
746
747
                    else:
                        # use group for GPU tensors
748
                        handle = torch.distributed.broadcast(
749
750
                            tensor, src=self.ranks[src], group=group, async_op=True
                        )
751
                    async_handles.append(handle)
752
                    tensor_dict[key] = tensor
753
                else:
754
                    tensor_dict[key] = value
755
756
757
758
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

759
760
    def send_tensor_dict(
        self,
761
762
        tensor_dict: dict[str, torch.Tensor | Any],
        dst: int | None = None,
763
        all_gather_group: Optional["GroupCoordinator"] = None,
764
765
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
766
767
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
783
784
785
786
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict
787
788
789
790
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
791

792
793
794
795
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
796
            dst = (self.rank_in_group + 1) % self.world_size
797
798
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

799
        if self.use_cpu_custom_send_recv:
800
801
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
802
            self.device_communicator.send_tensor_dict(  # type: ignore
803
804
                tensor_dict, dst
            )
805
806
            return None

807
        metadata_list: list[tuple[Any, Any]] = []
808
809
810
        assert isinstance(tensor_dict, dict), (
            f"Expecting a dictionary, got {type(tensor_dict)}"
        )
811
812
813
814
815
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
816

817
        tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
818
819
820
        assert len(tensor_keys) == len(tensor_list)

        for key, tensor in zip(tensor_keys, tensor_list):
821
822
823
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
824
825

            # send-allgather: send only a slice, then do allgather.
826
827
828
829
830
831
832
833
            use_all_gather = (
                all_gather_group is not None and tensor.numel() % all_gather_size == 0
            )
            use_all_gather = (
                all_gather_tensors.get(key, use_all_gather)
                if all_gather_tensors
                else use_all_gather
            )
834
            if use_all_gather:
835
836
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

837
838
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
839
840
841
                torch.distributed.send(
                    tensor, dst=self.ranks[dst], group=metadata_group
                )
842
843
            else:
                # use group for GPU tensors
844
                torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
845
846
847
848
        return None

    def recv_tensor_dict(
        self,
849
        src: int | None = None,
850
        all_gather_group: Optional["GroupCoordinator"] = None,
851
852
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
853
854
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
870
871
872
873
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None
874
875
876
877
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
878

879
880
881
882
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
883
            src = (self.rank_in_group - 1) % self.world_size
884
885
        assert src < self.world_size, f"Invalid src rank ({src})"

886
        if self.use_cpu_custom_send_recv:
887
888
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
889
            return self.device_communicator.recv_tensor_dict(  # type: ignore
890
891
                src
            )
892

893
        recv_metadata_list = self.recv_object(src=src)
894
        tensor_dict: dict[str, Any] = {}
895
896
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
897
                tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
898
899
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
900
                    tensor_dict[key] = tensor
901
                    continue
902
903

                # send-allgather: send only a slice, then do allgather.
904
905
906
907
908
909
910
911
912
                use_all_gather = (
                    all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0
                )
                use_all_gather = (
                    all_gather_tensors.get(key, use_all_gather)
                    if all_gather_tensors
                    else use_all_gather
                )
913
914
915

                if use_all_gather:
                    orig_shape = tensor.shape
916
                    tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
917

918
919
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
920
921
922
                    torch.distributed.recv(
                        tensor, src=self.ranks[src], group=metadata_group
                    )
923
924
                else:
                    # use group for GPU tensors
925
                    torch.distributed.recv(tensor, src=self.ranks[src], group=group)
926
927
928
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
929
930
                        tensor, dim=0
                    )
931
932
                    tensor = tensor.reshape(orig_shape)

933
                tensor_dict[key] = tensor
934
            else:
935
                tensor_dict[key] = value
936
937
        return tensor_dict

938
939
940
941
942
943
944
945
946
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

947
    def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
948
        """Sends a tensor to the destination rank in a blocking way"""
949
        """NOTE: `dst` is the local rank of the destination rank."""
950
951
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
952
        self.device_communicator.send(tensor, dst)
953

954
    def recv(
955
        self, size: torch.Size, dtype: torch.dtype, src: int | None = None
956
    ) -> torch.Tensor:
957
958
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
959
960
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
961
        return self.device_communicator.recv(size, dtype, src)
962

963
    def destroy(self):
964
        if hasattr(self, "device_group"):
965
            torch.distributed.destroy_process_group(self.device_group)
966
967
            del self.device_group
        if hasattr(self, "cpu_group"):
968
            torch.distributed.destroy_process_group(self.cpu_group)
969
            del self.cpu_group
970
971
        if self.device_communicator is not None:
            self.device_communicator.destroy()
972
973
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
974

975
976
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
977
            self.device_communicator.prepare_communication_buffer_for_model(model)
978
979

    def dispatch(
980
981
982
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
983
        is_sequence_parallel: bool = False,
984
    ) -> tuple[torch.Tensor, torch.Tensor]:
985
        if self.device_communicator is not None:
986
987
988
            return self.device_communicator.dispatch(
                hidden_states, router_logits, is_sequence_parallel
            )
989
990
        else:
            return hidden_states, router_logits
991

992
993
994
    def combine(
        self, hidden_states, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
995
        if self.device_communicator is not None:
996
            return self.device_communicator.combine(hidden_states, is_sequence_parallel)
997
998
        else:
            return hidden_states
999

1000

1001
1002
_WORLD: GroupCoordinator | None = None
_NODE_COUNT: int | None = None
1003
1004
1005


def get_world_group() -> GroupCoordinator:
1006
    assert _WORLD is not None, "world group is not initialized"
1007
1008
1009
    return _WORLD


1010
1011
1012
def init_world_group(
    ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
1013
1014
1015
1016
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
1017
        use_device_communicator=False,
1018
        group_name="world",
1019
1020
1021
    )


1022
def init_model_parallel_group(
1023
    group_ranks: list[list[int]],
1024
1025
1026
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
1027
    group_name: str | None = None,
1028
) -> GroupCoordinator:
1029
1030
1031
1032
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
1033
        use_device_communicator=True,
1034
        use_message_queue_broadcaster=use_message_queue_broadcaster,
1035
        group_name=group_name,
1036
1037
1038
    )


1039
_TP: GroupCoordinator | None = None
1040
1041
1042


def get_tp_group() -> GroupCoordinator:
1043
    assert _TP is not None, "tensor model parallel group is not initialized"
1044
1045
1046
    return _TP


1047
1048
1049
1050
1051
@deprecated(
    "`get_tensor_model_parallel_group` has been replaced with "
    "`get_tp_group` and may be removed after v0.12. Please use "
    "`get_tp_group` instead."
)
1052
1053
1054
def get_tensor_model_parallel_group():
    return get_tp_group()

1055

1056
_DCP: GroupCoordinator | None = None
1057
1058
1059


def get_dcp_group() -> GroupCoordinator:
1060
    assert _DCP is not None, "decode context model parallel group is not initialized"
1061
1062
1063
1064
1065
1066
    return _DCP


# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group

1067
_PP: GroupCoordinator | None = None
1068

1069
_DP: GroupCoordinator | None = None
1070
1071
1072


def get_dp_group() -> GroupCoordinator:
1073
    assert _DP is not None, "data parallel group is not initialized"
1074
1075
    return _DP

1076

1077
_EP: GroupCoordinator | None = None
1078
1079
1080


def get_ep_group() -> GroupCoordinator:
1081
    assert _EP is not None, "expert parallel group is not initialized"
1082
1083
1084
    return _EP


1085
def get_pp_group() -> GroupCoordinator:
1086
    assert _PP is not None, "pipeline model parallel group is not initialized"
1087
    return _PP
1088
1089


1090
1091
1092
1093
1094
@deprecated(
    "`get_pipeline_model_parallel_group` has been replaced with "
    "`get_pp_group` and may be removed in v0.12. Please use "
    "`get_pp_group` instead."
)
1095
1096
def get_pipeline_model_parallel_group():
    return get_pp_group()
1097
1098


1099
@contextmanager
1100
def graph_capture(device: torch.device):
1101
1102
    """
    `graph_capture` is a context manager which should surround the code that
1103
1104
    is capturing the CUDA graph. Its main purpose is to ensure that some
    operations will be run after the graph is captured, before the graph
1105
1106
1107
1108
1109
1110
1111
1112
1113
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
1114
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
1115
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
1116
1117
        yield context

1118

1119
logger = init_logger(__name__)
1120

1121
_ENABLE_CUSTOM_ALL_REDUCE = True
1122
1123


1124
1125
1126
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
1127

Zhuohan Li's avatar
Zhuohan Li committed
1128

1129
1130
1131
1132
1133
1134
def init_distributed_environment(
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
    local_rank: int = -1,
    backend: str = "nccl",
1135
    timeout: timedelta | None = None,
1136
):
1137
    logger.debug(
1138
1139
1140
1141
1142
1143
1144
        "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
        world_size,
        rank,
        local_rank,
        distributed_init_method,
        backend,
    )
1145
    from vllm.config import get_current_vllm_config
1146

1147
    config = get_current_vllm_config()
1148
1149
1150
1151
1152
    if (
        config is not None
        and config.parallel_config.data_parallel_size > 1
        and config.parallel_config.distributed_executor_backend != "external_launcher"
    ):
1153
1154
1155
1156
1157
1158
1159
1160
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
        ip = parallel_config.data_parallel_master_ip
        port = parallel_config.get_next_dp_init_port()
1161
        distributed_init_method = get_distributed_init_method(ip, port)
1162
1163
        logger.info(
            "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
1164
1165
1166
1167
            world_size,
            rank,
            distributed_init_method,
        )
1168
1169
1170
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
1171
1172
            "distributed environment"
        )
1173
1174
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
1175
1176
1177
                "Distributed backend %s is not available; falling back to gloo.",
                backend,
            )
1178
            assert torch.distributed.is_gloo_available(), (
1179
1180
                "Fallback Gloo backend is not available."
            )
1181
            backend = "gloo"
1182
1183
1184
1185
1186
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
1187
            rank=rank,
1188
1189
            timeout=timeout,
        )
1190
1191
1192
1193
1194
1195
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
1196
        local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
1197
    global _WORLD, _NODE_COUNT
1198
    if _WORLD is None:
1199
        ranks = list(range(torch.distributed.get_world_size()))
1200
        _WORLD = init_world_group(ranks, local_rank, backend)
1201
        _NODE_COUNT = _node_count(_WORLD.cpu_group)
1202
        logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
1203
1204
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
1205
1206
            "world group already initialized with a different world size"
        )
1207
1208


Zhuohan Li's avatar
Zhuohan Li committed
1209
1210
1211
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
1212
1213
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
Zhuohan Li's avatar
Zhuohan Li committed
1214
1215
) -> None:
    """
1216
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
1217
1218

    Arguments:
1219
1220
1221
1222
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.
1223
        backend: name of torch distributed communication backend.
1224
1225

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1226
1227
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1228
1229
1230
1231
1232
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1233
1234
1235
1236
1237
1238
1239
1240
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1241
    rank = torch.distributed.get_rank()
1242
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1243

1244
1245
    data_parallel_size = 1
    from vllm.config import get_current_vllm_config
1246

1247
1248
    config = get_current_vllm_config()
    if config is not None:
1249
1250
1251
1252
1253
1254
1255
1256
1257
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1258
1259
1260
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1261
1262
        -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size
    )  # noqa
1263

1264
1265
    # Build the tensor model-parallel groups.
    global _TP
1266
    assert _TP is None, "tensor model parallel group is already initialized"
1267
1268
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1269
1270

    # message queue broadcaster is only used in tensor model parallel group
1271
1272
1273
1274
1275
1276
1277
    _TP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="tp",
    )
1278

1279
1280
    # Build the DCP model-parallel groups.
    global _DCP
1281
    assert _DCP is None, "decode context model parallel group is already initialized"
1282
1283
    # Note(hc): In the current implementation of decode context parallel,
    # dcp_size must not exceed tp_size, because the world size does not
1284
    # change by DCP, it simply reuses the GPUs of TP group, and split one
1285
    # TP group into tp_size//dcp_size DCP groups.
1286
    group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
1287
    group_ranks = [x.tolist() for x in group_ranks]
1288
1289
1290
1291
1292
1293
1294
    _DCP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="dcp",
    )
1295

1296
    # Build the pipeline model-parallel groups.
1297
    global _PP
1298
1299
1300
1301
    assert _PP is None, "pipeline model parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0)
    )
1302
    group_ranks = [x.tolist() for x in group_ranks]
1303
1304
1305
    _PP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pp"
    )
1306

1307
    global _DP
1308
1309
    assert _DP is None, "data parallel group is already initialized"
    group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0)
1310
    group_ranks = [x.tolist() for x in group_ranks]
1311
1312
1313
    _DP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="dp"
    )
1314

1315
    global _EP
1316
1317
1318
1319
1320
1321
    assert _EP is None, "expert parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(1, 2)
        .reshape(-1, data_parallel_size * tensor_model_parallel_size)
        .unbind(0)
    )
1322
    group_ranks = [x.tolist() for x in group_ranks]
1323
1324
1325
    _EP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="ep"
    )
1326

1327
1328
    logger.info(
        "rank %s in world size %s is assigned as "
1329
1330
1331
1332
1333
1334
1335
1336
        "DP rank %s, PP rank %s, TP rank %s, EP rank %s",
        rank,
        world_size,
        _DP.rank_in_group,
        _PP.rank_in_group,
        _TP.rank_in_group,
        _EP.rank_in_group,
    )
1337

Zhuohan Li's avatar
Zhuohan Li committed
1338

1339
1340
1341
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1342
1343
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
1344
1345
1346
1347
1348
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1349
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
1350
    if not model_parallel_is_initialized():
1351
1352
1353
1354
1355
1356
        initialize_model_parallel(
            tensor_model_parallel_size,
            pipeline_model_parallel_size,
            decode_context_model_parallel_size,
            backend,
        )
1357
1358
        return

1359
1360
    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
        "tensor parallel group already initialized, but of unexpected size. "
1361
        f"got: {get_tensor_model_parallel_world_size()=} vs. "
1362
1363
        f"wanted: {tensor_model_parallel_size=}"
    )
1364
    pp_world_size = get_pp_group().world_size
1365
    assert pp_world_size == pipeline_model_parallel_size, (
1366
1367
        "pipeline parallel group already initialized, but of unexpected size. "
        f"got: {pp_world_size=} vs. "
1368
1369
        f"wanted: {pipeline_model_parallel_size=}"
    )
1370
1371


1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)


Zhuohan Li's avatar
Zhuohan Li committed
1389
def model_parallel_is_initialized():
1390
    """Check if tensor and pipeline parallel groups are initialized."""
1391
    return _TP is not None and _PP is not None
1392
1393


1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1422
1423
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1424
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1425
1426
1427
1428


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1429
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1430
1431


1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
def get_decode_context_model_parallel_world_size():
    """Return world size for the decode context model parallel group."""
    return get_dcp_group().world_size


def get_decode_context_model_parallel_rank():
    """Return my rank for the decode context model parallel group."""
    return get_dcp_group().rank_in_group


1442
def get_node_count() -> int:
1443
1444
    """Return the total number of nodes in the distributed environment."""
    assert _NODE_COUNT is not None, "distributed environment is not initialized"
1445
1446
1447
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1448
def destroy_model_parallel():
1449
    """Set the groups to none and destroy them."""
1450
    global _TP
1451

1452
1453
1454
1455
1456
1457
1458
1459
1460
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1461
1462
1463
1464
1465
    global _DCP
    if _DCP:
        _DCP.destroy()
    _DCP = None

1466
1467
1468
1469
1470
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1471
1472
1473
1474
1475
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1476
1477

def destroy_distributed_environment():
1478
    global _WORLD, _NODE_COUNT
1479
1480
1481
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1482
    _NODE_COUNT = None
1483
1484
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1485
1486


1487
1488
1489
1490
1491
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    if shutdown_ray:
        import ray  # Lazy import Ray
1492

1493
1494
        ray.shutdown()
    gc.collect()
1495
    from vllm.platforms import current_platform
1496

1497
1498
1499
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1500
    try:
1501
1502
        if not current_platform.is_cpu():
            torch._C._host_emptyCache()
1503
    except AttributeError:
1504
        logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5")
1505
1506


1507
def in_the_same_node_as(
1508
    pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0
1509
) -> list[bool]:
1510
    """
1511
1512
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1513
1514
    memory system (shared access to shared memory).
    """
1515
    if isinstance(pg, ProcessGroup):
1516
1517
1518
        assert torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL, (
            "in_the_same_node_as should be tested with a non-NCCL group."
        )
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1529
1530
1531
1532
1533
1534
1535
1536
1537

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1538
            if rank == source_rank:
1539
1540
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
1541
                shm.buf[: len(magic_message)] = magic_message
1542
1543
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
1544
1545
                        [shm.name], src=ranks[source_rank], group=pg
                    )
1546
1547
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1548
                is_in_the_same_node[rank] = 1
1549
1550
            else:
                # try to open the shared memory segment
1551
1552
1553
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
1554
1555
                        recv, src=ranks[source_rank], group=pg
                    )
1556
1557
1558
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1559
1560
1561
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
1562
1563
1564
1565
                with patch(
                    "multiprocessing.resource_tracker.register",
                    lambda *args, **kwargs: None,
                ):
1566
                    shm = shared_memory.SharedMemory(name=name)
1567
                if shm.buf[: len(magic_message)] == magic_message:
1568
1569
1570
1571
1572
1573
1574
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1575
1576
1577
1578
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1579
1580
1581

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1582
        if rank == source_rank and shm:
1583
            shm.unlink()
1584

1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
1595
1596


1597
1598
def is_global_first_rank() -> bool:
    """
1599
    Check if the current process is the first rank globally across all
1600
    parallelism strategies (PP, TP, DP, EP, etc.).
1601

1602
1603
1604
    Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
    or `get_pp_group().is_first_rank`, this function checks the global rank
    across all parallelism dimensions.
1605

1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
    Returns:
        bool: True if this is the global first rank (rank 0), False otherwise.
              Returns True if distributed is not initialized (single process).
    """
    try:
        # If world group is available, use it for the most accurate check
        global _WORLD
        if _WORLD is not None:
            return _WORLD.is_first_rank

        # If torch distributed is not initialized, assume single process
        if not torch.distributed.is_initialized():
            return True

        # Fallback to torch's global rank
        return torch.distributed.get_rank() == 0

    except Exception:
        # If anything goes wrong, assume this is the first rank
        return True


1628
def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
1629
1630
1631
1632
1633
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
1634

1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id