parallel_state.py 64.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25

26
import contextlib
27
import gc
28
import pickle
29
import weakref
30
from collections import namedtuple
31
from collections.abc import Callable
32
33
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
34
from datetime import timedelta
35
from multiprocessing import shared_memory
36
from typing import Any, Optional
37
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
38
39

import torch
40
import torch.distributed
41
42
import torch.distributed._functional_collectives as funcol
import torch.distributed._symmetric_memory
43
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
44

45
import vllm.envs as envs
46
from vllm.distributed.device_communicators.base_device_communicator import (
47
48
    DeviceCommunicatorBase,
)
49
from vllm.distributed.utils import StatelessProcessGroup
50
from vllm.logger import init_logger
51
from vllm.utils.import_utils import resolve_obj_by_qualname
52
from vllm.utils.network_utils import get_distributed_init_method
53
from vllm.utils.system_utils import suppress_stdout
54
55
56
57
from vllm.utils.torch_utils import (
    direct_register_custom_op,
    supports_custom_op,
)
58
59


60
61
62
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
63

64

65
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
66

67
68

def _split_tensor_dict(
69
    tensor_dict: dict[str, torch.Tensor | Any],
70
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
71
72
73
74
75
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
76
77
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
78
79
80
81
82
83
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
84
            device = value.device.type
85
            metadata_list.append(
86
87
                (key, TensorMetadata(device, value.dtype, value.size()))
            )
88
89
            tensor_list.append(value)
        else:
90
            metadata_list.append((key, value))
91
92
93
    return metadata_list, tensor_list


94
_group_name_counter: dict[str, int] = {}
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


110
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
111
112
113


def _register_group(group: "GroupCoordinator") -> None:
114
    _groups[group.unique_name] = weakref.ref(group)
115
116


117
118
119
120
121
122
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
123
124


125
126
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
127

128

129
130
131
def reduce_scatter(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
132
133
134
135
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
136
    return group._reduce_scatter_out_place(tensor, dim)
137
138


139
140
141
def reduce_scatter_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
142
143
144
145
146
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


147
148
149
def all_gather(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
150
151
152
153
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
154
    return group._all_gather_out_place(tensor, dim)
155
156


157
158
159
def all_gather_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
160
161
162
163
164
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
def patched_fused_scaled_matmul_reduce_scatter_fake(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    # Copied from
    # https://github.com/pytorch/pytorch/blob/50c338c2da905062449e4d9ac807832d1b5cd90e/torch/distributed/_symmetric_memory/__init__.py#L1189
    if A_scale.numel() > 1:
        if A_scale.shape[:-1] != A.shape[:-1]:
            raise ValueError(
                "For row-wise scaling, the leading dims of A_scale "
                "must match the leading dims of A "
                f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
            )
        A_scale = A_scale.flatten(0, -2).contiguous()
    elif A_scale.numel() != 1:
        raise ValueError(
            "Invalid A_scale shape "
            f"(A shape: {A.shape}, A_scale shape: {A_scale.shape})"
        )

    C = torch._scaled_mm(
        A.flatten(0, -2).contiguous(),
        B,
        A_scale,
        B_scale,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )
    C = C.view(*output_shape[:-1], B.shape[1])
    res = funcol.reduce_scatter_tensor(
        C,
        reduce_op,
        orig_scatter_dim,  # need original scatter dim for 3D+ output tensor here
        group_name,
    )
    res = funcol.wait_tensor(res)
    return res


def patched_fused_scaled_matmul_reduce_scatter(
    A: torch.Tensor,
    B: torch.Tensor,
    A_scale: torch.Tensor,
    B_scale: torch.Tensor,
    reduce_op: str,
    orig_scatter_dim: int,
    scatter_dim_after_maybe_reshape: int,
    group_name: str,
    output_shape: list[int],
    bias: torch.Tensor | None = None,
    result_scale: torch.Tensor | None = None,
    out_dtype: torch.dtype | None = None,
    use_fast_accum: bool = False,
) -> torch.Tensor:
    return torch.ops.symm_mem.fused_scaled_matmul_reduce_scatter(
        A,
        B,
        A_scale,
        B_scale,
        reduce_op,
        orig_scatter_dim,
        scatter_dim_after_maybe_reshape,
        group_name,
        output_shape,
        bias,
        result_scale,
        out_dtype,
        use_fast_accum,
    )


249
if supports_custom_op():
250
    direct_register_custom_op(
251
252
253
        op_name="all_reduce",
        op_func=all_reduce,
        fake_impl=all_reduce_fake,
254
255
    )

256
257
258
259
260
261
262
263
264
265
266
267
    direct_register_custom_op(
        op_name="reduce_scatter",
        op_func=reduce_scatter,
        fake_impl=reduce_scatter_fake,
    )

    direct_register_custom_op(
        op_name="all_gather",
        op_func=all_gather,
        fake_impl=all_gather_fake,
    )

268
269
270
271
272
273
274
275
276
    # TODO: Remove this once the pytorch fix
    # (https://github.com/pytorch/pytorch/pull/165086) gets released,
    # in either 2.9.1 or 2.10
    direct_register_custom_op(
        op_name="patched_fused_scaled_matmul_reduce_scatter",
        op_func=patched_fused_scaled_matmul_reduce_scatter,
        fake_impl=patched_fused_scaled_matmul_reduce_scatter_fake,
    )

277

278
279
280
281
282
283
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
284
285
        the processes in the group. It manages both CPU and device
        communication.
286
287
288
289
    """

    # available attributes:
    rank: int  # global rank
290
    ranks: list[int]  # global ranks in the group
291
292
293
294
295
296
297
298
299
300
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
301
302
303
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    # device communicator (if use_device_communicator=True)
304
305
    device_communicator: DeviceCommunicatorBase | None
    mq_broadcaster: Any | None  # shared memory broadcaster
306
307
308

    def __init__(
        self,
309
        group_ranks: list[list[int]],
310
        local_rank: int,
311
        torch_distributed_backend: str | Backend,
312
        use_device_communicator: bool,  # whether to use device communicator
313
        use_message_queue_broadcaster: bool = False,
314
        group_name: str | None = None,
315
    ):
316
317
318
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
319
320
321

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
322
323
324

        self_device_group = None
        self_cpu_group = None
325
326
327

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
328
329
                ranks, backend=torch_distributed_backend
            )
330
331
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
332
333
            with suppress_stdout():
                cpu_group = torch.distributed.new_group(ranks, backend="gloo")
334
335
336
337
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
338
339
340
341
342
                self_device_group = device_group
                self_cpu_group = cpu_group

        assert self_cpu_group is not None
        assert self_device_group is not None
343

344
345
        self.cpu_group = self_cpu_group
        self.device_group = self_device_group
346

347
        from vllm.platforms import current_platform
348

349
        if current_platform.is_cuda_alike():
350
            self.device = torch.device(f"cuda:{local_rank}")
351
352
        elif current_platform.is_xpu():
            self.device = torch.device(f"xpu:{local_rank}")
353
        elif current_platform.is_out_of_tree():
354
            self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
355
356
357
        else:
            self.device = torch.device("cpu")

358
        self.use_device_communicator = use_device_communicator
359
        self.device_communicator = None
360
361
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
362
363
                current_platform.get_device_communicator_cls()
            )
364
365
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
366
                device=self.device,
367
368
                device_group=self.device_group,
                unique_name=self.unique_name,
369
370
            )

371
372
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

373
        self.mq_broadcaster: MessageQueue | None = None
374
375
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
376
377
                self.cpu_group, 1 << 22, 6
            )
378

379
380
        from vllm.platforms import current_platform

381
382
383
384
385
386
387
        self.use_custom_op_call = (
            current_platform.is_cuda_alike() or current_platform.is_tpu()
        )

        self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
            torch.ops._C, "init_shm_manager"
        )
388

389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
    def create_mq_broadcaster(
        self, writer_rank=0, external_writer_handle=None, blocking=True
    ):
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

        return MessageQueue.create_from_process_group(
            self.cpu_group,
            1 << 22,
            6,
            writer_rank=writer_rank,
            external_writer_handle=external_writer_handle,
            blocking=blocking,
        )

    def create_single_reader_mq_broadcasters(
        self, reader_rank_in_group=0, blocking=False
    ):
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

        return MessageQueue.create_from_process_group_single_reader(
            self.cpu_group,
            1 << 22,
            6,
            reader_rank=self.ranks[reader_rank_in_group],
            blocking=blocking,
        )

416
417
418
419
420
421
422
423
424
425
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

426
427
428
429
430
431
432
433
434
435
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
451
    def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None):
452
453
454
455
456
457
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

458
459
460
461
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
462
463
464
            CudaCommunicator,
        )

465
466
467
468
469
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
470
471
472
473
474
475
476

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

477
        with torch.cuda.stream(stream), maybe_ca_context:
478
            yield graph_capture_context
479
480
481

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
482
483
484
485
486
487
488
489
490
491
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
492
493
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
494
495
496
497
498
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

499
        if self.use_custom_op_call:
500
            return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
501
502
        else:
            return self._all_reduce_out_place(input_)
503

504
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
505
506
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
507
        return self.device_communicator.all_reduce(input_)
508
509
510
511
512
513
514

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
515
516
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
517

518
        if self.use_custom_op_call:
519
520
521
            return torch.ops.vllm.all_gather(
                input_, dim, world_size, group_name=self.unique_name
            )
522
523
524
        else:
            return self._all_gather_out_place(input_, dim)

525
    def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
526
527
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
528
        return self.device_communicator.all_gather(input_, dim)
529

530
531
    def all_gatherv(
        self,
532
        input_: torch.Tensor | list[torch.Tensor],
533
        dim: int = 0,
534
        sizes: list[int] | None = None,
535
    ):
536
537
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
538
539
        return self.device_communicator.all_gatherv(input_, dim, sizes)

540
    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
541
542
543
544
545
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
546
547
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
548

549
        if self.use_custom_op_call:
550
551
552
            return torch.ops.vllm.reduce_scatter(
                input_, dim, world_size, group_name=self.unique_name
            )
553
554
555
        else:
            return self._reduce_scatter_out_place(input_, dim)

556
    def reduce_scatterv(
557
        self, input_: torch.Tensor, dim: int = -1, sizes: list[int] | None = None
558
    ) -> torch.Tensor:
559
560
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
561
562
        return self.device_communicator.reduce_scatterv(input_, dim, sizes)

563
    def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
564
565
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
566
567
        return self.device_communicator.reduce_scatter(input_, dim)

568
569
    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
570
    ) -> torch.Tensor | None:
571
572
573
574
575
576
577
578
579
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
580
581
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
582
        return self.device_communicator.gather(input_, dst, dim)
583
584
585
586
587
588
589
590
591
592
593

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
594
595
596
        torch.distributed.broadcast(
            input_, src=self.ranks[src], group=self.device_group
        )
597
598
        return input_

599
    def broadcast_object(self, obj: Any | None = None, src: int = 0):
600
601
602
603
604
605
606
607
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
608
609
610
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
611
        if self.rank_in_group == src:
612
613
614
            torch.distributed.broadcast_object_list(
                [obj], src=self.ranks[src], group=self.cpu_group
            )
615
616
617
            return obj
        else:
            recv = [None]
618
619
620
            torch.distributed.broadcast_object_list(
                recv, src=self.ranks[src], group=self.cpu_group
            )
621
622
            return recv[0]

623
    def broadcast_object_list(
624
        self, obj_list: list[Any], src: int = 0, group: ProcessGroup | None = None
625
    ):
626
627
628
629
630
631
632
633
634
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
635
636
637
        torch.distributed.broadcast_object_list(
            obj_list, src=self.ranks[src], group=self.device_group
        )
638
639
        return obj_list

640
641
642
643
644
645
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

646
        assert dst != self.rank_in_group, (
647
            "Invalid destination rank. Destination rank is the same "
648
649
            "as the current rank."
        )
650
651
652
653

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

654
655
656
        size_tensor = torch.tensor(
            [object_tensor.numel()], dtype=torch.long, device="cpu"
        )
657
658
659

        # Send object size

660
        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
661
662

        # Send object
663
        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
664
665
666
667
668
669
670
671
672

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

673
        assert src != self.rank_in_group, (
674
675
676
677
678
679
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
680
681
682
        rank_size = torch.distributed.recv(
            size_tensor, src=self.ranks[src], group=self.cpu_group
        )
683
684
685
686
687

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
688
689
            device="cpu",
        )
690

691
692
693
        rank_object = torch.distributed.recv(
            object_tensor, src=self.ranks[src], group=self.cpu_group
        )
694
695

        assert rank_object == rank_size, (
696
697
            "Received object sender rank does not match the size sender rank."
        )
698
699
700
701
702

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

703
704
    def broadcast_tensor_dict(
        self,
705
        tensor_dict: dict[str, torch.Tensor | Any] | None = None,
706
        src: int = 0,
707
708
709
        group: ProcessGroup | None = None,
        metadata_group: ProcessGroup | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
710
711
712
713
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
714
        if not torch.distributed.is_initialized() or self.world_size == 1:
715
716
717
718
719
720
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

721
722
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
723
            metadata_list: list[tuple[Any, Any]] = []
724
725
726
            assert isinstance(tensor_dict, dict), (
                f"Expecting a dictionary, got {type(tensor_dict)}"
            )
727
728
729
730
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
731
            self.broadcast_object(metadata_list, src=src)
732
733
734
735
736
737
738
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
739
740
741
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=metadata_group, async_op=True
                    )
742
743
                else:
                    # use group for GPU tensors
744
745
746
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=group, async_op=True
                    )
747
748
749
750
751
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
752
            metadata_list = self.broadcast_object(None, src=src)
753
754
            tensor_dict = {}
            async_handles = []
755
            for key, value in metadata_list:
756
                if isinstance(value, TensorMetadata):
757
758
759
                    tensor = torch.empty(
                        value.size, dtype=value.dtype, device=value.device
                    )
760
761
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
762
                        tensor_dict[key] = tensor
763
764
765
766
767
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
768
                            src=self.ranks[src],
769
                            group=metadata_group,
770
771
                            async_op=True,
                        )
772
773
                    else:
                        # use group for GPU tensors
774
                        handle = torch.distributed.broadcast(
775
776
                            tensor, src=self.ranks[src], group=group, async_op=True
                        )
777
                    async_handles.append(handle)
778
                    tensor_dict[key] = tensor
779
                else:
780
                    tensor_dict[key] = value
781
782
783
784
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

785
786
    def send_tensor_dict(
        self,
787
788
        tensor_dict: dict[str, torch.Tensor | Any],
        dst: int | None = None,
789
        all_gather_group: Optional["GroupCoordinator"] = None,
790
791
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
792
793
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
809
810
811
812
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict
813
814
815
816
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
817

818
819
820
821
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
822
            dst = (self.rank_in_group + 1) % self.world_size
823
824
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

825
        if self.use_cpu_custom_send_recv:
826
827
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
828
            self.device_communicator.send_tensor_dict(  # type: ignore
829
830
                tensor_dict, dst
            )
831
832
            return None

833
        metadata_list: list[tuple[Any, Any]] = []
834
835
836
        assert isinstance(tensor_dict, dict), (
            f"Expecting a dictionary, got {type(tensor_dict)}"
        )
837
838
839
840
841
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
842

843
        tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
844
845
846
        assert len(tensor_keys) == len(tensor_list)

        for key, tensor in zip(tensor_keys, tensor_list):
847
848
849
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
850
851

            # send-allgather: send only a slice, then do allgather.
852
853
854
855
856
857
858
859
            use_all_gather = (
                all_gather_group is not None and tensor.numel() % all_gather_size == 0
            )
            use_all_gather = (
                all_gather_tensors.get(key, use_all_gather)
                if all_gather_tensors
                else use_all_gather
            )
860
            if use_all_gather:
861
862
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

863
864
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
865
866
867
                torch.distributed.send(
                    tensor, dst=self.ranks[dst], group=metadata_group
                )
868
869
            else:
                # use group for GPU tensors
870
                torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
871
872
873
874
        return None

    def recv_tensor_dict(
        self,
875
        src: int | None = None,
876
        all_gather_group: Optional["GroupCoordinator"] = None,
877
878
        all_gather_tensors: dict[str, bool] | None = None,
    ) -> dict[str, torch.Tensor | Any] | None:
879
880
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
896
897
898
899
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None
900
901
902
903
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
904

905
906
907
908
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
909
            src = (self.rank_in_group - 1) % self.world_size
910
911
        assert src < self.world_size, f"Invalid src rank ({src})"

912
        if self.use_cpu_custom_send_recv:
913
914
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
915
            return self.device_communicator.recv_tensor_dict(  # type: ignore
916
917
                src
            )
918

919
        recv_metadata_list = self.recv_object(src=src)
920
        tensor_dict: dict[str, Any] = {}
921
922
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
923
                tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
924
925
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
926
                    tensor_dict[key] = tensor
927
                    continue
928
929

                # send-allgather: send only a slice, then do allgather.
930
931
932
933
934
935
936
937
938
                use_all_gather = (
                    all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0
                )
                use_all_gather = (
                    all_gather_tensors.get(key, use_all_gather)
                    if all_gather_tensors
                    else use_all_gather
                )
939
940
941

                if use_all_gather:
                    orig_shape = tensor.shape
942
                    tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
943

944
945
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
946
947
948
                    torch.distributed.recv(
                        tensor, src=self.ranks[src], group=metadata_group
                    )
949
950
                else:
                    # use group for GPU tensors
951
                    torch.distributed.recv(tensor, src=self.ranks[src], group=group)
952
953
954
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
955
956
                        tensor, dim=0
                    )
957
958
                    tensor = tensor.reshape(orig_shape)

959
                tensor_dict[key] = tensor
960
            else:
961
                tensor_dict[key] = value
962
963
        return tensor_dict

964
965
966
967
968
969
970
971
972
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

973
    def send(self, tensor: torch.Tensor, dst: int | None = None) -> None:
974
        """Sends a tensor to the destination rank in a blocking way"""
975
        """NOTE: `dst` is the local rank of the destination rank."""
976
977
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
978
        self.device_communicator.send(tensor, dst)
979

980
    def recv(
981
        self, size: torch.Size, dtype: torch.dtype, src: int | None = None
982
    ) -> torch.Tensor:
983
984
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
985
986
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
987
        return self.device_communicator.recv(size, dtype, src)
988

989
    def destroy(self):
990
        if hasattr(self, "device_group"):
991
            torch.distributed.destroy_process_group(self.device_group)
992
993
            del self.device_group
        if hasattr(self, "cpu_group"):
994
            torch.distributed.destroy_process_group(self.cpu_group)
995
            del self.cpu_group
996
997
        if self.device_communicator is not None:
            self.device_communicator.destroy()
998
999
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
1000

1001
1002
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
1003
            self.device_communicator.prepare_communication_buffer_for_model(model)
1004
1005

    def dispatch(
1006
1007
1008
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
1009
        is_sequence_parallel: bool = False,
1010
1011
1012
1013
1014
        extra_tensors: list[torch.Tensor] | None = None,
    ) -> (
        tuple[torch.Tensor, torch.Tensor]
        | tuple[torch.Tensor, torch.Tensor, list[torch.Tensor]]
    ):
1015
        if self.device_communicator is not None:
1016
1017
1018
1019
1020
            return self.device_communicator.dispatch(  # type: ignore[call-arg]
                hidden_states,
                router_logits,
                is_sequence_parallel,
                extra_tensors,
1021
            )
1022
1023
        else:
            return hidden_states, router_logits
1024

1025
1026
1027
    def combine(
        self, hidden_states, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
1028
        if self.device_communicator is not None:
1029
            return self.device_communicator.combine(hidden_states, is_sequence_parallel)
1030
1031
        else:
            return hidden_states
1032

1033

1034
_WORLD: GroupCoordinator | None = None
1035
_INNER_DP_WORLD: GroupCoordinator | None = None
1036
_NODE_COUNT: int | None = None
1037
1038
1039


def get_world_group() -> GroupCoordinator:
1040
    assert _WORLD is not None, "world group is not initialized"
1041
1042
1043
    return _WORLD


1044
1045
1046
1047
1048
def get_inner_dp_world_group() -> GroupCoordinator:
    assert _INNER_DP_WORLD is not None, "inner dp world group is not initialized"
    return _INNER_DP_WORLD


1049
1050
1051
def init_world_group(
    ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
1052
1053
1054
1055
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
1056
        use_device_communicator=False,
1057
        group_name="world",
1058
1059
1060
    )


1061
def init_model_parallel_group(
1062
    group_ranks: list[list[int]],
1063
1064
1065
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
1066
    group_name: str | None = None,
1067
    use_device_communicator: bool = True,
1068
) -> GroupCoordinator:
1069
1070
1071
1072
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
1073
        use_device_communicator=use_device_communicator,
1074
        use_message_queue_broadcaster=use_message_queue_broadcaster,
1075
        group_name=group_name,
1076
1077
1078
    )


1079
_TP: GroupCoordinator | None = None
1080
1081
1082


def get_tp_group() -> GroupCoordinator:
1083
    assert _TP is not None, "tensor model parallel group is not initialized"
1084
1085
1086
    return _TP


1087
_DCP: GroupCoordinator | None = None
1088
1089
1090


def get_dcp_group() -> GroupCoordinator:
1091
    assert _DCP is not None, "decode context model parallel group is not initialized"
1092
1093
1094
1095
1096
1097
    return _DCP


# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group

1098
_PP: GroupCoordinator | None = None
1099

1100
1101
1102
1103
1104
1105

def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, "pipeline model parallel group is not initialized"
    return _PP


1106
_DP: GroupCoordinator | None = None
1107
1108
1109


def get_dp_group() -> GroupCoordinator:
1110
    assert _DP is not None, "data parallel group is not initialized"
1111
1112
    return _DP

1113

1114
_EP: GroupCoordinator | None = None
1115
1116
1117


def get_ep_group() -> GroupCoordinator:
1118
1119
1120
1121
1122
    assert _EP is not None, (
        "expert parallel group is not initialized. "
        "EP group is only created for MoE models with num_experts > 0. "
        "This function should only be called for MoE models."
    )
1123
1124
1125
    return _EP


1126
1127
1128
1129
1130
1131
_PCP: GroupCoordinator | None = None


def get_pcp_group() -> GroupCoordinator:
    assert _PCP is not None, "prefill context parallel group is not initialized"
    return _PCP
1132
1133


1134
@contextmanager
1135
def graph_capture(device: torch.device):
1136
1137
    """
    `graph_capture` is a context manager which should surround the code that
1138
1139
    is capturing the CUDA graph. Its main purpose is to ensure that some
    operations will be run after the graph is captured, before the graph
1140
1141
1142
1143
1144
1145
1146
1147
1148
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
1149
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
1150
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
1151
1152
        yield context

1153

1154
logger = init_logger(__name__)
1155

1156
_ENABLE_CUSTOM_ALL_REDUCE = True
1157
1158


1159
1160
1161
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
1162

Zhuohan Li's avatar
Zhuohan Li committed
1163

1164
1165
1166
1167
1168
1169
def init_distributed_environment(
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
    local_rank: int = -1,
    backend: str = "nccl",
1170
    timeout: timedelta | None = None,
1171
):
1172
    logger.debug(
1173
1174
1175
1176
1177
1178
1179
        "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
        world_size,
        rank,
        local_rank,
        distributed_init_method,
        backend,
    )
1180
    from vllm.config import get_current_vllm_config_or_none
1181

1182
    config = get_current_vllm_config_or_none()
1183
    if (
1184
1185
        config is not None
        and config.parallel_config.distributed_executor_backend != "external_launcher"
1186
1187
1188
1189
        and (
            config.parallel_config.nnodes > 1
            or config.parallel_config.data_parallel_size > 1
        )
1190
    ):
1191
1192
1193
1194
1195
1196
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212

        # Use appropriate IP and port based on configuration
        if parallel_config.nnodes > 1:
            ip = parallel_config.master_addr
            port = parallel_config.master_port
            distributed_init_method = get_distributed_init_method(ip, port)
        else:
            ip = parallel_config.data_parallel_master_ip
            port = parallel_config.get_next_dp_init_port()
            distributed_init_method = get_distributed_init_method(ip, port)
            logger.debug(
                "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
                world_size,
                rank,
                distributed_init_method,
            )
1213
    if not torch.distributed.is_initialized():
1214
1215
1216
1217
1218
1219
1220
1221
        logger.info(
            "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
            world_size,
            rank,
            local_rank,
            distributed_init_method,
            backend,
        )
1222
1223
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
1224
1225
            "distributed environment"
        )
1226
1227
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
1228
1229
1230
                "Distributed backend %s is not available; falling back to gloo.",
                backend,
            )
1231
            assert torch.distributed.is_gloo_available(), (
1232
1233
                "Fallback Gloo backend is not available."
            )
1234
            backend = "gloo"
1235
1236
1237
1238
1239
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
1240
            rank=rank,
1241
1242
            timeout=timeout,
        )
1243
1244
1245
1246
1247
1248
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
1249
        local_rank = envs.LOCAL_RANK if distributed_init_method == "env://" else rank
1250
    global _WORLD, _NODE_COUNT, _INNER_DP_WORLD
1251
    if _WORLD is None:
1252
        ranks = list(range(torch.distributed.get_world_size()))
1253
        _WORLD = init_world_group(ranks, local_rank, backend)
1254
        if config is not None and config.parallel_config.nnodes > 1:
1255
1256
1257
            _NODE_COUNT = config.parallel_config.nnodes
        else:
            _NODE_COUNT = _node_count(_WORLD.cpu_group)
1258
        logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
1259
1260
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
1261
1262
            "world group already initialized with a different world size"
        )
