"docs/vscode:/vscode.git/clone" did not exist on "f1c78138aa28e58eeaafa4791788fe6ceddf1dd8"
parallel_state.py 44.5 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
6
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
7
8
9
10
11
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
12
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
13
14
15
16
17
18
19
20
21
22
23
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
24
import contextlib
25
import gc
26
import pickle
27
import weakref
28
29
30
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
31
from multiprocessing import shared_memory
32
33
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
                    Union)
34
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
35
36

import torch
37
import torch.distributed
38
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
39

40
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
41
import vllm.envs as envs
42
43
from vllm.distributed.device_communicators.base_device_communicator import (
    DeviceCommunicatorBase)
44
from vllm.distributed.utils import StatelessProcessGroup
45
from vllm.logger import init_logger
46
47
from vllm.utils import (direct_register_custom_op, resolve_obj_by_qualname,
                        supports_custom_op)
48

49
50
51
if TYPE_CHECKING:
    from vllm.config import VllmConfig

52

53
54
55
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
56

57

58
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
59

60
61

def _split_tensor_dict(
62
63
    tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
64
65
66
67
68
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
69
    metadata_list: List[Tuple[str, Any]] = []
70
    tensor_list: List[torch.Tensor] = []
71
72
73
74
75
76
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
77
            device = value.device.type
78
            metadata_list.append(
79
                (key, TensorMetadata(device, value.dtype, value.size())))
80
81
            tensor_list.append(value)
        else:
82
            metadata_list.append((key, value))
83
84
85
    return metadata_list, tensor_list


86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
_group_name_counter: Dict[str, int] = {}


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


102
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
103
104
105


def _register_group(group: "GroupCoordinator") -> None:
106
    _groups[group.unique_name] = weakref.ref(group)
107
108


109
110
111
112
113
114
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
115
116


117
118
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
119

120

121
if supports_custom_op():
122
    from vllm.platforms import current_platform
123
    direct_register_custom_op(
124
125
        op_name="all_reduce",
        op_func=all_reduce,
126
        mutates_args=[],
127
        fake_impl=all_reduce_fake,
128
        dispatch_key=current_platform.dispatch_key,
129
130
    )

131

132
133
134
135
136
137
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
138
139
        the processes in the group. It manages both CPU and device
        communication.
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
    """

    # available attributes:
    rank: int  # global rank
    ranks: List[int]  # global ranks in the group
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
157
158
    use_device_communicator: bool  # whether to use device communicator
    device_communicator: DeviceCommunicatorBase  # device communicator
159
    mq_broadcaster: Optional[Any]  # shared memory broadcaster
160
161
162
163
164
165

    def __init__(
        self,
        group_ranks: List[List[int]],
        local_rank: int,
        torch_distributed_backend: Union[str, Backend],
166
        use_device_communicator: bool,
167
        use_message_queue_broadcaster: bool = False,
168
        group_name: Optional[str] = None,
169
    ):
170
171
172
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
        self.device_group = None
        self.cpu_group = None

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
                self.device_group = device_group
                self.cpu_group = cpu_group

        assert self.cpu_group is not None
        assert self.device_group is not None

195
        from vllm.platforms import current_platform
196

197
        if current_platform.is_cuda_alike():
198
            self.device = torch.device(f"cuda:{local_rank}")
199
200
201
        elif current_platform.is_out_of_tree():
            self.device = torch.device(
                f"{current_platform.device_name}:{local_rank}")
202
203
204
        else:
            self.device = torch.device("cpu")

205
        self.use_device_communicator = use_device_communicator
206

207
208
209
210
211
212
        self.device_communicator: DeviceCommunicatorBase = None  # type: ignore
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
                current_platform.get_device_communicator_cls())
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
213
                device=self.device,
214
215
                device_group=self.device_group,
                unique_name=self.unique_name,
216
217
            )

218
        from vllm.distributed.device_communicators.shm_broadcast import (
219
220
221
222
            MessageQueue)
        self.mq_broadcaster: Optional[MessageQueue] = None
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
223
                self.cpu_group, 1 << 22, 6)
224

225
        from vllm.platforms import current_platform
226
227
        self.use_custom_op_call = (current_platform.is_cuda_alike()
                                   or current_platform.is_tpu())
228

229
230
231
232
233
234
235
236
237
238
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

239
240
241
242
243
244
245
246
247
248
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
    def graph_capture(
            self, graph_capture_context: Optional[GraphCaptureContext] = None):
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

272
273
274
275
276
277
278
279
280
281
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
            CudaCommunicator)
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
282
283
284
285
286
287
288

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

289
        with torch.cuda.stream(stream), maybe_ca_context:
290
            yield graph_capture_context
291
292
293

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
294
295
296
297
298
299
300
301
302
303
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
304
305
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
306
307
308
309
310
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

311
312
313
314
315
        if self.use_custom_op_call:
            return torch.ops.vllm.all_reduce(input_,
                                             group_name=self.unique_name)
        else:
            return self._all_reduce_out_place(input_)
316

317
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
318
        return self.device_communicator.all_reduce(input_)
319
320
321
322
323
324
325
326

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
327

328
        return self.device_communicator.all_gather(input_, dim)
329
330
331
332

    def gather(self,
               input_: torch.Tensor,
               dst: int = 0,
333
               dim: int = -1) -> Optional[torch.Tensor]:
334
335
336
337
338
339
340
341
342
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
343
        return self.device_communicator.gather(input_, dst, dim)
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
        torch.distributed.broadcast(input_,
                                    src=self.ranks[src],
                                    group=self.device_group)
        return input_

360
361
362
363
364
365
366
367
368
    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
369
370
371
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
372
373
374
375
376
377
378
379
380
381
382
383
        if self.rank_in_group == src:
            torch.distributed.broadcast_object_list([obj],
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return obj
        else:
            recv = [None]
            torch.distributed.broadcast_object_list(recv,
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return recv[0]

384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
    def broadcast_object_list(self,
                              obj_list: List[Any],
                              src: int = 0,
                              group: Optional[ProcessGroup] = None):
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
        torch.distributed.broadcast_object_list(obj_list,
                                                src=self.ranks[src],
                                                group=self.device_group)
        return obj_list

402
403
404
405
406
407
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

408
        assert dst != self.rank_in_group, (
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
            "Invalid destination rank. Destination rank is the same "
            "as the current rank.")

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

        size_tensor = torch.tensor([object_tensor.numel()],
                                   dtype=torch.long,
                                   device="cpu")

        # Send object size

        torch.distributed.send(size_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        # Send object
        torch.distributed.send(object_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

438
        assert src != self.rank_in_group, (
439
440
441
442
443
444
445
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
        rank_size = torch.distributed.recv(size_tensor,
446
                                           src=self.ranks[src],
447
448
449
450
451
452
453
454
455
                                           group=self.cpu_group)

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
            device="cpu")

        rank_object = torch.distributed.recv(object_tensor,
456
                                             src=self.ranks[src],
457
458
459
460
461
462
463
464
465
                                             group=self.cpu_group)

        assert rank_object == rank_size, (
            "Received object sender rank does not match the size sender rank.")

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

466
467
    def broadcast_tensor_dict(
        self,
468
        tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
469
470
471
        src: int = 0,
        group: Optional[ProcessGroup] = None,
        metadata_group: Optional[ProcessGroup] = None
472
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
473
474
475
476
477
478
479
480
481
482
483
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if (not torch.distributed.is_initialized() or self.world_size == 1):
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

484
485
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
486
487
488
489
490
491
492
493
            metadata_list: List[Tuple[Any, Any]] = []
            assert isinstance(
                tensor_dict,
                dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
494
            self.broadcast_object(metadata_list, src=src)
495
496
497
498
499
500
501
502
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
503
                                                         src=self.ranks[src],
504
505
506
507
508
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
509
                                                         src=self.ranks[src],
510
511
512
513
514
515
516
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
517
            metadata_list = self.broadcast_object(None, src=src)
518
519
            tensor_dict = {}
            async_handles = []
520
            for key, value in metadata_list:
521
522
523
524
525
526
                if isinstance(value, TensorMetadata):
                    tensor = torch.empty(value.size,
                                         dtype=value.dtype,
                                         device=value.device)
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
527
                        tensor_dict[key] = tensor
528
529
530
531
532
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
533
                            src=self.ranks[src],
534
535
536
537
                            group=metadata_group,
                            async_op=True)
                    else:
                        # use group for GPU tensors
538
539
540
541
542
                        handle = torch.distributed.broadcast(
                            tensor,
                            src=self.ranks[src],
                            group=group,
                            async_op=True)
543
                    async_handles.append(handle)
544
                    tensor_dict[key] = tensor
545
                else:
546
                    tensor_dict[key] = value
547
548
549
550
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

551
552
    def send_tensor_dict(
        self,
553
        tensor_dict: Dict[str, Union[torch.Tensor, Any]],
554
555
        dst: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
556
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
557
558
559
560
561
562
563
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict

564
565
566
567
568
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

569
570
571
572
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
573
            dst = (self.rank_in_group + 1) % self.world_size
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

        metadata_list: List[Tuple[Any, Any]] = []
        assert isinstance(
            tensor_dict,
            dict), f"Expecting a dictionary, got {type(tensor_dict)}"
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
        for tensor in tensor_list:
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
589
590
591
592
593
594

            # send-allgather: send only a slice, then do allgather.
            if (all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0):
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

595
596
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
597
598
599
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=metadata_group)
600
601
            else:
                # use group for GPU tensors
602
603
604
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=group)
605
606
607
608
        return None

    def recv_tensor_dict(
        self,
609
610
        src: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
611
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
612
613
614
615
616
617
618
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None

619
620
621
622
623
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

624
625
626
627
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
628
            src = (self.rank_in_group - 1) % self.world_size
629
630
631
        assert src < self.world_size, f"Invalid src rank ({src})"

        recv_metadata_list = self.recv_object(src=src)
632
        tensor_dict: Dict[str, Any] = {}
633
634
635
636
637
638
639
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device=value.device)
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
640
                    tensor_dict[key] = tensor
641
                    continue
642
643
644
645
646
647
648
649
650
651

                # send-allgather: send only a slice, then do allgather.
                use_all_gather = (all_gather_group is not None
                                  and tensor.numel() % all_gather_size == 0)

                if use_all_gather:
                    orig_shape = tensor.shape
                    tensor = tensor.reshape(all_gather_size,
                                            -1)[all_gather_rank]

652
653
654
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    torch.distributed.recv(tensor,
655
                                           src=self.ranks[src],
656
657
658
                                           group=metadata_group)
                else:
                    # use group for GPU tensors
659
660
661
                    torch.distributed.recv(tensor,
                                           src=self.ranks[src],
                                           group=group)
662
663
664
665
666
667
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
                        tensor, dim=0)
                    tensor = tensor.reshape(orig_shape)

668
                tensor_dict[key] = tensor
669
            else:
670
                tensor_dict[key] = value
671
672
        return tensor_dict

673
674
675
676
677
678
679
680
681
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

682
683
684
    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
        """Sends a tensor to the destination rank in a non-blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
685
        self.device_communicator.send(tensor, dst)
686
687
688
689
690

    def recv(self,
             size: torch.Size,
             dtype: torch.dtype,
             src: Optional[int] = None) -> torch.Tensor:
691
692
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
693
        return self.device_communicator.recv(size, dtype, src)
694

695
696
697
698
699
700
701
    def destroy(self):
        if self.device_group is not None:
            torch.distributed.destroy_process_group(self.device_group)
            self.device_group = None
        if self.cpu_group is not None:
            torch.distributed.destroy_process_group(self.cpu_group)
            self.cpu_group = None
702
703
        if self.device_communicator is not None:
            self.device_communicator.destroy()
704
705
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
706
707
708
709
710
711
712
713
714
715


_WORLD: Optional[GroupCoordinator] = None


def get_world_group() -> GroupCoordinator:
    assert _WORLD is not None, ("world group is not initialized")
    return _WORLD


716
717
718
719
720
721
def init_world_group(ranks: List[int], local_rank: int,
                     backend: str) -> GroupCoordinator:
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
722
        use_device_communicator=False,
723
        group_name="world",
724
725
726
    )


