parallel_state.py 62.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25

26
import contextlib
27
import gc
28
import pickle
29
import weakref
30
from collections import namedtuple
31
from collections.abc import Callable
32
33
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
34
from datetime import timedelta
35
from multiprocessing import shared_memory
36
from typing import Any, Optional
37
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
38
39

import torch
40
import torch.distributed
41
42
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
43
from torch.distributed import Backend, ProcessGroup
44
from typing_extensions import deprecated
Zhuohan Li's avatar
Zhuohan Li committed
45

46
import vllm.envs as envs
47
from vllm.distributed.device_communicators.base_device_communicator import (
48
49
    DeviceCommunicatorBase,
)
50
from vllm.distributed.utils import StatelessProcessGroup
51
from vllm.logger import init_logger
52
from vllm.utils.import_utils import resolve_obj_by_qualname
53
from vllm.utils.network_utils import get_distributed_init_method
54
55
56
57
from vllm.utils.torch_utils import (
    direct_register_custom_op,
    supports_custom_op,
)
58
59


60
61
62
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
63

64

65
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
66

67
68

def _split_tensor_dict(
69
    tensor_dict: dict[str, torch.Tensor | Any],
70
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
71
72
73
74
75
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
76
77
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
78
79
80
81
82
83
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
84
            device = value.device.type
85
            metadata_list.append(
86
87
                (key, TensorMetadata(device, value.dtype, value.size()))
            )
88
89
            tensor_list.append(value)
        else:
90
            metadata_list.append((key, value))
91
92
93
    return metadata_list, tensor_list


94
_group_name_counter: dict[str, int] = {}
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


110
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
111
112
113


def _register_group(group: "GroupCoordinator") -> None:
114
    _groups[group.unique_name] = weakref.ref(group)
115
116


117
118
119
120
121
122
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
123
124


125
126
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
127

128

129
130
131
def reduce_scatter(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
132
133
134
135
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
136
    return group._reduce_scatter_out_place(tensor, dim)
137
138


139
140
141
def reduce_scatter_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
142
143
144
145
146
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


147
148
149
def all_gather(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
150
151
152
153
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
154
    return group._all_gather_out_place(tensor, dim)
155
156


157
158
159
def all_gather_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
160
161
162
163
164
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def patched_fused_scaled_matmul_reduce_scatter_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    # Copied from
    # https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
    if A_scale.numel() > 1:
        if A_scale.shape[:-1] != A.shape[:-1]:
            raise ValueError(
                "For row-wise scaling, the leading dims of A_scale "
                "must match the leading dims of A "
                f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
            )
        A_scale = A_scale.flatten(0, -2).contiguous()
    elif A_scale.numel() != 1:
        raise ValueError(
            "Invalid A_scale shape "
            f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
        )

    C = torch._scaled_mm(
        A.flatten(0, -2).contiguous(),
        B,
        A_scale,
        B_scale,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )
    C = C.view(*output_shape[:-1], B.shape[1])
    res = funcol.reduce_scatter_tensor(
        C,
        reduce_op,
        orig_scatter_dim,  # need original scatter dim for 3D+ output tensor here
        group_name,
    )
    res = funcol.wait_tensor(res)
    return res


def patched_fused_scaled_matmul_reduce_scatter(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
        A,
        B,
        A_scale,
        B_scale,
        reduce_op,
        orig_scatter_dim,
        scatter_dim_after_maybe_reshape,
        group_name,
        output_shape,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )


249
if supports_custom_op():
250
    direct_register_custom_op(
251
252
253
        op_name="all_reduce",
        op_func=all_reduce,
        fake_impl=all_reduce_fake,
254
255
    )

256
257
258
259
260
261
262
263
264
265
266
267
    direct_register_custom_op(
        op_name="reduce_scatter",
        op_func=reduce_scatter,
        fake_impl=reduce_scatter_fake,
    )

    direct_register_custom_op(
        op_name="all_gather",
        op_func=all_gather,
        fake_impl=all_gather_fake,
    )

268
269
270
271
272
273
274
275
276
    # TODO: Remove this once the pytorch fix
    # (https://github.com/pytorch/pytorch/pull/165086) gets released,
    # in either 2.9.1 or 2.10
    direct_register_custom_op(
        op_name="patched_fused_scaled_matmul_reduce_scatter",
        op_func=patched_fused_scaled_matmul_reduce_scatter,
        fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake,
    )

277

278
279
280
281
282
283
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
284
285
        the processes in the group. It manages both CPU and device
        communication.
286
287
288
289
    """

    # available attributes:
    rank: int  # global rank
290
    ranks: list[int]  # global ranks in the group
291
292
293
294
295
296
297
298
299
300
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
301
302
303
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    # device communicator (if use_device_communicator=True)
304
305
    device_communicator: DeviceCommunicatorBase | None
    mq_broadcaster: Any | None  # shared memory broadcaster
306
307
308

    def __init__(
        self,
309
        group_ranks: list[list[int]],
310
        local_rank: int,
311
        torch_distributed_backend: str | Backend,
312
        use_device_communicator: bool,  # whether to use device communicator
313
        use_message_queue_broadcaster: bool = False,
314
        group_name: str | None = None,
315
    ):
316
317
318
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
319
320
321

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
322
323
324

        self_device_group = None
        self_cpu_group = None
325
326
327

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
328
329
                ranks, backend=torch_distributed_backend
            )
330
331
332
333
334
335
336
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
337
338
339
340
341
                self_device_group = device_group
                self_cpu_group = cpu_group

        assert self_cpu_group is not None
        assert self_device_group is not None
342

343
344
        self.cpu_group = self_cpu_group
        self.device_group = self_device_group
345

346
        from vllm.platforms import current_platform
347

348
        if current_platform.is_cuda_alike():
349
            self.device = torch.device(f"cuda:{local_rank}")
350
351
        elif current_platform.is_xpu():
            self.device = torch.device(f"xpu:{local_rank}")
352
        elif current_platform.is_out_of_tree():
353
            self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
354
355
356
        else:
            self.device = torch.device("cpu")

357
        self.use_device_communicator = use_device_communicator
358
        self.device_communicator = None
359
360
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
361
362
                current_platform.get_device_communicator_cls()
            )
363
364
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
365
                device=self.device,
366
367
                device_group=self.device_group,
                unique_name=self.unique_name,
368
369
            )

370
371
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

372
        self.mq_broadcaster: MessageQueue | None = None
373
374
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
375
376
                self.cpu_group, 1 << 22, 6
            )
377

378
379
        from vllm.platforms import current_platform

380
381
382
383
384
385
386
        self.use_custom_op_call = (
            current_platform.is_cuda_alike() or current_platform.is_tpu()
        )

        self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
            torch.ops._C, "init_shm_manager"
        )
387

388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
    def create_mq_broadcaster(
        self, writer_rank=0, external_writer_handle=None, blocking=True
    ):
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

        return MessageQueue.create_from_process_group(
            self.cpu_group,
            1 << 22,
            6,
            writer_rank=writer_rank,
            external_writer_handle=external_writer_handle,
            blocking=blocking,
        )

    def create_single_reader_mq_broadcasters(
        self, reader_rank_in_group=0, blocking=False
    ):
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

        return MessageQueue.create_from_process_group_single_reader(
            self.cpu_group,
            1 << 22,
            6,
            reader_rank=self.ranks[reader_rank_in_group],
            blocking=blocking,
        )

415
416
417
418
419
420
421
422
423
424
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

425
426
427
428
429
430
431
432
433
434
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
450
    def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None):
451
452
453
454
455
456
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

457
458
459
460
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
461
462
463
            CudaCommunicator,
        )

464
465
466
467
468
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
469
470
471
472
473
474
475

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

476
        with torch.cuda.stream(stream), maybe_ca_context:
477
            yield graph_capture_context
478
479
480

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
481
482
483
484
485
486
487
488
489
490
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
491
492
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
493
494
495
496
497
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

498
        if self.use_custom_op_call:
499
            return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
500
501
        else:
            return self._all_reduce_out_place(input_)
502

503
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
504
505
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
506
        return self.device_communicator.all_reduce(input_)
507
508
509
510
511
512
513

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
514
515
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
516

517
        if self.use_custom_op_call:
518
519
520
            return torch.ops.vllm.all_gather(
                input_, dim, world_size, group_name=self.unique_name
            )
521
522
523
        else:
            return self._all_gather_out_place(input_, dim)

524
    def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
525
526
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
527
        return self.device_communicator.all_gather(input_, dim)
528

529
530
    def all_gatherv(
        self,
531
        input_: torch.Tensor | list[torch.Tensor],
532
        dim: int = 0,
533
        sizes: list[int] | None = None,
534
    ):
535
536
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
537
538
        return self.device_communicator.all_gatherv(input_, dim, sizes)

539
    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
540
541
542
543
544
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
545
546
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
547

548
        if self.use_custom_op_call:
549
550
551
            return torch.ops.vllm.reduce_scatter(
                input_, dim, world_size, group_name=self.unique_name
            )
552
553
554
        else:
            return self._reduce_scatter_out_place(input_, dim)

555
    def reduce_scatterv(
556
        self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
557
    ) -> torch.Tensor:
558
559
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
560
561
        return self.device_communicator.reduce_scatterv(input_, dim, sizes)

562
    def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
563
564
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
565
566
        return self.device_communicator.reduce_scatter(input_, dim)

567
568
    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
569
    ) -> torch.Tensor | None:
570
571
572
573
574
575
576
577
578
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
579
580
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
581
        return self.device_communicator.gather(input_, dst, dim)
582
583
584
585
586
587
588
589
590
591
592

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
593
594
595
        torch.distributed.broadcast(
            input_, src=self.ranks[src], group=self.device_group
        )
596
597
        return input_

598
    def broadcast_object(self, obj: Any | None = None, src: int = 0):
599
600
601
602
603
604
605
606
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
607
608
609
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
610
        if self.rank_in_group == src:
611
612
613
            torch.distributed.broadcast_object_list(
                [obj], src=self.ranks[src], group=self.cpu_group
            )
614
615
616
            return obj
        else:
            recv = [None]
617
618
619
            torch.distributed.broadcast_object_list(
                recv, src=self.ranks[src], group=self.cpu_group
            )
620
621
            return recv[0]

622
    def broadcast_object_list(
623
        self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
624
    ):
625
626
627
628
629
630
631
632
633
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
634
635
636
        torch.distributed.broadcast_object_list(
            obj_list, src=self.ranks[src], group=self.device_group
        )
637
638
        return obj_list

639
640
641
642
643
644
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

645
        assert dst != self.rank_in_group, (
646
            "Invalid destination rank. Destination rank is the same "
647
648
            "as the current rank."
        )
649
650
651
652

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

653
654
655
        size_tensor = torch.tensor(
            [object_tensor.numel()], dtype=torch.long, device="cpu"
        )
656
657
658

        # Send object size

659
        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
660
661

        # Send object
662
        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
663
664
665
666
667
668
669
670
671

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

672
        assert src != self.rank_in_group, (
673
674
675
676
677
678
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
679
680
681
        rank_size = torch.distributed.recv(
            size_tensor, src=self.ranks[src], group=self.cpu_group
        )
682
683
684
685
686

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
687
688
            device="cpu",
        )
689

690
691
692
        rank_object = torch.distributed.recv(
            object_tensor, src=self.ranks[src], group=self.cpu_group
        )
693
694

        assert rank_object == rank_size, (
695
696
            "Received object sender rank does not match the size sender rank."
        )
697
698
699
700
701

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

702
703
    def broadcast_tensor_dict(
        self,
704
        tensor_dict: dict[str, torch.Tensor | Any] | None = None,
705
        src: int = 0,
706
707
708
        group: ProcessGroup | None = None,
        metadata_group: ProcessGroup | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
709
710
711
712
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
713
        if not torch.distributed.is_initialized() or self.world_size == 1:
714
715
716
717
718
719
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

720
721
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
722
            metadata_list: list[tuple[Any, Any]] = []
723
724
725
            assert isinstance(tensor_dict, dict), (
                f"Expecting a dictionary, got {type(tensor_dict)}"
            )
726
727
728
729
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
730
            self.broadcast_object(metadata_list, src=src)
731
732
733
734
735
736
737
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
738
739
740
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=metadata_group, async_op=True
                    )
741
742
                else:
                    # use group for GPU tensors
743
744
745
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=group, async_op=True
                    )
746
747
748
749
750
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
751
            metadata_list = self.broadcast_object(None, src=src)
752
753
            tensor_dict = {}
            async_handles = []
754
            for key, value in metadata_list:
755
                if isinstance(value, TensorMetadata):
756
757
758
                    tensor = torch.empty(
                        value.size, dtype=value.dtype, device=value.device
                    )
759
760
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
761
                        tensor_dict[key] = tensor
762
763
764
765
766
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
767
                            src=self.ranks[src],
768
                            group=metadata_group,
769
770
                            async_op=True,
                        )
771
772
                    else:
                        # use group for GPU tensors
773
                        handle = torch.distributed.broadcast(
774
775
                            tensor, src=self.ranks[src], group=group, async_op=True
                        )
776
                    async_handles.append(handle)
777
                    tensor_dict[key] = tensor
778
                else:
779
                    tensor_dict[key] = value
780
781
782
783
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

784
785
    def send_tensor_dict(
        self,
786
787
        tensor_dict: dict[str, torch.Tensor | Any],
        dst: int | None = None,
788
        all_gather_group: Optional["GroupCoordinator"] = None,
789
790
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
791
792
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
808
809
810
811
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict
812
813
814
815
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
816

817
818
819
820
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
821
            dst = (self.rank_in_group + 1) % self.world_size
822
823
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

824
        if self.use_cpu_custom_send_recv:
825
826
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
827
            self.device_communicator.send_tensor_dict(  # type: ignore
828
829
                tensor_dict, dst
            )
830
831
            return None

832
        metadata_list: list[tuple[Any, Any]] = []
833
834
835
        assert isinstance(tensor_dict, dict), (
            f"Expecting a dictionary, got {type(tensor_dict)}"
        )
836
837
838
839
840
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
841

842
        tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
843
844
845
        assert len(tensor_keys) == len(tensor_list)

        for key, tensor in zip(tensor_keys, tensor_list):
846
847
848
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
849
850

            # send-allgather: send only a slice, then do allgather.
851
852
853
854
855
856
857
858
            use_all_gather = (
                all_gather_group is not None and tensor.numel() % all_gather_size == 0
            )
            use_all_gather = (
                all_gather_tensors.get(key, use_all_gather)
                if all_gather_tensors
                else use_all_gather
            )
859
            if use_all_gather:
860
861
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

862
863
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
864
865
866
                torch.distributed.send(
                    tensor, dst=self.ranks[dst], group=metadata_group
                )
867
868
            else:
                # use group for GPU tensors
869
                torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
870
871
872
873
        return None

    def recv_tensor_dict(
        self,
874
        src: int | None = None,
875
        all_gather_group: Optional["GroupCoordinator"] = None,
876
877
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
878
879
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
895
896
897
898
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None
899
900
901
902
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
903

904
905
906
907
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
908
            src = (self.rank_in_group - 1) % self.world_size
909
910
        assert src < self.world_size, f"Invalid src rank ({src})"

911
        if self.use_cpu_custom_send_recv:
912
913
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
914
            return self.device_communicator.recv_tensor_dict(  # type: ignore
915
916
                src
            )
917

918
        recv_metadata_list = self.recv_object(src=src)
919
        tensor_dict: dict[str, Any] = {}
920
921
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
922
                tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
923
924
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
925
                    tensor_dict[key] = tensor
926
                    continue
927
928

                # send-allgather: send only a slice, then do allgather.
929
930
931
932
933
934
935
936
937
                use_all_gather = (
                    all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0
                )
                use_all_gather = (
                    all_gather_tensors.get(key, use_all_gather)
                    if all_gather_tensors
                    else use_all_gather
                )
938
939
940

                if use_all_gather:
                    orig_shape = tensor.shape
941
                    tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
942

943
944
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
945
946
947
                    torch.distributed.recv(
                        tensor, src=self.ranks[src], group=metadata_group
                    )
948
949
                else:
                    # use group for GPU tensors
950
                    torch.distributed.recv(tensor, src=self.ranks[src], group=group)
951
952
953
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
954
955
                        tensor, dim=0
                    )
956
957
                    tensor = tensor.reshape(orig_shape)

958
                tensor_dict[key] = tensor
959
            else:
960
                tensor_dict[key] = value
961
962
        return tensor_dict

963
964
965
966
967
968
969
970
971
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

972
    def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
973
        """Sends a tensor to the destination rank in a blocking way"""
974
        """NOTE: `dst` is the local rank of the destination rank."""
975
976
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
977
        self.device_communicator.send(tensor, dst)
978

979
    def recv(
980
        self, size: torch.Size, dtype: torch.dtype, src: int | None = None
981
    ) -> torch.Tensor:
982
983
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
984
985
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
986
        return self.device_communicator.recv(size, dtype, src)
987

988
    def destroy(self):
989
        if hasattr(self, "device_group"):
990
            torch.distributed.destroy_process_group(self.device_group)
991
992
            del self.device_group
        if hasattr(self, "cpu_group"):
993
            torch.distributed.destroy_process_group(self.cpu_group)
994
            del self.cpu_group
995
996
        if self.device_communicator is not None:
            self.device_communicator.destroy()
997
998
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
999

1000
1001
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
1002
            self.device_communicator.prepare_communication_buffer_for_model(model)
1003
1004

    def dispatch(
1005
1006
1007
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
1008
        is_sequence_parallel: bool = False,
1009
    ) -> tuple[torch.Tensor, torch.Tensor]:
1010
        if self.device_communicator is not None:
1011
1012
1013
            return self.device_communicator.dispatch(
                hidden_states, router_logits, is_sequence_parallel
            )
1014
1015
        else:
            return hidden_states, router_logits
1016

1017
1018
1019
    def combine(
        self, hidden_states, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
1020
        if self.device_communicator is not None:
1021
            return self.device_communicator.combine(hidden_states, is_sequence_parallel)
1022
1023
        else:
            return hidden_states
1024

1025

1026
_WORLD: GroupCoordinator | None = None
1027
_INNER_DP_WORLD: GroupCoordinator | None = None
1028
_NODE_COUNT: int | None = None
1029
1030
1031


def get_world_group() -> GroupCoordinator:
1032
    assert _WORLD is not None, "world group is not initialized"
1033
1034
1035
    return _WORLD


1036
1037
1038
1039
1040
def get_inner_dp_world_group() -> GroupCoordinator:
    assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
    return _INNER_DP_WORLD


1041
1042
1043
def init_world_group(
    ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
1044
1045
1046
1047
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
1048
        use_device_communicator=False,
1049
        group_name="world",
1050
1051
1052
    )


1053
def init_model_parallel_group(
1054
    group_ranks: list[list[int]],
1055
1056
1057
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
1058
    group_name: str | None = None,
1059
    use_device_communicator: bool = True,
1060
) -> GroupCoordinator:
1061
1062
1063
1064
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
1065
        use_device_communicator=use_device_communicator,
1066
        use_message_queue_broadcaster=use_message_queue_broadcaster,
1067
        group_name=group_name,
1068
1069
1070
    )


1071
_TP: GroupCoordinator | None = None
1072
1073
1074


def get_tp_group() -> GroupCoordinator:
1075
    assert _TP is not None, "tensor model parallel group is not initialized"
1076
1077
1078
    return _TP


1079
1080
1081
1082
1083
@deprecated(
    "`get_tensor_model_parallel_group` has been replaced with "
    "`get_tp_group` and may be removed after v0.12. Please use "
    "`get_tp_group` instead."
)
1084
1085
1086
def get_tensor_model_parallel_group():
    return get_tp_group()

1087

1088
_DCP: GroupCoordinator | None = None
1089
1090
1091


def get_dcp_group() -> GroupCoordinator:
1092
    assert _DCP is not None, "decode context model parallel group is not initialized"
1093
1094
1095
1096
1097
1098
    return _DCP


# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group

1099
_PP: GroupCoordinator | None = None
1100

1101
_DP: GroupCoordinator | None = None
1102
1103
1104


def get_dp_group() -> GroupCoordinator:
1105
    assert _DP is not None, "data parallel group is not initialized"
1106
1107
    return _DP

1108

1109
_EP: GroupCoordinator | None = None
1110
1111
1112


def get_ep_group() -> GroupCoordinator:
1113
    assert _EP is not None, "expert parallel group is not initialized"
1114
1115
1116
    return _EP


1117
def get_pp_group() -> GroupCoordinator:
1118
    assert _PP is not None, "pipeline model parallel group is not initialized"
1119
    return _PP
1120
1121


1122
1123
1124
1125
1126
@deprecated(
    "`get_pipeline_model_parallel_group` has been replaced with "
    "`get_pp_group` and may be removed in v0.12. Please use "
    "`get_pp_group` instead."
)
1127
1128
def get_pipeline_model_parallel_group():
    return get_pp_group()
1129
1130


1131
@contextmanager
1132
def graph_capture(device: torch.device):
1133
1134
    """
    `graph_capture` is a context manager which should surround the code that
1135
1136
    is capturing the CUDA graph. Its main purpose is to ensure that some
    operations will be run after the graph is captured, before the graph
1137
1138
1139
1140
1141
1142
1143
1144
1145
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
1146
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
1147
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
1148
1149
        yield context

1150

1151
logger = init_logger(__name__)
1152

1153
_ENABLE_CUSTOM_ALL_REDUCE = True
1154
1155


1156
1157
1158
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
1159

Zhuohan Li's avatar
Zhuohan Li committed
1160

1161
1162
1163
1164
1165
1166
def init_distributed_environment(
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
    local_rank: int = -1,
    backend: str = "nccl",
1167
    timeout: timedelta | None = None,
1168
):
1169
    logger.debug(
1170
1171
1172
1173
1174
1175
1176
        "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
        world_size,
        rank,
        local_rank,
        distributed_init_method,
        backend,
    )
1177
    from vllm.config import get_current_vllm_config
1178

1179
    config = get_current_vllm_config()
1180
1181
1182
1183
1184
1185
1186
1187
    if config is not None and config.parallel_config.nnodes > 1:
        parallel_config = config.parallel_config
        ip = parallel_config.master_addr
        rank = parallel_config.data_parallel_rank * world_size + rank
        world_size = parallel_config.world_size_across_dp
        port = parallel_config.master_port
        distributed_init_method = get_distributed_init_method(ip, port)
    elif (
1188
1189
1190
1191
        config is not None
        and config.parallel_config.data_parallel_size > 1
        and config.parallel_config.distributed_executor_backend != "external_launcher"
    ):
1192
1193
1194
1195
1196
1197
1198
1199
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
        ip = parallel_config.data_parallel_master_ip
        port = parallel_config.get_next_dp_init_port()
1200
        distributed_init_method = get_distributed_init_method(ip, port)
1201
        logger.debug(
1202
            "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
1203
1204
1205
1206
            world_size,
            rank,
            distributed_init_method,
        )
1207
    if not torch.distributed.is_initialized():
1208
1209
1210
1211
1212
1213
1214
1215
        logger.info(
            "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
            world_size,
            rank,
            local_rank,
            distributed_init_method,
            backend,
        )
1216
1217
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
1218
1219
            "distributed environment"
        )
1220
1221
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
1222
1223
1224
                "Distributed backend %s is not available; falling back to gloo.",
                backend,
            )
1225
            assert torch.distributed.is_gloo_available(), (
1226
1227
                "Fallback Gloo backend is not available."
            )
1228
            backend = "gloo"
1229
1230
1231
1232
1233
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
1234
            rank=rank,
1235
1236
            timeout=timeout,
        )