1263
    if config is not None and config.parallel_config.nnodes_within_dp > 1:
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
        if parallel_config.data_parallel_size > 1:
            world_size_inner_dp = parallel_config.world_size
            group_ranks = [
                [dp_rank * world_size_inner_dp + i for i in range(world_size_inner_dp)]
                for dp_rank in range(parallel_config.data_parallel_size)
            ]
            _INNER_DP_WORLD = init_model_parallel_group(
                group_ranks,
                get_world_group().local_rank,
                backend,
                use_message_queue_broadcaster=True,
                group_name="inner_dp_world",
                use_device_communicator=False,
            )
        else:
            _INNER_DP_WORLD = _WORLD
1280
1281


Zhuohan Li's avatar
Zhuohan Li committed
1282
1283
1284
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
1285
    prefill_context_model_parallel_size: int = 1,
1286
1287
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
Zhuohan Li's avatar
Zhuohan Li committed
1288
1289
) -> None:
    """
1290
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
1291
1292

    Arguments:
1293
1294
1295
1296
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.
1297
        backend: name of torch distributed communication backend.
1298
1299

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1300
1301
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1302
1303
1304
1305
1306
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1307
1308
1309
1310
1311
1312
1313
1314
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1315
    rank = torch.distributed.get_rank()
1316
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1317

1318
    data_parallel_size = 1
1319
    from vllm.config import get_current_vllm_config_or_none
1320

1321
    config = get_current_vllm_config_or_none()
1322
    if config is not None:
1323
1324
1325
1326
1327
1328
1329
1330
1331
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1332
1333
1334
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1335
1336
1337
1338
1339
        -1,
        data_parallel_size,
        pipeline_model_parallel_size,
        prefill_context_model_parallel_size,
        tensor_model_parallel_size,
1340
    )  # noqa
1341

1342
1343
    # Build the tensor model-parallel groups.
    global _TP
1344
    assert _TP is None, "tensor model parallel group is already initialized"
1345
1346
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1347
1348

    # message queue broadcaster is only used in tensor model parallel group
1349
1350
1351
1352
1353
1354
1355
    _TP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="tp",
    )
1356

1357
1358
    # Build the DCP model-parallel groups.
    global _DCP
1359
    assert _DCP is None, "decode context model parallel group is already initialized"
1360
1361
    # Note(hc): In the current implementation of decode context parallel,
    # dcp_size must not exceed tp_size, because the world size does not
1362
    # change by DCP, it simply reuses the GPUs of TP group, and split one
1363
    # TP group into tp_size//dcp_size DCP groups.
1364
    group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
1365
    group_ranks = [x.tolist() for x in group_ranks]
1366
1367
1368
1369
1370
1371
1372
    _DCP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="dcp",
    )
1373

1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
    global _PCP
    assert _PCP is None, "prefill context parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(3, 4)
        .reshape(-1, prefill_context_model_parallel_size)
        .unbind(0)
    )
    group_ranks = [x.tolist() for x in group_ranks]
    _PCP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pcp"
    )

1386
    # Build the pipeline model-parallel groups.
1387
    global _PP
1388
1389
    assert _PP is None, "pipeline model parallel group is already initialized"
    group_ranks = (
1390
        all_ranks.transpose(2, 4).reshape(-1, pipeline_model_parallel_size).unbind(0)
1391
    )
1392
    group_ranks = [x.tolist() for x in group_ranks]
1393
1394
1395
    _PP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pp"
    )
1396

1397
    global _DP
1398
    assert _DP is None, "data parallel group is already initialized"
1399
    group_ranks = all_ranks.transpose(1, 4).reshape(-1, data_parallel_size).unbind(0)
1400
    group_ranks = [x.tolist() for x in group_ranks]
1401
1402
1403
    _DP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="dp"
    )
1404

1405
    global _EP
1406
    assert _EP is None, "expert parallel group is already initialized"
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
    # Don't create EP group for dense models.
    if config is None or config.model_config is None or config.model_config.is_moe:
        group_ranks = (
            all_ranks.transpose(1, 2)
            .reshape(
                -1,
                data_parallel_size
                * prefill_context_model_parallel_size
                * tensor_model_parallel_size,
            )
            .unbind(0)
1418
        )
1419
1420
1421
1422
1423
        group_ranks = [x.tolist() for x in group_ranks]
        _EP = init_model_parallel_group(
            group_ranks, get_world_group().local_rank, backend, group_name="ep"
        )
    # If no EP group needed, _EP remains None
1424

1425
    logger.info_once(
1426
        "rank %s in world size %s is assigned as "
1427
1428
        "DP rank %s, PP rank %s, PCP rank %s, "
        "TP rank %s, EP rank %s",
1429
1430
1431
1432
        rank,
        world_size,
        _DP.rank_in_group,
        _PP.rank_in_group,
1433
        _PCP.rank_in_group,
1434
        _TP.rank_in_group,
1435
        _EP.rank_in_group if _EP is not None else "N/A",
1436
    )
1437

Zhuohan Li's avatar
Zhuohan Li committed
1438

1439
1440
1441
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1442
    prefill_context_model_parallel_size: int = 1,
1443
1444
    decode_context_model_parallel_size: int | None = 1,
    backend: str | None = None,
1445
1446
1447
1448
1449
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1450
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
1451
    if not model_parallel_is_initialized():
1452
1453
1454
        initialize_model_parallel(
            tensor_model_parallel_size,
            pipeline_model_parallel_size,
1455
            prefill_context_model_parallel_size,
1456
1457
1458
            decode_context_model_parallel_size,
            backend,
        )
1459
1460
        return

1461
1462
    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
        "tensor parallel group already initialized, but of unexpected size. "
1463
        f"got: {get_tensor_model_parallel_world_size()=} vs. "
1464
1465
        f"wanted: {tensor_model_parallel_size=}"
    )
1466
    pp_world_size = get_pp_group().world_size
1467
    assert pp_world_size == pipeline_model_parallel_size, (
1468
1469
        "pipeline parallel group already initialized, but of unexpected size. "
        f"got: {pp_world_size=} vs. "
1470
1471
        f"wanted: {pipeline_model_parallel_size=}"
    )
1472
1473
1474
1475
1476
1477
    pcp_world_size = get_pcp_group().world_size
    assert pcp_world_size == prefill_context_model_parallel_size, (
        "prefill context parallel group already initialized, but of unexpected size: "
        f"{pcp_world_size=} vs. "
        f"{prefill_context_model_parallel_size=}"
    )
1478
1479


1480
1481
1482
1483
1484
1485
1486
1487
1488
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
1489
1490
    if _PCP is not None:
        _PCP.prepare_communication_buffer_for_model(model)
1491
1492
1493
1494
1495
1496
1497
1498
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)


Zhuohan Li's avatar
Zhuohan Li committed
1499
def model_parallel_is_initialized():
1500
    """Check if tensor and pipeline parallel groups are initialized."""
1501
    return _TP is not None and _PP is not None
1502
1503


1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1532
1533
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1534
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1535
1536
1537
1538


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1539
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1540
1541


1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
def get_decode_context_model_parallel_world_size():
    """Return world size for the decode context model parallel group."""
    return get_dcp_group().world_size


def get_decode_context_model_parallel_rank():
    """Return my rank for the decode context model parallel group."""
    return get_dcp_group().rank_in_group


1552
def get_node_count() -> int:
1553
1554
    """Return the total number of nodes in the distributed environment."""
    assert _NODE_COUNT is not None, "distributed environment is not initialized"
1555
1556
1557
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1558
def destroy_model_parallel():
1559
    """Set the groups to none and destroy them."""
1560
    global _TP
1561

1562
1563
1564
1565
    if _TP:
        _TP.destroy()
    _TP = None

1566
1567
1568
1569
1570
    global _DCP
    if _DCP:
        _DCP.destroy()
    _DCP = None

1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
    global _PCP
    if _PCP:
        _PCP.destroy()
    _PCP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1581
1582
1583
1584
1585
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1586
1587
1588
1589
1590
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1591
1592

def destroy_distributed_environment():
1593
    global _WORLD, _NODE_COUNT
1594
1595
1596
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1597
    _NODE_COUNT = None
1598
1599
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1600
1601


1602
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1603
1604
    # Reset environment variable cache
    envs.disable_envs_cache()
1605
    # Ensure all objects are not frozen before cleanup
1606
1607
    gc.unfreeze()

1608
1609
1610
1611
    destroy_model_parallel()
    destroy_distributed_environment()
    if shutdown_ray:
        import ray  # Lazy import Ray
1612

1613
1614
        ray.shutdown()
    gc.collect()
1615
    from vllm.platforms import current_platform
1616

1617
1618
1619
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1620
    try:
1621
1622
        if not current_platform.is_cpu():
            torch._C._host_emptyCache()
1623
    except AttributeError:
1624
        logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5")
1625
1626


1627
def in_the_same_node_as(
1628
    pg: ProcessGroup | StatelessProcessGroup, source_rank: int = 0
1629
) -> list[bool]:
1630
    """
