parallel_state.py 56.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25

26
import contextlib
27
import gc
28
import pickle
29
import weakref
30
31
32
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
33
from datetime import timedelta
34
from multiprocessing import shared_memory
35
from typing import Any, Callable, Optional, Union
36
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
37
38

import torch
39
import torch.distributed
40
from torch.distributed import Backend, ProcessGroup
41
from typing_extensions import deprecated
Zhuohan Li's avatar
Zhuohan Li committed
42

43
import vllm.envs as envs
44
from vllm.distributed.device_communicators.base_device_communicator import (
45
46
    DeviceCommunicatorBase,
)
47
from vllm.distributed.utils import StatelessProcessGroup
48
from vllm.logger import init_logger
49
50
51
52
53
54
from vllm.utils import (
    direct_register_custom_op,
    get_distributed_init_method,
    resolve_obj_by_qualname,
    supports_custom_op,
)
55
56


57
58
59
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
60

61

62
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
63

64
65

def _split_tensor_dict(
66
    tensor_dict: dict[str, Union[torch.Tensor, Any]],
67
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
68
69
70
71
72
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
73
74
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
75
76
77
78
79
80
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
81
            device = value.device.type
82
            metadata_list.append(
83
84
                (key, TensorMetadata(device, value.dtype, value.size()))
            )
85
86
            tensor_list.append(value)
        else:
87
            metadata_list.append((key, value))
88
89
90
    return metadata_list, tensor_list


91
_group_name_counter: dict[str, int] = {}
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


107
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
108
109
110


def _register_group(group: "GroupCoordinator") -> None:
111
    _groups[group.unique_name] = weakref.ref(group)
112
113


114
115
116
117
118
119
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
120
121


122
123
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
124

125

126
127
128
def reduce_scatter(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
129
130
131
132
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
133
    return group._reduce_scatter_out_place(tensor, dim)
134
135


136
137
138
def reduce_scatter_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
139
140
141
142
143
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


144
145
146
def all_gather(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
147
148
149
150
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
151
    return group._all_gather_out_place(tensor, dim)
152
153


154
155
156
def all_gather_fake(
    tensor: torch.Tensor, dim: int, world_size: int, group_name: str
) -> torch.Tensor:
157
158
159
160
161
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


162
if supports_custom_op():
163
    direct_register_custom_op(
164
165
166
        op_name="all_reduce",
        op_func=all_reduce,
        fake_impl=all_reduce_fake,
167
168
    )

169
170
171
172
173
174
175
176
177
178
179
180
    direct_register_custom_op(
        op_name="reduce_scatter",
        op_func=reduce_scatter,
        fake_impl=reduce_scatter_fake,
    )

    direct_register_custom_op(
        op_name="all_gather",
        op_func=all_gather,
        fake_impl=all_gather_fake,
    )

181

182
183
184
185
186
187
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
188
189
        the processes in the group. It manages both CPU and device
        communication.
190
191
192
193
    """

    # available attributes:
    rank: int  # global rank
194
    ranks: list[int]  # global ranks in the group
195
196
197
198
199
200
201
202
203
204
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
205
206
207
208
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    # device communicator (if use_device_communicator=True)
    device_communicator: Optional[DeviceCommunicatorBase]
209
    mq_broadcaster: Optional[Any]  # shared memory broadcaster
210
211
212

    def __init__(
        self,
213
        group_ranks: list[list[int]],
214
215
        local_rank: int,
        torch_distributed_backend: Union[str, Backend],
216
        use_device_communicator: bool,  # whether to use device communicator
217
        use_message_queue_broadcaster: bool = False,
218
        group_name: Optional[str] = None,
219
    ):
220
221
222
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
223
224
225

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
226
227
228

        self_device_group = None
        self_cpu_group = None
229
230
231

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
232
233
                ranks, backend=torch_distributed_backend
            )
234
235
236
237
238
239
240
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
241
242
243
244
245
                self_device_group = device_group
                self_cpu_group = cpu_group

        assert self_cpu_group is not None
        assert self_device_group is not None
246

247
248
        self.cpu_group = self_cpu_group
        self.device_group = self_device_group
249

250
        from vllm.platforms import current_platform
251

252
        if current_platform.is_cuda_alike():
253
            self.device = torch.device(f"cuda:{local_rank}")
254
255
        elif current_platform.is_xpu():
            self.device = torch.device(f"xpu:{local_rank}")
256
        elif current_platform.is_out_of_tree():
257
            self.device = torch.device(f"{current_platform.device_name}:{local_rank}")
258
259
260
        else:
            self.device = torch.device("cpu")

261
        self.use_device_communicator = use_device_communicator
262
        self.device_communicator = None
263
264
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
265
266
                current_platform.get_device_communicator_cls()
            )
267
268
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
269
                device=self.device,
270
271
                device_group=self.device_group,
                unique_name=self.unique_name,
272
273
            )

274
275
        from vllm.distributed.device_communicators.shm_broadcast import MessageQueue

276
277
278
        self.mq_broadcaster: Optional[MessageQueue] = None
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
279
280
                self.cpu_group, 1 << 22, 6
            )
281

282
283
        from vllm.platforms import current_platform

284
285
286
287
288
289
290
        self.use_custom_op_call = (
            current_platform.is_cuda_alike() or current_platform.is_tpu()
        )

        self.use_cpu_custom_send_recv = current_platform.is_cpu() and hasattr(
            torch.ops._C, "init_shm_manager"
        )
291

292
293
294
295
296
297
298
299
300
301
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

302
303
304
305
306
307
308
309
310
311
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
    def graph_capture(
328
329
        self, graph_capture_context: Optional[GraphCaptureContext] = None
    ):
330
331
332
333
334
335
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

336
337
338
339
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
340
341
342
            CudaCommunicator,
        )

343
344
345
346
347
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
348
349
350
351
352
353
354

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

355
        with torch.cuda.stream(stream), maybe_ca_context:
356
            yield graph_capture_context
357
358
359

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
360
361
362
363
364
365
366
367
368
369
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
370
371
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
372
373
374
375
376
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

377
        if self.use_custom_op_call:
378
            return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
379
380
        else:
            return self._all_reduce_out_place(input_)
381

382
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
383
384
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
385
        return self.device_communicator.all_reduce(input_)
386
387
388
389
390
391
392

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
393
394
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
395

396
        if self.use_custom_op_call:
397
398
399
            return torch.ops.vllm.all_gather(
                input_, dim, world_size, group_name=self.unique_name
            )
400
401
402
        else:
            return self._all_gather_out_place(input_, dim)

403
    def _all_gather_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
404
405
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
406
        return self.device_communicator.all_gather(input_, dim)
407

408
409
410
411
412
413
    def all_gatherv(
        self,
        input_: Union[torch.Tensor, list[torch.Tensor]],
        dim: int = 0,
        sizes: Optional[list[int]] = None,
    ):
414
415
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
416
417
        return self.device_communicator.all_gatherv(input_, dim, sizes)

418
    def reduce_scatter(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
419
420
421
422
423
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
424
425
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
        )
426

427
        if self.use_custom_op_call:
428
429
430
            return torch.ops.vllm.reduce_scatter(
                input_, dim, world_size, group_name=self.unique_name
            )
431
432
433
        else:
            return self._reduce_scatter_out_place(input_, dim)

434
435
436
    def reduce_scatterv(
        self, input_: torch.Tensor, dim: int = -1, sizes: Optional[list[int]] = None
    ) -> torch.Tensor:
437
438
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
439
440
        return self.device_communicator.reduce_scatterv(input_, dim, sizes)

441
    def _reduce_scatter_out_place(self, input_: torch.Tensor, dim: int) -> torch.Tensor:
442
443
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
444
445
        return self.device_communicator.reduce_scatter(input_, dim)

446
447
448
    def gather(
        self, input_: torch.Tensor, dst: int = 0, dim: int = -1
    ) -> Optional[torch.Tensor]:
449
450
451
452
453
454
455
456
457
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
458
459
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
460
        return self.device_communicator.gather(input_, dst, dim)
461
462
463
464
465
466
467
468
469
470
471

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
472
473
474
        torch.distributed.broadcast(
            input_, src=self.ranks[src], group=self.device_group
        )
475
476
        return input_

477
478
479
480
481
482
483
484
485
    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
486
487
488
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
489
        if self.rank_in_group == src:
490
491
492
            torch.distributed.broadcast_object_list(
                [obj], src=self.ranks[src], group=self.cpu_group
            )
493
494
495
            return obj
        else:
            recv = [None]
496
497
498
            torch.distributed.broadcast_object_list(
                recv, src=self.ranks[src], group=self.cpu_group
            )
499
500
            return recv[0]

501
502
503
    def broadcast_object_list(
        self, obj_list: list[Any], src: int = 0, group: Optional[ProcessGroup] = None
    ):
504
505
506
507
508
509
510
511
512
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
513
514
515
        torch.distributed.broadcast_object_list(
            obj_list, src=self.ranks[src], group=self.device_group
        )
516
517
        return obj_list

518
519
520
521
522
523
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

524
        assert dst != self.rank_in_group, (
525
            "Invalid destination rank. Destination rank is the same "
526
527
            "as the current rank."
        )
528
529
530
531

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

532
533
534
        size_tensor = torch.tensor(
            [object_tensor.numel()], dtype=torch.long, device="cpu"
        )
535
536
537

        # Send object size

538
        torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
539
540

        # Send object
541
        torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
542
543
544
545
546
547
548
549
550

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

551
        assert src != self.rank_in_group, (
552
553
554
555
556
557
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
558
559
560
        rank_size = torch.distributed.recv(
            size_tensor, src=self.ranks[src], group=self.cpu_group
        )
561
562
563
564
565

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
566
567
            device="cpu",
        )
568

569
570
571
        rank_object = torch.distributed.recv(
            object_tensor, src=self.ranks[src], group=self.cpu_group
        )
572
573

        assert rank_object == rank_size, (
574
575
            "Received object sender rank does not match the size sender rank."
        )
576
577
578
579
580

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

581
582
    def broadcast_tensor_dict(
        self,
583
        tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
584
585
        src: int = 0,
        group: Optional[ProcessGroup] = None,
586
        metadata_group: Optional[ProcessGroup] = None,
587
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
588
589
590
591
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
592
        if not torch.distributed.is_initialized() or self.world_size == 1:
593
594
595
596
597
598
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

599
600
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
601
            metadata_list: list[tuple[Any, Any]] = []
602
603
604
            assert isinstance(tensor_dict, dict), (
                f"Expecting a dictionary, got {type(tensor_dict)}"
            )
605
606
607
608
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
609
            self.broadcast_object(metadata_list, src=src)
610
611
612
613
614
615
616
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
617
618
619
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=metadata_group, async_op=True
                    )
620
621
                else:
                    # use group for GPU tensors
622
623
624
                    handle = torch.distributed.broadcast(
                        tensor, src=self.ranks[src], group=group, async_op=True
                    )
625
626
627
628
629
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
630
            metadata_list = self.broadcast_object(None, src=src)
631
632
            tensor_dict = {}
            async_handles = []
633
            for key, value in metadata_list:
634
                if isinstance(value, TensorMetadata):
635
636
637
                    tensor = torch.empty(
                        value.size, dtype=value.dtype, device=value.device
                    )
638
639
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
640
                        tensor_dict[key] = tensor
641
642
643
644
645
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
646
                            src=self.ranks[src],
647
                            group=metadata_group,
648
649
                            async_op=True,
                        )
650
651
                    else:
                        # use group for GPU tensors
652
                        handle = torch.distributed.broadcast(
653
654
                            tensor, src=self.ranks[src], group=group, async_op=True
                        )
655
                    async_handles.append(handle)
656
                    tensor_dict[key] = tensor
657
                else:
658
                    tensor_dict[key] = value
659
660
661
662
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

663
664
    def send_tensor_dict(
        self,
665
        tensor_dict: dict[str, Union[torch.Tensor, Any]],
666
667
        dst: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
668
        all_gather_tensors: Optional[dict[str, bool]] = None,
669
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
670
671
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
687
688
689
690
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict
691
692
693
694
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
695

696
697
698
699
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
700
            dst = (self.rank_in_group + 1) % self.world_size
701
702
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

703
        if self.use_cpu_custom_send_recv:
704
705
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
706
            self.device_communicator.send_tensor_dict(  # type: ignore
707
708
                tensor_dict, dst
            )
709
710
            return None

711
        metadata_list: list[tuple[Any, Any]] = []
712
713
714
        assert isinstance(tensor_dict, dict), (
            f"Expecting a dictionary, got {type(tensor_dict)}"
        )
715
716
717
718
719
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
720

721
        tensor_keys = [k for k, v in tensor_dict.items() if isinstance(v, torch.Tensor)]
722
723
724
        assert len(tensor_keys) == len(tensor_list)

        for key, tensor in zip(tensor_keys, tensor_list):
725
726
727
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
728
729

            # send-allgather: send only a slice, then do allgather.
730
731
732
733
734
735
736
737
            use_all_gather = (
                all_gather_group is not None and tensor.numel() % all_gather_size == 0
            )
            use_all_gather = (
                all_gather_tensors.get(key, use_all_gather)
                if all_gather_tensors
                else use_all_gather
            )
738
            if use_all_gather:
739
740
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

741
742
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
743
744
745
                torch.distributed.send(
                    tensor, dst=self.ranks[dst], group=metadata_group
                )
746
747
            else:
                # use group for GPU tensors
748
                torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
749
750
751
752
        return None

    def recv_tensor_dict(
        self,
753
754
        src: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
755
        all_gather_tensors: Optional[dict[str, bool]] = None,
756
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
757
758
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773

        all_gather_group: The group for the all-gather operation. If provided,
            an optimization is enabled where each rank in the group sends a
            slice of a tensor and the receiver reconstructs it using an
            all-gather, which can improve performance. This is typically the
            tensor-parallel group.
        all_gather_tensors: A dictionary to specify which tensors should use
            the all-gather optimization, which is only effective when
            `all_gather_group` is provided. By default, this optimization is
            on for any tensor whose size is divisible by the
            `all_gather_group`'s world size. However, it should be disabled
            for tensors that are not fully replicated across the group (e.g.,
            the residual tensor when sequence parallelism is enabled). This
            dictionary allows overriding the default behavior on a per-tensor
            basis.
774
775
776
777
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None
778
779
780
781
        all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
        all_gather_rank = (
            0 if all_gather_group is None else all_gather_group.rank_in_group
        )
782

783
784
785
786
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
787
            src = (self.rank_in_group - 1) % self.world_size
788
789
        assert src < self.world_size, f"Invalid src rank ({src})"

790
        if self.use_cpu_custom_send_recv:
791
792
            if self.device_communicator is None:
                raise ValueError("No device communicator found")
793
            return self.device_communicator.recv_tensor_dict(  # type: ignore
794
795
                src
            )
796

797
        recv_metadata_list = self.recv_object(src=src)
798
        tensor_dict: dict[str, Any] = {}
799
800
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
801
                tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
802
803
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
804
                    tensor_dict[key] = tensor
805
                    continue
806
807

                # send-allgather: send only a slice, then do allgather.
808
809
810
811
812
813
814
815
816
                use_all_gather = (
                    all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0
                )
                use_all_gather = (
                    all_gather_tensors.get(key, use_all_gather)
                    if all_gather_tensors
                    else use_all_gather
                )
817
818
819

                if use_all_gather:
                    orig_shape = tensor.shape
820
                    tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
821

822
823
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
824
825
826
                    torch.distributed.recv(
                        tensor, src=self.ranks[src], group=metadata_group
                    )
827
828
                else:
                    # use group for GPU tensors
829
                    torch.distributed.recv(tensor, src=self.ranks[src], group=group)
830
831
832
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
833
834
                        tensor, dim=0
                    )
835
836
                    tensor = tensor.reshape(orig_shape)

837
                tensor_dict[key] = tensor
838
            else:
839
                tensor_dict[key] = value
840
841
        return tensor_dict

842
843
844
845
846
847
848
849
850
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

851
    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
852
        """Sends a tensor to the destination rank in a blocking way"""
853
        """NOTE: `dst` is the local rank of the destination rank."""
854
855
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
856
        self.device_communicator.send(tensor, dst)
857

858
859
860
    def recv(
        self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
    ) -> torch.Tensor:
861
862
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
863
864
        if self.device_communicator is None:
            raise ValueError("No device communicator found")
865
        return self.device_communicator.recv(size, dtype, src)
866

867
    def destroy(self):
868
        if hasattr(self, "device_group"):
869
            torch.distributed.destroy_process_group(self.device_group)
870
871
            del self.device_group
        if hasattr(self, "cpu_group"):
872
            torch.distributed.destroy_process_group(self.cpu_group)
873
            del self.cpu_group
874
875
        if self.device_communicator is not None:
            self.device_communicator.destroy()
876
877
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
878

879
880
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
881
            self.device_communicator.prepare_communication_buffer_for_model(model)
882
883

    def dispatch(
884
885
886
        self,
        hidden_states: torch.Tensor,
        router_logits: torch.Tensor,
887
        is_sequence_parallel: bool = False,
888
    ) -> tuple[torch.Tensor, torch.Tensor]:
889
        if self.device_communicator is not None:
890
891
892
            return self.device_communicator.dispatch(
                hidden_states, router_logits, is_sequence_parallel
            )
893
894
        else:
            return hidden_states, router_logits
895

896
897
898
    def combine(
        self, hidden_states, is_sequence_parallel: bool = False
    ) -> torch.Tensor:
899
        if self.device_communicator is not None:
900
            return self.device_communicator.combine(hidden_states, is_sequence_parallel)
901
902
        else:
            return hidden_states
903

904
905

_WORLD: Optional[GroupCoordinator] = None
906
_NODE_COUNT: Optional[int] = None
907
908
909


def get_world_group() -> GroupCoordinator:
910
    assert _WORLD is not None, "world group is not initialized"
911
912
913
    return _WORLD


914
915
916
def init_world_group(
    ranks: list[int], local_rank: int, backend: str
) -> GroupCoordinator:
917
918
919
920
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
921
        use_device_communicator=False,
922
        group_name="world",
923
924
925
    )


926
def init_model_parallel_group(
927
    group_ranks: list[list[int]],
928
929
930
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
931
    group_name: Optional[str] = None,
932
) -> GroupCoordinator:
933
934
935
936
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
937
        use_device_communicator=True,
938
        use_message_queue_broadcaster=use_message_queue_broadcaster,
939
        group_name=group_name,
940
941
942
    )


943
944
945
946
_TP: Optional[GroupCoordinator] = None


def get_tp_group() -> GroupCoordinator:
947
    assert _TP is not None, "tensor model parallel group is not initialized"
948
949
950
    return _TP


951
952
953
954
955
@deprecated(
    "`get_tensor_model_parallel_group` has been replaced with "
    "`get_tp_group` and may be removed after v0.12. Please use "
    "`get_tp_group` instead."
)
956
957
958
def get_tensor_model_parallel_group():
    return get_tp_group()

959

960
961
962
963
_DCP: Optional[GroupCoordinator] = None


def get_dcp_group() -> GroupCoordinator:
964
    assert _DCP is not None, "decode context model parallel group is not initialized"
965
966
967
968
969
970
    return _DCP


# kept for backward compatibility
get_context_model_parallel_group = get_dcp_group

971
972
_PP: Optional[GroupCoordinator] = None

973
974
975
976
_DP: Optional[GroupCoordinator] = None


def get_dp_group() -> GroupCoordinator:
977
    assert _DP is not None, "data parallel group is not initialized"
978
979
    return _DP

980

981
982
983
984
_EP: Optional[GroupCoordinator] = None


def get_ep_group() -> GroupCoordinator:
985
    assert _EP is not None, "expert parallel group is not initialized"
986
987
988
    return _EP


989
def get_pp_group() -> GroupCoordinator:
990
    assert _PP is not None, "pipeline model parallel group is not initialized"
991
    return _PP
992
993


994
995
996
997
998
@deprecated(
    "`get_pipeline_model_parallel_group` has been replaced with "
    "`get_pp_group` and may be removed in v0.12. Please use "
    "`get_pp_group` instead."
)
999
1000
def get_pipeline_model_parallel_group():
    return get_pp_group()
1001
1002


1003
@contextmanager
1004
def graph_capture(device: torch.device):
1005
1006
    """
    `graph_capture` is a context manager which should surround the code that
1007
1008
    is capturing the CUDA graph. Its main purpose is to ensure that some
    operations will be run after the graph is captured, before the graph
1009
1010
1011
1012
1013
1014
1015
1016
1017
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
1018
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
1019
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(context):
1020
1021
        yield context

1022

1023
logger = init_logger(__name__)
1024

1025
_ENABLE_CUSTOM_ALL_REDUCE = True
1026
1027


1028
1029
1030
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
1031

Zhuohan Li's avatar
Zhuohan Li committed
1032

1033
1034
1035
1036
1037
1038
1039
1040
def init_distributed_environment(
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
    local_rank: int = -1,
    backend: str = "nccl",
    timeout: Optional[timedelta] = None,
):
1041
    logger.debug(
1042
1043
1044
1045
1046
1047
1048
        "world_size=%d rank=%d local_rank=%d distributed_init_method=%s backend=%s",
        world_size,
        rank,
        local_rank,
        distributed_init_method,
        backend,
    )
1049
    from vllm.config import get_current_vllm_config
1050

1051
    config = get_current_vllm_config()
1052
1053
1054
1055
1056
    if (
        config is not None
        and config.parallel_config.data_parallel_size > 1
        and config.parallel_config.distributed_executor_backend != "external_launcher"
    ):
1057
1058
1059
1060
1061
1062
1063
1064
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
        ip = parallel_config.data_parallel_master_ip
        port = parallel_config.get_next_dp_init_port()
1065
        distributed_init_method = get_distributed_init_method(ip, port)
1066
1067
        logger.info(
            "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
1068
1069
1070
1071
            world_size,
            rank,
            distributed_init_method,
        )
1072
1073
1074
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
1075
1076
            "distributed environment"
        )
1077
1078
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
1079
1080
1081
                "Distributed backend %s is not available; falling back to gloo.",
                backend,
            )
1082
            assert torch.distributed.is_gloo_available(), (
1083
1084
                "Fallback Gloo backend is not available."
            )
1085
            backend = "gloo"
1086
1087
1088
1089
1090
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
1091
            rank=rank,
1092
1093
            timeout=timeout,
        )
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
        if distributed_init_method == "env://":
            local_rank = envs.LOCAL_RANK
        else:
            local_rank = rank