727
def init_model_parallel_group(
728
729
730
731
    group_ranks: List[List[int]],
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
732
    group_name: Optional[str] = None,
733
) -> GroupCoordinator:
734

735
736
737
738
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
739
        use_device_communicator=True,
740
        use_message_queue_broadcaster=use_message_queue_broadcaster,
741
        group_name=group_name,
742
743
744
    )


745
746
747
748
749
750
751
752
753
754
755
756
757
_TP: Optional[GroupCoordinator] = None


def get_tp_group() -> GroupCoordinator:
    assert _TP is not None, ("tensor model parallel group is not initialized")
    return _TP


# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group

_PP: Optional[GroupCoordinator] = None

758
759
760
761
762
763
764
_DP: Optional[GroupCoordinator] = None


def get_dp_group() -> GroupCoordinator:
    assert _DP is not None, ("data parallel group is not initialized")
    return _DP

765
766
767
768
769

def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, (
        "pipeline model parallel group is not initialized")
    return _PP
770
771


772
773
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
774

775
776
777
778
779
780
781
782
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None


def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
    assert _KV_TRANSFER is not None, (
        "disaggregated KV cache transfer parallel group is not initialized")
    return _KV_TRANSFER

783

784
@contextmanager
785
def graph_capture(device: torch.device):
786
787
788
789
790
791
792
793
794
795
796
797
798
    """
    `graph_capture` is a context manager which should surround the code that
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
799
800
801
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
            context):
802
803
        yield context

804

805
logger = init_logger(__name__)
806

807
_ENABLE_CUSTOM_ALL_REDUCE = True
808
809


810
811
812
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
813

Zhuohan Li's avatar
Zhuohan Li committed
814

815
def init_distributed_environment(
816
817
818
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
819
820
821
    local_rank: int = -1,
    backend: str = "nccl",
):
822
823
824
825
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None and config.parallel_config.data_parallel_size > 1:
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
        ip = parallel_config.data_parallel_master_ip
        port = parallel_config.get_next_dp_init_port()
        distributed_init_method = f"tcp://{ip}:{port}"  # noqa
        logger.info(
            "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
            world_size, rank, distributed_init_method)
841
842
843
844
845
846
847
848
849
850
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
851
852
853
854
855
856
857
858
859
860
861
862
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
        if distributed_init_method == "env://":
            local_rank = envs.LOCAL_RANK
        else:
            local_rank = rank
    global _WORLD
    if _WORLD is None:
863
        ranks = list(range(torch.distributed.get_world_size()))
864
        _WORLD = init_world_group(ranks, local_rank, backend)
865
866
867
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
            "world group already initialized with a different world size")
868
869


Zhuohan Li's avatar
Zhuohan Li committed
870
871
872
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
873
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
874
875
) -> None:
    """
