parallel_state.py 59.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25

26
import contextlib
27
import gc
28
import pickle
29
import weakref
30
31
32
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
33
from datetime import timedelta
34
from multiprocessing import shared_memory
35
from typing import Any, Callable, Optional, Union
36
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
37
38

import torch
39
import torch.distributed
40
41
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
42
from torch.distributed import Backend, ProcessGroup
43
from typing_extensions import deprecated
Zhuohan Li's avatar
Zhuohan Li committed
44

45
import vllm.envs as envs
46
from vllm.distributed.device_communicators.base_device_communicator import (
47
48
    DeviceCommunicatorBase,
)
49
from vllm.distributed.utils import StatelessProcessGroup
50
from vllm.logger import init_logger
51
52
53
54
55
56
from vllm.utils import (
    direct_register_custom_op,
    get_distributed_init_method,
    resolve_obj_by_qualname,
    supports_custom_op,
)
57
58


59
60
61
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
62

63

64
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
65

66
67

def _split_tensor_dict(
68
    tensor_dict: dict[str, Union[torch.Tensor, Any]],
69
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
70
71
72
73
74
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
75
76
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
77
78
79
80
81
82
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
83
            device = value.device.type
84
            metadata_list.append(
85
86
                (key, TensorMetadata(device, value.dtype, value.size()))
            )
87
88
            tensor_list.append(value)
        else:
89
            metadata_list.append((key, value))
90
91
92
    return metadata_list, tensor_list


93
_group_name_counter: dict[str, int] = {}
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


109
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
110
111
112


def _register_group(group: "GroupCoordinator") -> None:
113
    _groups[group.unique_name] = weakref.ref(group)
114
115


116
117
118
119
120
121
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
122
123


124
125
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
126

127

128
129
130
def reduce_scatter(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
131
132
133
134
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
135
    return group._reduce_scatter_out_place(tensor, dim)
136
137


138
139
140
def reduce_scatter_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
141
142
143
144
145
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


146
147
148
def all_gather(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
149
150
151
152
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
153
    return group._all_gather_out_place(tensor, dim)
154
155


156
157
158
def all_gather_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
159
160
161
162
163
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def patched_fused_scaled_matmul_reduce_scatter_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    # Copied from
    # https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
    if A_scale.numel() > 1:
        if A_scale.shape[:-1] != A.shape[:-1]:
            raise ValueError(
                "For row-wise scaling, the leading dims of A_scale "
                "must match the leading dims of A "
                f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
            )
        A_scale = A_scale.flatten(0, -2).contiguous()
    elif A_scale.numel() != 1:
        raise ValueError(
            "Invalid A_scale shape "
            f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
        )

    C = torch._scaled_mm(
        A.flatten(0, -2).contiguous(),
        B,
        A_scale,
        B_scale,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )
    C = C.view(*output_shape[:-1], B.shape[1])
    res = funcol.reduce_scatter_tensor(
        C,
        reduce_op,
        orig_scatter_dim,  # need original scatter dim for 3D+ output tensor here
        group_name,
    )
    res = funcol.wait_tensor(res)
    return res


def patched_fused_scaled_matmul_reduce_scatter(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
        A,
        B,
        A_scale,
        B_scale,
        reduce_op,
        orig_scatter_dim,
        scatter_dim_after_maybe_reshape,
        group_name,
        output_shape,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )


248
if supports_custom_op():
249
    direct_register_custom_op(
250
251
252
        op_name="all_reduce",
        op_func=all_reduce,
        fake_impl=all_reduce_fake,
253
254
    )

255
256
257
258
259
260
261
262
263
264
265
266
    direct_register_custom_op(
        op_name="reduce_scatter",
        op_func=reduce_scatter,
        fake_impl=reduce_scatter_fake,
    )

    direct_register_custom_op(
        op_name="all_gather",
        op_func=all_gather,
        fake_impl=all_gather_fake,
    )

267
268
269
270
271
272
273
274
275
    # TODO: Remove this once the pytorch fix
    # (https://github.com/pytorch/pytorch/pull/165086) gets released,
    # in either 2.9.1 or 2.10
    direct_register_custom_op(
        op_name="patched_fused_scaled_matmul_reduce_scatter",
        op_func=patched_fused_scaled_matmul_reduce_scatter,
        fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake,
    )

276

277
278
279
280
281
282
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
283
284
        the processes in the group. It manages both CPU and device
        communication.
285
286
287
288
    """

    # available attributes:
    rank: int  # global rank
289
    ranks: list[int]  # global ranks in the group
290
291
292
293
294
295
296
297
298
299
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
300
301
302
303
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    # device communicator (if use_device_communicator=True)
    device_communicator: Optional[DeviceCommunicatorBase]
304
    mq_broadcaster: Optional[Any]  # shared memory broadcaster
305
306
307

    def __init__(
        self,
308
        group_ranks: list[list[int]],
309
310
        local_rank: int,
        torch_distributed_backend: Union[str, Backend],
311
        use_device_communicator: bool,  # whether to use device communicator
312
        use_message_queue_broadcaster: bool = False,
313
        group_name: Optional[str] = None,
314
    ):
315
316
317
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
318
319
320

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
321
322
323

        self_device_group = None
        self_cpu_group = None
324
325
326

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
327
328
                ranks, backend=torch_distributed_backend
            )
329
330
331
332
333
334
335
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
336
337
338
339
340
                self_device_group = device_group
                self_cpu_group = cpu_group

        assert self_cpu_group is not None
        assert self_device_group is not None
341

342
343
        self.cpu_group = self_cpu_group
        self.device_group = self_device_group
344

345
        from vllm.platforms import current_platform
346

347
        if current_platform.is_cuda_alike():
348
            self.device = torch.device(f"cuda:{local_rank}")
349
350
        elif current_platform.is_xpu():
            self.device = torch.device(f"xpu:{local_rank}")
351
        elif current_platform.is_out_of_tree():
352
            self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
353
354
355
        else:
            self.device = torch.device("cpu")

356
        self.use_device_communicator = use_device_communicator
357
        self.device_communicator = None
358
359
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
360
361
                current_platform.get_device_communicator_cls()
            )
362
363
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
364
                device=self.device,
365
366
                device_group=self.device_group,
                unique_name=self.unique_name,
367
368
            )

369
370
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

371
372
373
        self.mq_broadcaster: Optional[MessageQueue] = None
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
374
375
                self.cpu_group, 1 << 22, 6
            )
376

377
378
        from vllm.platforms import current_platform

379
380
381
382
383
384
385
        self.use_custom_op_call = (
            current_platform.is_cuda_alike() or current_platform.is_tpu()
        )

        self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
            torch.ops._C, "init_shm_manager"
        )
386

387
388
389
390
391
392
393
394
395
396
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

397
398
399
400
401
402
403
404
405
406
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
    def graph_capture(
423
424
        self, graph_capture_context: Optional[GraphCaptureContext] = None
    ):
425
426
427
428
429
430
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

431
432
433
434
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
435
436
437
            CudaCommunicator,
        )

438
439
440
441
442
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
443
444
445
446
447
448
449

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

450
        with torch.cuda.stream(stream), maybe_ca_context:
451
            yield graph_capture_context
452
453
454

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
455
456
457
458
459
460
461
462
463
464
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
465
466
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
467
468
469
470
471
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

472
        if self.use_custom_op_call:
473
            return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
474
475
        else:
            return self._all_reduce_out_place(input_)
476

477
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
478
479
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
480
        return self.device_communicator.all_reduce(input_)
481
482
483
484
485
486
487

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
488
489
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
490

491
        if self.use_custom_op_call:
492
493
494
            return torch.ops.vllm.all_gather(
                input_, dim, world_size, group_name=self.unique_name
            )
495
496
497
        else:
            return self._all_gather_out_place(input_, dim)

498
    def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
499
500
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
501
        return self.device_communicator.all_gather(input_, dim)
502

503
504
505
506
507
508
    def all_gatherv(
        self,
        input_: Union[torch.Tensor, list[torch.Tensor]],
        dim: int = 0,
        sizes: Optional[list[int]] = None,
    ):
509
510
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
511
512
        return self.device_communicator.all_gatherv(input_, dim, sizes)

513
    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
514
515
516
517
518
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
519
520
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
521

522
        if self.use_custom_op_call:
523
524
525
            return torch.ops.vllm.reduce_scatter(
                input_, dim, world_size, group_name=self.unique_name
            )
526
527
528
        else:
            return self._reduce_scatter_out_place(input_, dim)

529
530
531
    def reduce_scatterv(
        self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
    ) -> torch.Tensor:
532
533
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
534
535
        return self.device_communicator.reduce_scatterv(input_, dim, sizes)

536
    def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
537
538
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
539
540
        return self.device_communicator.reduce_scatter(input_, dim)

541
542
543
    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
    ) -> Optional[torch.Tensor]:
544
545
546
547
548
549
550
551
552
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
553
554
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
555
        return self.device_communicator.gather(input_, dst, dim)
556
557
558
559
560
561
562
563
564
565
566

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
567
568
569
        torch.distributed.broadcast(
            input_, src=self.ranks[src], group=self.device_group
        )
570
571
        return input_

572
573
574
575
576
577
578
579
580
    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
581
582
583
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
584
        if self.rank_in_group == src:
585
586
587
            torch.distributed.broadcast_object_list(
                [obj], src=self.ranks[src], group=self.cpu_group
            )
588
589
590
            return obj
        else:
            recv = [None]
591
592
593
            torch.distributed.broadcast_object_list(
                recv, src=self.ranks[src], group=self.cpu_group
            )
594
595
            return recv[0]

596
597
598
    def broadcast_object_list(
        self, obj_list: list[Any], src: int = 0, group: Optional[ProcessGroup] = None
    ):
599
600
601
602
603
604
605
606
607
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
608
609
610
        torch.distributed.broadcast_object_list(
            obj_list, src=self.ranks[src], group=self.device_group
        )
611
612
        return obj_list

613
614
615
616
617
618
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

619
        assert dst != self.rank_in_group, (
620
            "Invalid destination rank. Destination rank is the same "
621
622
            "as the current rank."
        )
623
624
625
626

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

627
628
629
        size_tensor = torch.tensor(
            [object_tensor.numel()], dtype=torch.long, device="cpu"
        )
630
631
632

        # Send object size

633
        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
634
635

        # Send object
636
        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
637
638
639
640
641
642
643
644
645

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

646
        assert src != self.rank_in_group, (
647
648
649
650
651
652
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
653
654
655
        rank_size = torch.distributed.recv(
            size_tensor, src=self.ranks[src], group=self.cpu_group
        )
656
657
658
659
660

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
661
662
            device="cpu",
        )
663

664
665
666
        rank_object = torch.distributed.recv(
            object_tensor, src=self.ranks[src], group=self.cpu_group
        )
667
668

        assert rank_object == rank_size, (
669
670
            "Received object sender rank does not match the size sender rank."
        )
671
672
673
674
675

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

676
677
    def broadcast_tensor_dict(
        self,
678
        tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
679
680
        src: int = 0,
        group: Optional[ProcessGroup] = None,
681
        metadata_group: Optional[ProcessGroup] = None,
682
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
683
684
685
686
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
687
        if not torch.distributed.is_initialized() or self.world_size == 1:
688
689
690
691
692
693
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

694
695
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
696
            metadata_list: list[tuple[Any, Any]] = []
697
698
699
            assert isinstance(tensor_dict, dict), (
                f"Expecting a dictionary, got {type(tensor_dict)}"
            )
700
701
702
703
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
704
            self.broadcast_object(metadata_list, src=src)
705
706
707
708
709
710
711
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
712
713
714
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=metadata_group, async_op=True
                    )
715
716
                else:
                    # use group for GPU tensors
717
718
719
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=group, async_op=True
                    )
720
721
722
723
724
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
725
            metadata_list = self.broadcast_object(None, src=src)
726
727
            tensor_dict = {}
            async_handles = []
728
            for key, value in metadata_list:
729
                if isinstance(value, TensorMetadata):
730
731
732
                    tensor = torch.empty(
                        value.size, dtype=value.dtype, device=value.device
                    )
733
734
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
735
                        tensor_dict[key] = tensor
736
737
738
739
740
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
741
                            src=self.ranks[src],
742
                            group=metadata_group,
743
744
                            async_op=True,
                        )
745
746
                    else:
                        # use group for GPU tensors
747
                        handle = torch.distributed.broadcast(
748
749
                            tensor, src=self.ranks[src], group=group, async_op=True
                        )
750
                    async_handles.append(handle)
751
                    tensor_dict[key] = tensor
752
                else:
753
                    tensor_dict[key] = value
754
755
756
757
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

758
759
    def send_tensor_dict(
        self,
760
        tensor_dict: dict[str, Union[torch.Tensor, Any]],
761
762
        dst: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
763
        all_gather_tensors: Optional[dict[str, bool]] = None,
764
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
765
766
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
782
783
784
785
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict
786
787
788
789
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
790

791
792
793
794
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
795
            dst = (self.rank_in_group + 1) % self.world_size
796
797
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

798
        if self.use_cpu_custom_send_recv:
799
800
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
801
            self.device_communicator.send_tensor_dict(  # type: ignore
802
803
                tensor_dict, dst
            )
804
805
            return None

806
        metadata_list: list[tuple[Any, Any]] = []
807
808
809
        assert isinstance(tensor_dict, dict), (
            f"Expecting a dictionary, got {type(tensor_dict)}"
        )
810
811
812
813
814
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
815

816
        tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
817
818
819
        assert len(tensor_keys) == len(tensor_list)

        for key, tensor in zip(tensor_keys, tensor_list):
820
821
822
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
823
824

            # send-allgather: send only a slice, then do allgather.
825
826
827
828
829
830
831
832
            use_all_gather = (
                all_gather_group is not None and tensor.numel() % all_gather_size == 0
            )
            use_all_gather = (
                all_gather_tensors.get(key, use_all_gather)
                if all_gather_tensors
                else use_all_gather
            )
833
            if use_all_gather:
834
835
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

836
837
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
838
839
840
                torch.distributed.send(
                    tensor, dst=self.ranks[dst], group=metadata_group
                )
841
842
            else:
                # use group for GPU tensors
843
                torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
844
845
846
847
        return None

    def recv_tensor_dict(
        self,
848
849
        src: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
850
        all_gather_tensors: Optional[dict[str, bool]] = None,
851
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
852
853
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
869
870
871
872
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None
873
874
875
876
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
877

878
879
880
881
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
882
            src = (self.rank_in_group - 1) % self.world_size
883
884
        assert src < self.world_size, f"Invalid src rank ({src})"

885
        if self.use_cpu_custom_send_recv:
886
887
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
888
            return self.device_communicator.recv_tensor_dict(  # type: ignore
889
890
                src
            )
891

892
        recv_metadata_list = self.recv_object(src=src)
893
        tensor_dict: dict[str, Any] = {}
894
895
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
896
                tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
897
898
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
899
                    tensor_dict[key] = tensor
900
                    continue
901
902

                # send-allgather: send only a slice, then do allgather.
903
904
905
906
907
908
909
910
911
                use_all_gather = (
                    all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0
                )
                use_all_gather = (
                    all_gather_tensors.get(key, use_all_gather)
                    if all_gather_tensors
                    else use_all_gather
                )
912
913
914

                if use_all_gather:
                    orig_shape = tensor.shape
915
                    tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
916

917
918
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
919
920
921
                    torch.distributed.recv(
                        tensor, src=self.ranks[src], group=metadata_group
                    )
922
923
                else:
                    # use group for GPU tensors
924
                    torch.distributed.recv(tensor, src=self.ranks[src], group=group)
925
926
927
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
928
929
                        tensor, dim=0
                    )
930
931
                    tensor = tensor.reshape(orig_shape)

932
                tensor_dict[key] = tensor
933
            else:
934
                tensor_dict[key] = value
935
936
        return tensor_dict

937
938
939
940
941
942
943
944
945
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

946
    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
947
        """Sends a tensor to the destination rank in a blocking way"""
948
        """NOTE: `dst` is the local rank of the destination rank."""
949
950
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
951
        self.device_communicator.send(tensor, dst)
952

953
954
955
    def recv(
        self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
    ) -> torch.Tensor:
956
957
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
958
959
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
960
        return self.device_communicator.recv(size, dtype, src)
961

962
    def destroy(self):
963
        if hasattr(self, "device_group"):
964
            torch.distributed.destroy_process_group(self.device_group)
965
966
            del self.device_group
        if hasattr(self, "cpu_group"):
967
            torch.distributed.destroy_process_group(self.cpu_group)
968
            del self.cpu_group
969
970
        if self.device_communicator is not None:
            self.device_communicator.destroy()
971
972
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
973

974
975
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
976
            self.device_communicator.prepare_communication_buffer_for_model(model)
977
978

    def dispatch(
979
980
981
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
982
        is_sequence_parallel: bool = False,
983
    ) -> tuple[torch.Tensor, torch.Tensor]:
984
        if self.device_communicator is not None:
985
986
987
            return self.device_communicator.dispatch(
                hidden_states, router_logits, is_sequence_parallel
            )
988
989
        else:
            return hidden_states, router_logits
990

991
992
993
    def combine(
        self, hidden_states, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
994
        if self.device_communicator is not None:
995
            return self.device_communicator.combine(hidden_states, is_sequence_parallel)
996
997
        else:
            return hidden_states
998

999
1000

_WORLD: Optional[GroupCoordinator] = None
1001
_NODE_COUNT: Optional[int] = None
1002
1003
1004


def get_world_group() -> GroupCoordinator:
1005
    assert _WORLD is not None, "world group is not initialized"
1006
1007
1008
    return _WORLD


1009
1010
1011
def init_world_group(
    ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
1012
1013
1014
1015
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
1016
        use_device_communicator=False,
1017
        group_name="world",
1018
1019
1020
    )


1021
def init_model_parallel_group(
1022
    group_ranks: list[list[int]],
1023
1024
1025
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
1026
    group_name: Optional[str] = None,
1027
) -> GroupCoordinator:
1028
1029
1030
1031
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
1032
        use_device_communicator=True,
1033
        use_message_queue_broadcaster=use_message_queue_broadcaster,
1034
        group_name=group_name,
1035
1036
1037
    )


1038
1039
1040
1041
_TP: Optional[GroupCoordinator] = None


def get_tp_group() -> GroupCoordinator:
1042
    assert _TP is not None, "tensor model parallel group is not initialized"
1043
1044
1045
    return _TP


1046
1047
1048
1049
1050
@deprecated(
    "`get_tensor_model_parallel_group` has been replaced with "
    "`get_tp_group` and may be removed after v0.12. Please use "
    "`get_tp_group` instead."
)
1051
1052
1053
def get_tensor_model_parallel_group():
    return get_tp_group()

1054

1055
1056
1057
1058
_DCP: Optional[GroupCoordinator] = None


def get_dcp_group() -> GroupCoordinator:
1059
    assert _DCP is not None, "decode context model parallel group is not initialized"
1060
1061
1062
1063
1064
1065
    return _DCP


# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group

1066
1067
_PP: Optional[GroupCoordinator] = None

1068
1069
1070
1071
_DP: Optional[GroupCoordinator] = None


def get_dp_group() -> GroupCoordinator:
1072
    assert _DP is not None, "data parallel group is not initialized"
1073
1074
    return _DP

1075

1076
1077
1078
1079
_EP: Optional[GroupCoordinator] = None


def get_ep_group() -> GroupCoordinator:
1080
    assert _EP is not None, "expert parallel group is not initialized"
1081
1082
1083
    return _EP


1084
def get_pp_group() -> GroupCoordinator:
1085
    assert _PP is not None, "pipeline model parallel group is not initialized"
1086
    return _PP
1087
1088


1089
1090
1091
1092
1093
@deprecated(
    "`get_pipeline_model_parallel_group` has been replaced with "
    "`get_pp_group` and may be removed in v0.12. Please use "
    "`get_pp_group` instead."
)
1094
1095
def get_pipeline_model_parallel_group():
    return get_pp_group()
1096
1097


1098
@contextmanager
1099
def graph_capture(device: torch.device):
1100
1101
    """
    `graph_capture` is a context manager which should surround the code that
1102
1103
    is capturing the CUDA graph. Its main purpose is to ensure that some
    operations will be run after the graph is captured, before the graph
1104
1105
1106
1107
1108
1109
1110
1111
1112
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
1113
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
1114
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
1115
1116
        yield context

1117

1118
logger = init_logger(__name__)
1119

1120
_ENABLE_CUSTOM_ALL_REDUCE = True
1121
1122


1123
1124
1125
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
1126

Zhuohan Li's avatar
Zhuohan Li committed
1127

1128
1129
1130
1131
1132
1133
1134
1135
def init_distributed_environment(
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
    local_rank: int = -1,
    backend: str = "nccl",
    timeout: Optional[timedelta] = None,
):
1136
    logger.debug(
1137
1138
1139
1140
1141
1142
1143
        "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
        world_size,
        rank,
        local_rank,
        distributed_init_method,
        backend,
    )
1144
    from vllm.config import get_current_vllm_config
1145

1146
    config = get_current_vllm_config()
1147
1148
1149
1150
1151
    if (
        config is not None
        and config.parallel_config.data_parallel_size > 1
        and config.parallel_config.distributed_executor_backend != "external_launcher"
    ):
1152
1153
1154
1155
1156
1157
1158
1159
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
        ip = parallel_config.data_parallel_master_ip
        port = parallel_config.get_next_dp_init_port()
1160
        distributed_init_method = get_distributed_init_method(ip, port)
1161
1162
        logger.info(
            "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
1163
1164
1165
1166
            world_size,
            rank,
            distributed_init_method,
        )
1167
1168
1169
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
1170
1171
            "distributed environment"
        )
1172
1173
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
1174
1175
1176
                "Distributed backend %s is not available; falling back to gloo.",
                backend,
            )
1177
            assert torch.distributed.is_gloo_available(), (
1178
1179
                "Fallback Gloo backend is not available."
            )
1180
            backend = "gloo"
1181
1182
1183
1184
1185
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
1186
            rank=rank,
1187
1188
            timeout=timeout,
        )