1104
    global _WORLD, _NODE_COUNT
1105
    if _WORLD is None:
1106
        ranks = list(range(torch.distributed.get_world_size()))
1107
        _WORLD = init_world_group(ranks, local_rank, backend)
1108
        _NODE_COUNT = _node_count(_WORLD.cpu_group)
1109
        logger.debug("Detected %d nodes in the distributed environment", _NODE_COUNT)
1110
1111
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
1112
1113
            "world group already initialized with a different world size"
        )
1114
1115


Zhuohan Li's avatar
Zhuohan Li committed
1116
1117
1118
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
1119
    decode_context_model_parallel_size: Optional[int] = 1,
1120
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
1121
1122
) -> None:
    """
1123
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
1124
1125

    Arguments:
1126
1127
1128
1129
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.
1130
        backend: name of torch distributed communication backend.
1131
1132

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1133
1134
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1135
1136
1137
1138
1139
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1140
1141
1142
1143
1144
1145
1146
1147
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1148
    rank = torch.distributed.get_rank()
1149
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1150

1151
1152
    data_parallel_size = 1
    from vllm.config import get_current_vllm_config
1153

1154
1155
    config = get_current_vllm_config()
    if config is not None:
1156
1157
1158
1159
1160
1161
1162
1163
1164
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1165
1166
1167
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1168
1169
        -1, data_parallel_size, pipeline_model_parallel_size, tensor_model_parallel_size
    )  # noqa
1170

1171
1172
    # Build the tensor model-parallel groups.
    global _TP
1173
    assert _TP is None, "tensor model parallel group is already initialized"
1174
1175
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1176
1177

    # message queue broadcaster is only used in tensor model parallel group
1178
1179
1180
1181
1182
1183
1184
    _TP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="tp",
    )
1185

1186
1187
    # Build the DCP model-parallel groups.
    global _DCP
1188
    assert _DCP is None, "decode context model parallel group is already initialized"
1189
1190
    # Note(hc): In the current implementation of decode context parallel,
    # dcp_size must not exceed tp_size, because the world size does not
1191
    # change by DCP, it simply reuses the GPUs of TP group, and split one
1192
    # TP group into tp_size//dcp_size DCP groups.
1193
    group_ranks = all_ranks.reshape(-1, decode_context_model_parallel_size).unbind(0)
1194
    group_ranks = [x.tolist() for x in group_ranks]
1195
1196
1197
1198
1199
1200
1201
    _DCP = init_model_parallel_group(
        group_ranks,
        get_world_group().local_rank,
        backend,
        use_message_queue_broadcaster=True,
        group_name="dcp",
    )
1202

1203
    # Build the pipeline model-parallel groups.
1204
    global _PP
1205
1206
1207
1208
    assert _PP is None, "pipeline model parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(2, 3).reshape(-1, pipeline_model_parallel_size).unbind(0)
    )
1209
    group_ranks = [x.tolist() for x in group_ranks]
1210
1211
1212
    _PP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="pp"
    )
1213

1214
    global _DP
1215
1216
    assert _DP is None, "data parallel group is already initialized"
    group_ranks = all_ranks.transpose(1, 3).reshape(-1, data_parallel_size).unbind(0)
1217
    group_ranks = [x.tolist() for x in group_ranks]
1218
1219
1220
    _DP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="dp"
    )
1221

1222
    global _EP
1223
1224
1225
1226
1227
1228
    assert _EP is None, "expert parallel group is already initialized"
    group_ranks = (
        all_ranks.transpose(1, 2)
        .reshape(-1, data_parallel_size * tensor_model_parallel_size)
        .unbind(0)
    )
1229
    group_ranks = [x.tolist() for x in group_ranks]
1230
1231
1232
    _EP = init_model_parallel_group(
        group_ranks, get_world_group().local_rank, backend, group_name="ep"
    )
1233

1234
1235
    logger.info(
        "rank %s in world size %s is assigned as "
1236
1237
1238
1239
1240
1241
1242
1243
        "DP rank %s, PP rank %s, TP rank %s, EP rank %s",
        rank,
        world_size,
        _DP.rank_in_group,
        _PP.rank_in_group,
        _TP.rank_in_group,
        _EP.rank_in_group,
    )
1244

Zhuohan Li's avatar
Zhuohan Li committed
1245

1246
1247
1248
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1249
    decode_context_model_parallel_size: Optional[int] = 1,
1250
    backend: Optional[str] = None,
1251
1252
1253
1254
1255
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1256
    backend = backend or torch.distributed.get_backend(get_world_group().device_group)
1257
    if not model_parallel_is_initialized():
1258
1259
1260
1261
1262
1263
        initialize_model_parallel(
            tensor_model_parallel_size,
            pipeline_model_parallel_size,
            decode_context_model_parallel_size,
            backend,
        )
1264
1265
        return

1266
1267
    assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
        "tensor parallel group already initialized, but of unexpected size. "
1268
        f"got: {get_tensor_model_parallel_world_size()=} vs. "
1269
1270
        f"wanted: {tensor_model_parallel_size=}"
    )
1271
    pp_world_size = get_pp_group().world_size
1272
    assert pp_world_size == pipeline_model_parallel_size, (
1273
1274
        "pipeline parallel group already initialized, but of unexpected size. "
        f"got: {pp_world_size=} vs. "
1275
1276
        f"wanted: {pipeline_model_parallel_size=}"
    )
1277
1278


1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)


Zhuohan Li's avatar
Zhuohan Li committed
1296
def model_parallel_is_initialized():
1297
    """Check if tensor and pipeline parallel groups are initialized."""
1298
    return _TP is not None and _PP is not None
1299
1300


1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1329
1330
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1331
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1332
1333
1334
1335


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1336
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1337
1338


1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
def get_decode_context_model_parallel_world_size():
    """Return world size for the decode context model parallel group."""
    return get_dcp_group().world_size


def get_decode_context_model_parallel_rank():
    """Return my rank for the decode context model parallel group."""
    return get_dcp_group().rank_in_group


1349
def get_node_count() -> int:
1350
1351
    """Return the total number of nodes in the distributed environment."""
    assert _NODE_COUNT is not None, "distributed environment is not initialized"
1352
1353
1354
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1355
def destroy_model_parallel():
1356
    """Set the groups to none and destroy them."""
1357
    global _TP
1358

1359
1360
1361
1362
1363
1364
1365
1366
1367
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1368
1369
1370
1371
1372
    global _DCP
    if _DCP:
        _DCP.destroy()
    _DCP = None

1373
1374
1375
1376
1377
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1378
1379
1380
1381
1382
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1383
1384

def destroy_distributed_environment():
1385
    global _WORLD, _NODE_COUNT
1386
1387
1388
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1389
    _NODE_COUNT = None
1390
1391
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1392
1393


1394
1395
1396
1397
1398
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    if shutdown_ray:
        import ray  # Lazy import Ray
1399

1400
1401
        ray.shutdown()
    gc.collect()
1402
    from vllm.platforms import current_platform
1403

1404
1405
1406
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1407
    try:
1408
1409
        if not current_platform.is_cpu():
            torch._C._host_emptyCache()
1410
    except AttributeError:
1411
        logger.warning("torch._C._host_emptyCache() only available in Pytorch >=2.5")
1412
1413


1414
1415
1416
def in_the_same_node_as(
    pg: Union[ProcessGroup, StatelessProcessGroup], source_rank: int = 0
) -> list[bool]:
1417
    """