1237
1238
1239
1240
1241
1242
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
1243
        local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
1244
    global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
1245
    if _WORLD is None:
1246
        ranks = list(range(torch.distributed.get_world_size()))
1247
        _WORLD = init_world_group(ranks, local_rank, backend)
1248
1249
1250
1251
        if config.parallel_config.nnodes > 1:
            _NODE_COUNT = config.parallel_config.nnodes
        else:
            _NODE_COUNT = _node_count(_WORLD.cpu_group)
1252
        logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
1253
1254
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
1255
1256
            "world group already initialized with a different world size"
        )
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
    if config.parallel_config.nnodes_within_dp > 1:
        if parallel_config.data_parallel_size > 1:
            world_size_inner_dp = parallel_config.world_size
            group_ranks = [
                [dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
                for dp_rank in range(parallel_config.data_parallel_size)
            ]
            _INNER_DP_WORLD = init_model_parallel_group(
                group_ranks,
                get_world_group().local_rank,
                backend,
                use_message_queue_broadcaster=True,
                group_name="inner_dp_world",
                use_device_communicator=False,
            )
        else:
            _INNER_DP_WORLD = _WORLD
1274
1275


Zhuohan Li's avatar
Zhuohan Li committed
1276
1277
1278
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
1279
1280
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
Zhuohan Li's avatar
Zhuohan Li committed
1281
1282
) -> None:
    """
1283
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
1284
1285

    Arguments:
1286
1287
1288
1289
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.
1290
        backend: name of torch distributed communication backend.
1291
1292

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1293
1294
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1295
1296
1297
1298
1299
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1300
1301
1302
1303
1304
1305
1306
1307
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1308
    rank = torch.distributed.get_rank()
1309
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1310

1311
1312
    data_parallel_size = 1
    from vllm.config import get_current_vllm_config
1313

1314
1315
    config = get_current_vllm_config()
    if config is not None:
1316
1317
1318
1319
1320
1321
1322
1323
1324
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1325
1326
1327
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1328
1329
        -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size
    )  # noqa
1330

1331
1332
    # Build the tensor model-parallel groups.
    global _TP
1333
    assert _TP is None, "tensor model parallel group is already initialized"
1334
1335
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1336
1337

    # message queue broadcaster is only used in tensor model parallel group
1338
1339
1340
1341
1342
1343
1344
    _TP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="tp",
    )
1345

1346
1347
    # Build the DCP model-parallel groups.
    global _DCP
1348
    assert _DCP is None, "decode context model parallel group is already initialized"
1349
1350
    # Note(hc): In the current implementation of decode context parallel,
    # dcp_size must not exceed tp_size, because the world size does not
1351
    # change by DCP, it simply reuses the GPUs of TP group, and split one
1352
    # TP group into tp_size//dcp_size DCP groups.
1353
    group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
1354
    group_ranks = [x.tolist() for x in group_ranks]
1355
1356
1357
1358
1359
1360
1361
    _DCP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="dcp",
    )
1362

1363
    # Build the pipeline model-parallel groups.
1364
    global _PP
1365
1366
1367
1368
    assert _PP is None, "pipeline model parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0)
    )
1369
    group_ranks = [x.tolist() for x in group_ranks]
1370
1371
1372
    _PP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pp"
    )
1373

1374
    global _DP
1375
1376
    assert _DP is None, "data parallel group is already initialized"
    group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0)
1377
    group_ranks = [x.tolist() for x in group_ranks]
1378
1379
1380
    _DP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="dp"
    )
1381

1382
    global _EP
1383
1384
1385
1386
1387
1388
    assert _EP is None, "expert parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(1, 2)
        .reshape(-1, data_parallel_size * tensor_model_parallel_size)
        .unbind(0)
    )
1389
    group_ranks = [x.tolist() for x in group_ranks]
1390
1391
1392
    _EP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="ep"
    )
1393

1394
    logger.info_once(
1395
        "rank %s in world size %s is assigned as "
1396
1397
1398
1399
1400
1401
1402
1403
        "DP rank %s, PP rank %s, TP rank %s, EP rank %s",
        rank,
        world_size,
        _DP.rank_in_group,
        _PP.rank_in_group,
        _TP.rank_in_group,
        _EP.rank_in_group,
    )
1404

Zhuohan Li's avatar
Zhuohan Li committed
1405

1406
1407
1408
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1409
1410
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
1411
1412
1413
1414
1415
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1416
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
1417
    if not model_parallel_is_initialized():
1418
1419
1420
1421
1422
1423
        initialize_model_parallel(
            tensor_model_parallel_size,
            pipeline_model_parallel_size,
            decode_context_model_parallel_size,
            backend,
        )
1424
1425
        return

1426
1427
    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
        "tensor parallel group already initialized, but of unexpected size. "
1428
        f"got: {get_tensor_model_parallel_world_size()=} vs. "
1429
1430
        f"wanted: {tensor_model_parallel_size=}"
    )
1431
    pp_world_size = get_pp_group().world_size
1432
    assert pp_world_size == pipeline_model_parallel_size, (
1433
1434
        "pipeline parallel group already initialized, but of unexpected size. "
        f"got: {pp_world_size=} vs. "
1435
1436
        f"wanted: {pipeline_model_parallel_size=}"
    )
1437
1438


1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)


Zhuohan Li's avatar
Zhuohan Li committed
1456
def model_parallel_is_initialized():
1457
    """Check if tensor and pipeline parallel groups are initialized."""
1458
    return _TP is not None and _PP is not None
1459
1460


1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1489
1490
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1491
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1492
1493
1494
1495


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1496
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1497
1498


1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
def get_decode_context_model_parallel_world_size():
    """Return world size for the decode context model parallel group."""
    return get_dcp_group().world_size


def get_decode_context_model_parallel_rank():
    """Return my rank for the decode context model parallel group."""
    return get_dcp_group().rank_in_group


1509
def get_node_count() -> int:
1510
1511
    """Return the total number of nodes in the distributed environment."""
    assert _NODE_COUNT is not None, "distributed environment is not initialized"
1512
1513
1514
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1515
def destroy_model_parallel():
1516
    """Set the groups to none and destroy them."""
1517
    global _TP
1518

1519
1520
1521
1522
1523
1524
1525
1526
1527
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1528
1529
1530
1531
1532
    global _DCP
    if _DCP:
        _DCP.destroy()
    _DCP = None

1533
1534
1535
1536
1537
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1538
1539
1540
1541
1542
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1543
1544

def destroy_distributed_environment():
1545
    global _WORLD, _NODE_COUNT
1546
1547
1548
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1549
    _NODE_COUNT = None
1550
1551
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1552
1553


1554
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1555
1556
1557
    # Ensure all objects are not freezed before cleanup
    gc.unfreeze()

1558
1559
1560
1561
    destroy_model_parallel()
    destroy_distributed_environment()
    if shutdown_ray:
        import ray  # Lazy import Ray
1562

1563
1564
        ray.shutdown()
    gc.collect()
1565
    from vllm.platforms import current_platform
1566

1567
1568
1569
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1570
    try:
1571
1572
        if not current_platform.is_cpu():
            torch._C._host_emptyCache()
1573
    except AttributeError:
1574
        logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5")
1575
1576


1577
def in_the_same_node_as(
1578
    pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0
1579
) -> list[bool]:
1580
    """