876
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
877
878

    Arguments:
879
880
881
882
883
884
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
885
886
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
887
888
889
890
891
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
892
893
894
895
896
897
898
899
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
900
    rank = torch.distributed.get_rank()
901
902
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
903

904
905
906
907
    data_parallel_size = 1
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None:
908
909
910
911
912
913
914
915
916
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
917
918
919
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
920
        -1, data_parallel_size, pipeline_model_parallel_size,
921
922
        tensor_model_parallel_size)  # noqa

923
924
925
    # Build the tensor model-parallel groups.
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
926
927
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
928
929

    # message queue broadcaster is only used in tensor model parallel group
930
    _TP = init_model_parallel_group(group_ranks,
931
932
                                    get_world_group().local_rank,
                                    backend,
933
934
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")
935

936
    # Build the pipeline model-parallel groups.
937
938
    global _PP
    assert _PP is None, (
939
        "pipeline model parallel group is already initialized")
940
    group_ranks = all_ranks.transpose(2, 3).reshape(
941
942
        -1, pipeline_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
943
    _PP = init_model_parallel_group(group_ranks,
944
945
                                    get_world_group().local_rank,
                                    backend,
946
                                    group_name="pp")
947

948
949
    global _DP
    assert _DP is None, ("data parallel group is already initialized")
950
951
    group_ranks = all_ranks.transpose(1,
                                      3).reshape(-1,
952
953
954
955
956
957
958
959
960
961
962
963
                                                 data_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
    _DP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="dp")

    logger.info(
        "rank %s in world size %s is assigned as "
        "DP rank %s, PP rank %s, TP rank %s", rank, world_size,
        _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group)

Zhuohan Li's avatar
Zhuohan Li committed
964

965
966
967
968
969
970
971
972
973
974
975
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
    """
    Initialize KV cache transfer parallel group.
    """

    global _KV_TRANSFER

    if vllm_config.kv_transfer_config is None:
        return

    if all([
976
977
            vllm_config.kv_transfer_config.is_kv_transfer_instance,
            _KV_TRANSFER is None
978
979
980
981
982
983
984
    ]):
        _KV_TRANSFER = kv_transfer.KVTransferAgent(
            rank=get_world_group().rank,
            local_rank=get_world_group().local_rank,
            config=vllm_config)


985
986
987
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
988
    backend: Optional[str] = None,
989
990
991
992
993
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
994
995
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
996
997
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
998
                                  pipeline_model_parallel_size, backend)
999
1000
1001
1002
1003
1004
1005
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
1006
1007
    pp_world_size = get_pp_group().world_size
    assert (pp_world_size == pipeline_model_parallel_size), (
1008
        "pipeline parallel group already initialized, but of unexpected size: "
1009
        f"{pp_world_size=} vs. "
1010
1011
1012
        f"{pipeline_model_parallel_size=}")


Zhuohan Li's avatar
Zhuohan Li committed
1013
def model_parallel_is_initialized():
1014
    """Check if tensor and pipeline parallel groups are initialized."""
1015
    return (_TP is not None and _PP is not None)
1016
1017


1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1046
1047
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1048
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1049
1050
1051
1052


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1053
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1054
1055
1056


def destroy_model_parallel():
1057
    """Set the groups to none and destroy them."""
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
    global _TP
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1068
1069
1070
1071
1072
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1073
1074
1075
1076
1077
1078
1079
1080

def destroy_distributed_environment():
    global _WORLD
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1081
1082


1083
1084
1085
1086
1087
1088
1089
1090
1091
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    if shutdown_ray:
        import ray  # Lazy import Ray
        ray.shutdown()
    gc.collect()
1092
    from vllm.platforms import current_platform
1093
    if not current_platform.is_cpu():
1094
        torch.cuda.empty_cache()
1095
1096
1097
1098
1099
    try:
        torch._C._host_emptyCache()
    except AttributeError:
        logger.warning(
            "torch._C._host_emptyCache() only available in Pytorch >=2.5")
1100
1101


1102
1103
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
                        source_rank: int = 0) -> List[bool]:
1104
    """
1105
1106
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1107
1108
    memory system (shared access to shared memory).
    """
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
    if isinstance(pg, ProcessGroup):
        assert torch.distributed.get_backend(
            pg) != torch.distributed.Backend.NCCL, (
                "in_the_same_node_as should be tested with a non-NCCL group.")
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1123
1124
1125
1126
1127
1128
1129
1130
1131

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1132
            if rank == source_rank:
1133
1134
1135
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
1136
1137
1138
1139
1140
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
                        [shm.name], src=ranks[source_rank], group=pg)
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1141
                is_in_the_same_node[rank] = 1
1142
1143
            else:
                # try to open the shared memory segment
1144
1145
1146
1147
1148
1149
1150
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
                        recv, src=ranks[source_rank], group=pg)
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1151
1152
1153
1154
1155
1156
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
1157
1158
1159
1160
1161
1162
1163
1164
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1165
1166
1167
1168
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1169
1170
1171

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1172
        if rank == source_rank and shm:
1173
            shm.unlink()
1174

1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]