1418
1419
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1420
1421
    memory system (shared access to shared memory).
    """
1422
    if isinstance(pg, ProcessGroup):
1423
1424
1425
        assert torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL, (
            "in_the_same_node_as should be tested with a non-NCCL group."
        )
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1436
1437
1438
1439
1440
1441
1442
1443
1444

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1445
            if rank == source_rank:
1446
1447
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
1448
                shm.buf[: len(magic_message)] = magic_message
1449
1450
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
1451
1452
                        [shm.name], src=ranks[source_rank], group=pg
                    )
1453
1454
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1455
                is_in_the_same_node[rank] = 1
1456
1457
            else:
                # try to open the shared memory segment
1458
1459
1460
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
1461
1462
                        recv, src=ranks[source_rank], group=pg
                    )
1463
1464
1465
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1466
1467
1468
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
1469
1470
1471
1472
                with patch(
                    "multiprocessing.resource_tracker.register",
                    lambda *args, **kwargs: None,
                ):
1473
                    shm = shared_memory.SharedMemory(name=name)
1474
                if shm.buf[: len(magic_message)] == magic_message:
1475
1476
1477
1478
1479
1480
1481
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1482
1483
1484
1485
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1486
1487
1488

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1489
        if rank == source_rank and shm:
1490
            shm.unlink()
1491

1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
1502
1503


1504
1505
def is_global_first_rank() -> bool:
    """
1506
    Check if the current process is the first rank globally across all
1507
    parallelism strategies (PP, TP, DP, EP, etc.).
1508

1509
1510
1511
    Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
    or `get_pp_group().is_first_rank`, this function checks the global rank
    across all parallelism dimensions.
1512

1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
    Returns:
        bool: True if this is the global first rank (rank 0), False otherwise.
              Returns True if distributed is not initialized (single process).
    """
    try:
        # If world group is available, use it for the most accurate check
        global _WORLD
        if _WORLD is not None:
            return _WORLD.is_first_rank

        # If torch distributed is not initialized, assume single process
        if not torch.distributed.is_initialized():
            return True

        # Fallback to torch's global rank
        return torch.distributed.get_rank() == 0

    except Exception:
        # If anything goes wrong, assume this is the first rank
        return True


1535
1536
1537
1538
1539
1540
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
1541

1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id