1631
1632
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1633
1634
    memory system (shared access to shared memory).
    """
1635
    if isinstance(pg, ProcessGroup):
1636
1637
1638
        assert torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL, (
            "in_the_same_node_as should be tested with a non-NCCL group."
        )
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1649
1650

    # local tensor in each process to store the result
1651
1652
1653
    is_in_the_same_node = torch.tensor(
        [0] * world_size, dtype=torch.int32, device="cpu"
    )
1654
1655
1656
1657
1658
1659

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1660
            if rank == source_rank:
1661
1662
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
1663
                shm.buf[: len(magic_message)] = magic_message
1664
1665
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
1666
1667
                        [shm.name], src=ranks[source_rank], group=pg
                    )
1668
1669
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1670
                is_in_the_same_node[rank] = 1
1671
1672
            else:
                # try to open the shared memory segment
1673
1674
1675
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
1676
1677
                        recv, src=ranks[source_rank], group=pg
                    )
1678
1679
1680
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1681
1682
1683
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
1684
1685
1686
1687
                with patch(
                    "multiprocessing.resource_tracker.register",
                    lambda *args, **kwargs: None,
                ):
1688
                    shm = shared_memory.SharedMemory(name=name)
1689
                if shm.buf[: len(magic_message)] == magic_message:
1690
1691
1692
1693
1694
1695
1696
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1697
1698
1699
1700
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1701
1702
1703

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1704
        if rank == source_rank and shm:
1705
            shm.unlink()
1706

1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
1717
1718


1719
1720
def is_global_first_rank() -> bool:
    """