1189
1190
1191
1192
1193
1194
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
1195
        local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
1196
    global _WORLD, _NODE_COUNT
1197
    if _WORLD is None:
1198
        ranks = list(range(torch.distributed.get_world_size()))
1199
        _WORLD = init_world_group(ranks, local_rank, backend)
1200
        _NODE_COUNT = _node_count(_WORLD.cpu_group)
1201
        logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
1202
1203
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
1204
1205
            "world group already initialized with a different world size"
        )
1206
1207


Zhuohan Li's avatar
Zhuohan Li committed
1208
1209
1210
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
1211
    decode_context_model_parallel_size: Optional[int] = 1,
1212
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
1213
1214
) -> None:
    """
1215
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
1216
1217

    Arguments:
1218
1219
1220
1221
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.
1222
        backend: name of torch distributed communication backend.
1223
1224

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1225
1226
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1227
1228
1229
1230
1231
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1232
1233
1234
1235
1236
1237
1238
1239
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1240
    rank = torch.distributed.get_rank()
1241
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1242

1243
1244
    data_parallel_size = 1
    from vllm.config import get_current_vllm_config
1245

1246
1247
    config = get_current_vllm_config()
    if config is not None:
1248
1249
1250
1251
1252
1253
1254
1255
1256
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1257
1258
1259
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1260
1261
        -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size
    )  # noqa
1262

1263
1264
    # Build the tensor model-parallel groups.
    global _TP
1265
    assert _TP is None, "tensor model parallel group is already initialized"
1266
1267
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1268
1269

    # message queue broadcaster is only used in tensor model parallel group
1270
1271
1272
1273
1274
1275
1276
    _TP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="tp",
    )
1277

1278
1279
    # Build the DCP model-parallel groups.
    global _DCP
1280
    assert _DCP is None, "decode context model parallel group is already initialized"
1281
1282
    # Note(hc): In the current implementation of decode context parallel,
    # dcp_size must not exceed tp_size, because the world size does not
1283
    # change by DCP, it simply reuses the GPUs of TP group, and split one
1284
    # TP group into tp_size//dcp_size DCP groups.
1285
    group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
1286
    group_ranks = [x.tolist() for x in group_ranks]
1287
1288
1289
1290
1291
1292
1293
    _DCP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="dcp",
    )
1294

1295
    # Build the pipeline model-parallel groups.
1296
    global _PP
1297
1298
1299
1300
    assert _PP is None, "pipeline model parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0)
    )
1301
    group_ranks = [x.tolist() for x in group_ranks]
1302
1303
1304
    _PP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pp"
    )
1305

1306
    global _DP
1307
1308
    assert _DP is None, "data parallel group is already initialized"
    group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0)
1309
    group_ranks = [x.tolist() for x in group_ranks]
1310
1311
1312
    _DP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="dp"
    )
1313

1314
    global _EP
1315
1316
1317
1318
1319
1320
    assert _EP is None, "expert parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(1, 2)
        .reshape(-1, data_parallel_size * tensor_model_parallel_size)
        .unbind(0)
    )
1321
    group_ranks = [x.tolist() for x in group_ranks]
1322
1323
1324
    _EP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="ep"
    )
1325

1326
1327
    logger.info(
        "rank %s in world size %s is assigned as "
1328
1329
1330
1331
1332
1333
1334
1335
        "DP rank %s, PP rank %s, TP rank %s, EP rank %s",
        rank,
        world_size,
        _DP.rank_in_group,
        _PP.rank_in_group,
        _TP.rank_in_group,
        _EP.rank_in_group,
    )
1336

Zhuohan Li's avatar
Zhuohan Li committed
1337

1338
1339
1340
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1341
    decode_context_model_parallel_size: Optional[int] = 1,
1342
    backend: Optional[str] = None,
1343
1344
1345
1346
1347
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1348
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
1349
    if not model_parallel_is_initialized():
1350
1351
1352
1353
1354
1355
        initialize_model_parallel(
            tensor_model_parallel_size,
            pipeline_model_parallel_size,
            decode_context_model_parallel_size,
            backend,
        )
1356
1357
        return

1358
1359
    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
        "tensor parallel group already initialized, but of unexpected size. "
1360
        f"got: {get_tensor_model_parallel_world_size()=} vs. "
1361
1362
        f"wanted: {tensor_model_parallel_size=}"
    )
1363
    pp_world_size = get_pp_group().world_size
1364
    assert pp_world_size == pipeline_model_parallel_size, (
1365
1366
        "pipeline parallel group already initialized, but of unexpected size. "
        f"got: {pp_world_size=} vs. "
1367
1368
        f"wanted: {pipeline_model_parallel_size=}"
    )
1369
1370


1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)


Zhuohan Li's avatar
Zhuohan Li committed
1388
def model_parallel_is_initialized():
1389
    """Check if tensor and pipeline parallel groups are initialized."""
1390
    return _TP is not None and _PP is not None
1391
1392


1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1421
1422
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1423
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1424
1425
1426
1427


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1428
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1429
1430


1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
def get_decode_context_model_parallel_world_size():
    """Return world size for the decode context model parallel group."""
    return get_dcp_group().world_size


def get_decode_context_model_parallel_rank():
    """Return my rank for the decode context model parallel group."""
    return get_dcp_group().rank_in_group


1441
def get_node_count() -> int:
1442
1443
    """Return the total number of nodes in the distributed environment."""
    assert _NODE_COUNT is not None, "distributed environment is not initialized"
1444
1445
1446
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1447
def destroy_model_parallel():
1448
    """Set the groups to none and destroy them."""
1449
    global _TP
1450

1451
1452
1453
1454
1455
1456
1457
1458
1459
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1460
1461
1462
1463
1464
    global _DCP
    if _DCP:
        _DCP.destroy()
    _DCP = None

1465
1466
1467
1468
1469
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1470
1471
1472
1473
1474
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1475
1476

def destroy_distributed_environment():
1477
    global _WORLD, _NODE_COUNT
1478
1479
1480
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1481
    _NODE_COUNT = None
1482
1483
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1484
1485


1486
1487
1488
1489
1490
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    if shutdown_ray:
        import ray  # Lazy import Ray
1491

1492
1493
        ray.shutdown()
    gc.collect()
1494
    from vllm.platforms import current_platform
1495

1496
1497
1498
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1499
    try:
1500
1501
        if not current_platform.is_cpu():
            torch._C._host_emptyCache()
1502
    except AttributeError:
1503
        logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5")
1504
1505


1506
1507
1508
def in_the_same_node_as(
    pg: Union[ProcessGroup, StatelessProcessGroup], source_rank: int = 0
) -> list[bool]:
1509
    """