1581
1582
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1583
1584
    memory system (shared access to shared memory).
    """
1585
    if isinstance(pg, ProcessGroup):
1586
1587
1588
        assert torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL, (
            "in_the_same_node_as should be tested with a non-NCCL group."
        )
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1599
1600

    # local tensor in each process to store the result
1601
1602
1603
    is_in_the_same_node = torch.tensor(
        [0] * world_size, dtype=torch.int32, device="cpu"
    )
1604
1605
1606
1607
1608
1609

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1610
            if rank == source_rank:
1611
1612
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
1613
                shm.buf[: len(magic_message)] = magic_message
1614
1615
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
1616
1617
                        [shm.name], src=ranks[source_rank], group=pg
                    )
1618
1619
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1620
                is_in_the_same_node[rank] = 1
1621
1622
            else:
                # try to open the shared memory segment
1623
1624
1625
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
1626
1627
                        recv, src=ranks[source_rank], group=pg
                    )
1628
1629
1630
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1631
1632
1633
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
1634
1635
1636
1637
                with patch(
                    "multiprocessing.resource_tracker.register",
                    lambda *args, **kwargs: None,
                ):
1638
                    shm = shared_memory.SharedMemory(name=name)
1639
                if shm.buf[: len(magic_message)] == magic_message:
1640
1641
1642
1643
1644
1645
1646
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1647
1648
1649
1650
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1651
1652
1653

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1654
        if rank == source_rank and shm:
1655
            shm.unlink()
1656

1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
1667
1668


1669
1670
def is_global_first_rank() -> bool:
    """