1721
    Check if the current process is the first rank globally across all
1722
    parallelism strategies (PP, TP, DP, EP, etc.).
1723

1724
1725
1726
    Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
    or `get_pp_group().is_first_rank`, this function checks the global rank
    across all parallelism dimensions.
1727

1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
    Returns:
        bool: True if this is the global first rank (rank 0), False otherwise.
              Returns True if distributed is not initialized (single process).
    """
    try:
        # If world group is available, use it for the most accurate check
        global _WORLD
        if _WORLD is not None:
            return _WORLD.is_first_rank

        # If torch distributed is not initialized, assume single process
        if not torch.distributed.is_initialized():
            return True

        # Fallback to torch's global rank
        return torch.distributed.get_rank() == 0

    except Exception:
        # If anything goes wrong, assume this is the first rank
        return True


1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
def is_local_first_rank() -> bool:
    """
    Check if the current process is the first local rank (rank 0 on its node).
    """
    try:
        # prefer the initialized world group if available
        global _WORLD
        if _WORLD is not None:
            return _WORLD.local_rank == 0

        if not torch.distributed.is_initialized():
            return True

        # fallback to environment-provided local rank if available
        # note: envs.LOCAL_RANK is set when using env:// launchers (e.g., torchrun)
        try:
            return int(envs.LOCAL_RANK) == 0  # type: ignore[arg-type]
        except Exception:
            return torch.distributed.get_rank() == 0
    except Exception:
        return True


1773
def _node_count(pg: ProcessGroup | StatelessProcessGroup) -> int:
1774
1775
1776
1777
1778
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
1779

1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id