1510
1511
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1512
1513
    memory system (shared access to shared memory).
    """
1514
    if isinstance(pg, ProcessGroup):
1515
1516
1517
        assert torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL, (
            "in_the_same_node_as should be tested with a non-NCCL group."
        )
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1528
1529
1530
1531
1532
1533
1534
1535
1536

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1537
            if rank == source_rank:
1538
1539
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
1540
                shm.buf[: len(magic_message)] = magic_message
1541
1542
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
1543
1544
                        [shm.name], src=ranks[source_rank], group=pg
                    )
1545
1546
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1547
                is_in_the_same_node[rank] = 1
1548
1549
            else:
                # try to open the shared memory segment
1550
1551
1552
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
1553
1554
                        recv, src=ranks[source_rank], group=pg
                    )
1555
1556
1557
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1558
1559
1560
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
1561
1562
1563
1564
                with patch(
                    "multiprocessing.resource_tracker.register",
                    lambda *args, **kwargs: None,
                ):
1565
                    shm = shared_memory.SharedMemory(name=name)
1566
                if shm.buf[: len(magic_message)] == magic_message:
1567
1568
1569
1570
1571
1572
1573
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1574
1575
1576
1577
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1578
1579
1580

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1581
        if rank == source_rank and shm:
1582
            shm.unlink()
1583

1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
1594
1595


1596
1597
def is_global_first_rank() -> bool:
    """
1598
    Check if the current process is the first rank globally across all
1599
    parallelism strategies (PP, TP, DP, EP, etc.).
1600

1601
1602
1603
    Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
    or `get_pp_group().is_first_rank`, this function checks the global rank
    across all parallelism dimensions.
1604

1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
    Returns:
        bool: True if this is the global first rank (rank 0), False otherwise.
              Returns True if distributed is not initialized (single process).
    """
    try:
        # If world group is available, use it for the most accurate check
        global _WORLD
        if _WORLD is not None:
            return _WORLD.is_first_rank

        # If torch distributed is not initialized, assume single process
        if not torch.distributed.is_initialized():
            return True

        # Fallback to torch's global rank
        return torch.distributed.get_rank() == 0

    except Exception:
        # If anything goes wrong, assume this is the first rank
        return True


1627
1628
1629
1630
1631
1632
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
1633

1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id