1671
    Check if the current process is the first rank globally across all
1672
    parallelism strategies (PP, TP, DP, EP, etc.).
1673

1674
1675
1676
    Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
    or `get_pp_group().is_first_rank`, this function checks the global rank
    across all parallelism dimensions.
1677

1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
    Returns:
        bool: True if this is the global first rank (rank 0), False otherwise.
              Returns True if distributed is not initialized (single process).
    """
    try:
        # If world group is available, use it for the most accurate check
        global _WORLD
        if _WORLD is not None:
            return _WORLD.is_first_rank

        # If torch distributed is not initialized, assume single process
        if not torch.distributed.is_initialized():
            return True

        # Fallback to torch's global rank
        return torch.distributed.get_rank() == 0

    except Exception:
        # If anything goes wrong, assume this is the first rank
        return True


1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
def is_local_first_rank() -> bool:
    """
    Check if the current process is the first local rank (rank 0 on its node).
    """
    try:
        # prefer the initialized world group if available
        global _WORLD
        if _WORLD is not None:
            return _WORLD.local_rank == 0

        if not torch.distributed.is_initialized():
            return True

        # fallback to environment-provided local rank if available
        # note: envs.LOCAL_RANK is set when using env:// launchers (e.g., torchrun)
        try:
            return int(envs.LOCAL_RANK) == 0  # type: ignore[arg-type]
        except Exception:
            return torch.distributed.get_rank() == 0
    except Exception:
        return True


1723
def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
1724
1725
1726
1727
1728
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
1729

1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id