parallel_state.py 56.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25
import contextlib
26
import gc
27
import pickle
28
import weakref
29
30
31
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
32
from multiprocessing import shared_memory
33
from typing import Any, Callable, Optional, Tuple, Union
34
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
35
36

import torch
37
import torch.distributed
38
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
39

40
import vllm.envs as envs
41
42
from vllm.distributed.device_communicators.base_device_communicator import (
    DeviceCommunicatorBase)
43
from vllm.distributed.utils import StatelessProcessGroup
44
from vllm.logger import init_logger
45
46
from vllm.utils import (direct_register_custom_op, get_distributed_init_method,
                        resolve_obj_by_qualname, supports_custom_op)
47
48


49
50
51
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
52

53

54
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
55

56
57

def _split_tensor_dict(
58
59
    tensor_dict: dict[str, Union[torch.Tensor, Any]]
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
60
61
62
63
64
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
65
66
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
67
68
69
70
71
72
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
73
            device = value.device.type
74
            metadata_list.append(
75
                (key, TensorMetadata(device, value.dtype, value.size())))
76
77
            tensor_list.append(value)
        else:
78
            metadata_list.append((key, value))
79
80
81
    return metadata_list, tensor_list


82
_group_name_counter: dict[str, int] = {}
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


98
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
99
100
101


def _register_group(group: "GroupCoordinator") -> None:
102
    _groups[group.unique_name] = weakref.ref(group)
103
104


105
106
107
108
109
110
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
111
112


113
114
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
115

116

117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def all_reduce_rms_quant(input_: torch.Tensor, group_name: str,
                         pa_rms_weight: Optional[torch.Tensor] = None,
                         pa_residual: Optional[torch.Tensor] = None,
                         pa_rms_eps: Optional[float] = 1e-6,
                         pa_quant_dtype: Optional[torch.dtype] = torch.int8,
                         update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place_m32(input_,
                                       pa_rms_weight=pa_rms_weight,
                                       pa_residual=pa_residual,
                                       pa_rms_eps=pa_rms_eps,
                                       pa_quant_dtype=pa_quant_dtype,
                                       update_input=update_input)


def all_reduce_rms_quant_fake(input_: torch.Tensor, group_name: str,
                              pa_rms_weight: Optional[torch.Tensor] = None,
                              pa_residual: Optional[torch.Tensor] = None,
                              pa_rms_eps: Optional[float] = 1e-6,
                              pa_quant_dtype: Optional[torch.dtype] = torch.int8,
                              update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    xq = torch.zeros_like(input_, dtype=pa_quant_dtype)
    xs = torch.ones((input_.numel() // input_.shape[-1], 1),
                        device=input_.device,
                        dtype=torch.float32)
    return input_, pa_residual, xq, xs


148
149
150
151
152
153
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
                   group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
154
    return group._reduce_scatter_out_place(tensor, dim)
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169


def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
                        group_name: str) -> torch.Tensor:
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
               group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
170
    return group._all_gather_out_place(tensor, dim)
171
172
173
174
175
176
177
178
179


def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
                    group_name: str) -> torch.Tensor:
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


180
if supports_custom_op():
181
    from vllm.platforms import current_platform
182
    direct_register_custom_op(
183
184
        op_name="all_reduce",
        op_func=all_reduce,
185
        mutates_args=[],
186
        fake_impl=all_reduce_fake,
187
        dispatch_key=current_platform.dispatch_key,
188
189
    )

190
191
192
193
194
195
196
197
    direct_register_custom_op(
        op_name="all_reduce_rms_quant",
        op_func=all_reduce_rms_quant,
        mutates_args=["input_", "pa_residual"],
        fake_impl=all_reduce_rms_quant_fake,
        dispatch_key=current_platform.dispatch_key,
    )

198
199
200
201
202
    direct_register_custom_op(
        op_name="reduce_scatter",
        op_func=reduce_scatter,
        mutates_args=[],
        fake_impl=reduce_scatter_fake,
203
        dispatch_key=current_platform.dispatch_key,
204
205
206
207
208
209
210
    )

    direct_register_custom_op(
        op_name="all_gather",
        op_func=all_gather,
        mutates_args=[],
        fake_impl=all_gather_fake,
211
        dispatch_key=current_platform.dispatch_key,
212
213
    )

214

215
216
217
218
219
220
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
221
222
        the processes in the group. It manages both CPU and device
        communication.
223
224
225
226
    """

    # available attributes:
    rank: int  # global rank
227
    ranks: list[int]  # global ranks in the group
228
229
230
231
232
233
234
235
236
237
238
239
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
240
241
    use_device_communicator: bool  # whether to use device communicator
    device_communicator: DeviceCommunicatorBase  # device communicator
242
    mq_broadcaster: Optional[Any]  # shared memory broadcaster
243
244
245

    def __init__(
        self,
246
        group_ranks: list[list[int]],
247
248
        local_rank: int,
        torch_distributed_backend: Union[str, Backend],
249
        use_device_communicator: bool,
250
        use_message_queue_broadcaster: bool = False,
251
        group_name: Optional[str] = None,
252
    ):
253
254
255
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
        self.device_group = None
        self.cpu_group = None

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
                self.device_group = device_group
                self.cpu_group = cpu_group

        assert self.cpu_group is not None
        assert self.device_group is not None

278
        from vllm.platforms import current_platform
279

280
        if current_platform.is_cuda_alike():
281
            self.device = torch.device(f"cuda:{local_rank}")
282
283
284
        elif current_platform.is_out_of_tree():
            self.device = torch.device(
                f"{current_platform.device_name}:{local_rank}")
285
286
287
        else:
            self.device = torch.device("cpu")

288
        self.use_device_communicator = use_device_communicator
289

290
291
292
293
294
295
        self.device_communicator: DeviceCommunicatorBase = None  # type: ignore
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
                current_platform.get_device_communicator_cls())
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
296
                device=self.device,
297
298
                device_group=self.device_group,
                unique_name=self.unique_name,
299
300
            )

301
        from vllm.distributed.device_communicators.shm_broadcast import (
302
303
304
305
            MessageQueue)
        self.mq_broadcaster: Optional[MessageQueue] = None
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
306
                self.cpu_group, 1 << 22, 6)
307

308
        from vllm.platforms import current_platform
309
310
        self.use_custom_op_call = (current_platform.is_cuda_alike()
                                   or current_platform.is_tpu())
311

312
313
314
315
316
317
318
319
320
321
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

322
323
324
325
326
327
328
329
330
331
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
    def graph_capture(
            self, graph_capture_context: Optional[GraphCaptureContext] = None):
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

355
356
357
358
359
360
361
362
363
364
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
            CudaCommunicator)
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
365
366
367
368
369
370
371

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

372
        with torch.cuda.stream(stream), maybe_ca_context:
373
            yield graph_capture_context
374
375
376

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
377
378
379
380
381
382
383
384
385
386
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
387
388
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
389
390
391
392
393
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

394
395
396
397
398
        if self.use_custom_op_call:
            return torch.ops.vllm.all_reduce(input_,
                                             group_name=self.unique_name)
        else:
            return self._all_reduce_out_place(input_)
399

400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
    def all_reduce_crq_m32(self, input_: torch.Tensor, 
                       pa_rms_weight: torch.Tensor,
                       pa_residual: torch.Tensor,
                       pa_rms_eps: float,
                       pa_quant_dtype: torch.dtype,
                       update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        assert self.world_size > 1

        assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None and pa_residual is not None
        return torch.ops.vllm.all_reduce_rms_quant(input_,
                                                    group_name=self.unique_name,
                                                    pa_rms_weight=pa_rms_weight,
                                                    pa_residual=pa_residual,
                                                    pa_rms_eps=pa_rms_eps,
                                                    pa_quant_dtype=pa_quant_dtype,
                                                    update_input=update_input)



419
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
420
        return self.device_communicator.all_reduce(input_)
421

422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
    def _all_reduce_out_place_m32(self, input_: torch.Tensor,
                                  pa_rms_weight: torch.Tensor,
                                  pa_residual: torch.Tensor,
                                  pa_rms_eps: float,
                                  pa_quant_dtype: torch.dtype,
                                  update_input: Optional[bool] = True) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        assert envs.USE_FUSED_CUSTOM_ALL_REDUCE_RMS_QUANT and pa_rms_weight is not None \
        and pa_residual is not None 
        input_, pa_residual, xq, xs = self.device_communicator.all_reduce_rms_quant_m32(input_,
                                                                    pa_rms_weight=pa_rms_weight,
                                                                    pa_residual=pa_residual,
                                                                    pa_rms_eps=pa_rms_eps,
                                                                    pa_quant_dtype=pa_quant_dtype,
                                                                    update_input=update_input)
        return input_, pa_residual, xq, xs

438
439
440
441
442
443
444
    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
445

446
447
448
449
450
451
452
453
454
455
        if self.use_custom_op_call:
            return torch.ops.vllm.all_gather(input_,
                                             dim,
                                             world_size,
                                             group_name=self.unique_name)
        else:
            return self._all_gather_out_place(input_, dim)

    def _all_gather_out_place(self, input_: torch.Tensor,
                              dim: int) -> torch.Tensor:
456
        return self.device_communicator.all_gather(input_, dim)
457

458
459
460
461
462
463
464
465
466
467
    def reduce_scatter(self,
                       input_: torch.Tensor,
                       dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

468
469
470
471
472
473
474
475
476
477
        if self.use_custom_op_call:
            return torch.ops.vllm.reduce_scatter(input_,
                                                 dim,
                                                 world_size,
                                                 group_name=self.unique_name)
        else:
            return self._reduce_scatter_out_place(input_, dim)

    def _reduce_scatter_out_place(self, input_: torch.Tensor,
                                  dim: int) -> torch.Tensor:
478
479
        return self.device_communicator.reduce_scatter(input_, dim)

480
481
482
    def gather(self,
               input_: torch.Tensor,
               dst: int = 0,
483
               dim: int = -1) -> Optional[torch.Tensor]:
484
485
486
487
488
489
490
491
492
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
493
        return self.device_communicator.gather(input_, dst, dim)
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
        torch.distributed.broadcast(input_,
                                    src=self.ranks[src],
                                    group=self.device_group)
        return input_

510
511
512
513
514
515
516
517
518
    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
519
520
521
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
522
523
524
525
526
527
528
529
530
531
532
533
        if self.rank_in_group == src:
            torch.distributed.broadcast_object_list([obj],
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return obj
        else:
            recv = [None]
            torch.distributed.broadcast_object_list(recv,
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return recv[0]

534
    def broadcast_object_list(self,
535
                              obj_list: list[Any],
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
                              src: int = 0,
                              group: Optional[ProcessGroup] = None):
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
        torch.distributed.broadcast_object_list(obj_list,
                                                src=self.ranks[src],
                                                group=self.device_group)
        return obj_list

552
553
554
555
556
557
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

558
        assert dst != self.rank_in_group, (
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
            "Invalid destination rank. Destination rank is the same "
            "as the current rank.")

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

        size_tensor = torch.tensor([object_tensor.numel()],
                                   dtype=torch.long,
                                   device="cpu")

        # Send object size

        torch.distributed.send(size_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        # Send object
        torch.distributed.send(object_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

588
        assert src != self.rank_in_group, (
589
590
591
592
593
594
595
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
        rank_size = torch.distributed.recv(size_tensor,
596
                                           src=self.ranks[src],
597
598
599
600
601
602
603
604
605
                                           group=self.cpu_group)

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
            device="cpu")

        rank_object = torch.distributed.recv(object_tensor,
606
                                             src=self.ranks[src],
607
608
609
610
611
612
613
614
615
                                             group=self.cpu_group)

        assert rank_object == rank_size, (
            "Received object sender rank does not match the size sender rank.")

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

616
617
    def broadcast_tensor_dict(
        self,
618
        tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
619
620
621
        src: int = 0,
        group: Optional[ProcessGroup] = None,
        metadata_group: Optional[ProcessGroup] = None
622
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
623
624
625
626
627
628
629
630
631
632
633
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if (not torch.distributed.is_initialized() or self.world_size == 1):
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

634
635
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
636
            metadata_list: list[tuple[Any, Any]] = []
637
638
639
640
641
642
643
            assert isinstance(
                tensor_dict,
                dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
644
            self.broadcast_object(metadata_list, src=src)
645
646
647
648
649
650
651
652
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
653
                                                         src=self.ranks[src],
654
655
656
657
658
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
659
                                                         src=self.ranks[src],
660
661
662
663
664
665
666
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
667
            metadata_list = self.broadcast_object(None, src=src)
668
669
            tensor_dict = {}
            async_handles = []
670
            for key, value in metadata_list:
671
672
673
674
675
676
                if isinstance(value, TensorMetadata):
                    tensor = torch.empty(value.size,
                                         dtype=value.dtype,
                                         device=value.device)
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
677
                        tensor_dict[key] = tensor
678
679
680
681
682
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
683
                            src=self.ranks[src],
684
685
686
687
                            group=metadata_group,
                            async_op=True)
                    else:
                        # use group for GPU tensors
688
689
690
691
692
                        handle = torch.distributed.broadcast(
                            tensor,
                            src=self.ranks[src],
                            group=group,
                            async_op=True)
693
                    async_handles.append(handle)
694
                    tensor_dict[key] = tensor
695
                else:
696
                    tensor_dict[key] = value
697
698
699
700
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

701
702
    def send_tensor_dict(
        self,
703
        tensor_dict: dict[str, Union[torch.Tensor, Any]],
704
705
        dst: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
706
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
707
708
709
710
711
712
713
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict

714
715
716
717
718
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

719
720
721
722
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
723
            dst = (self.rank_in_group + 1) % self.world_size
724
725
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

726
        metadata_list: list[tuple[Any, Any]] = []
727
728
729
730
731
732
733
734
735
736
737
738
        assert isinstance(
            tensor_dict,
            dict), f"Expecting a dictionary, got {type(tensor_dict)}"
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
        for tensor in tensor_list:
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
739
740
741
742
743
744

            # send-allgather: send only a slice, then do allgather.
            if (all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0):
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

745
746
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
747
748
749
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=metadata_group)
750
751
            else:
                # use group for GPU tensors
752
753
754
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=group)
755
756
757
758
        return None

    def recv_tensor_dict(
        self,
759
760
        src: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
761
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
762
763
764
765
766
767
768
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None

769
770
771
772
773
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

774
775
776
777
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
778
            src = (self.rank_in_group - 1) % self.world_size
779
780
781
        assert src < self.world_size, f"Invalid src rank ({src})"

        recv_metadata_list = self.recv_object(src=src)
782
        tensor_dict: dict[str, Any] = {}
783
784
785
786
787
788
789
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device=value.device)
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
790
                    tensor_dict[key] = tensor
791
                    continue
792
793
794
795
796
797
798
799
800
801

                # send-allgather: send only a slice, then do allgather.
                use_all_gather = (all_gather_group is not None
                                  and tensor.numel() % all_gather_size == 0)

                if use_all_gather:
                    orig_shape = tensor.shape
                    tensor = tensor.reshape(all_gather_size,
                                            -1)[all_gather_rank]

802
803
804
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    torch.distributed.recv(tensor,
805
                                           src=self.ranks[src],
806
807
808
                                           group=metadata_group)
                else:
                    # use group for GPU tensors
809
810
811
                    torch.distributed.recv(tensor,
                                           src=self.ranks[src],
                                           group=group)
812
                if envs.VLLM_USE_PP_SYNC:
zhuwenwen's avatar
zhuwenwen committed
813
                    torch.cuda.synchronize()
814
815
816
817
818
819
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
                        tensor, dim=0)
                    tensor = tensor.reshape(orig_shape)

820
                tensor_dict[key] = tensor
821
            else:
822
                tensor_dict[key] = value
823
824
        return tensor_dict

825
826
827
828
829
830
831
832
833
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

834
835
836
    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
        """Sends a tensor to the destination rank in a non-blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
837
        self.device_communicator.send(tensor, dst)
838
839
840
841
842

    def recv(self,
             size: torch.Size,
             dtype: torch.dtype,
             src: Optional[int] = None) -> torch.Tensor:
843
844
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
845
        return self.device_communicator.recv(size, dtype, src)
846

847
848
849
850
851
852
853
    def destroy(self):
        if self.device_group is not None:
            torch.distributed.destroy_process_group(self.device_group)
            self.device_group = None
        if self.cpu_group is not None:
            torch.distributed.destroy_process_group(self.cpu_group)
            self.cpu_group = None
854
855
        if self.device_communicator is not None:
            self.device_communicator.destroy()
856
857
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
858

859
860
861
862
863
864
865
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
            self.device_communicator.prepare_communication_buffer_for_model(
                model)

    def dispatch(
            self, hidden_states: torch.Tensor,
866
            router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
867
868
869
        if self.device_communicator is not None:
            return self.device_communicator.dispatch(hidden_states,
                                                     router_logits)
870
871
        else:
            return hidden_states, router_logits
872
873
874
875

    def combine(self, hidden_states) -> torch.Tensor:
        if self.device_communicator is not None:
            return self.device_communicator.combine(hidden_states)
876
877
        else:
            return hidden_states
878

879
880

_WORLD: Optional[GroupCoordinator] = None
881
_NODE_COUNT: Optional[int] = None
882
883
884
885
886
887
888


def get_world_group() -> GroupCoordinator:
    assert _WORLD is not None, ("world group is not initialized")
    return _WORLD


889
def init_world_group(ranks: list[int], local_rank: int,
890
891
892
893
894
                     backend: str) -> GroupCoordinator:
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
895
        use_device_communicator=False,
896
        group_name="world",
897
898
899
    )


900
def init_model_parallel_group(
901
    group_ranks: list[list[int]],
902
903
904
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
905
    group_name: Optional[str] = None,
906
) -> GroupCoordinator:
907

908
909
910
911
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
912
        use_device_communicator=True,
913
        use_message_queue_broadcaster=use_message_queue_broadcaster,
914
        group_name=group_name,
915
916
917
    )


918
919
920
921
922
923
924
925
926
927
928
929
930
_TP: Optional[GroupCoordinator] = None


def get_tp_group() -> GroupCoordinator:
    assert _TP is not None, ("tensor model parallel group is not initialized")
    return _TP


# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group

_PP: Optional[GroupCoordinator] = None

931
932
933
934
935
936
937
_DP: Optional[GroupCoordinator] = None


def get_dp_group() -> GroupCoordinator:
    assert _DP is not None, ("data parallel group is not initialized")
    return _DP

938

939
940
941
942
943
944
945
946
_EP: Optional[GroupCoordinator] = None


def get_ep_group() -> GroupCoordinator:
    assert _EP is not None, ("expert parallel group is not initialized")
    return _EP


947
948
949
950
def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, (
        "pipeline model parallel group is not initialized")
    return _PP
951
952


953
954
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
955
956


957
@contextmanager
958
def graph_capture(device: torch.device):
959
960
961
962
963
964
965
966
967
968
969
970
971
    """
    `graph_capture` is a context manager which should surround the code that
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
972
973
974
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
            context):
975
976
        yield context

977

978
logger = init_logger(__name__)
979

980
_ENABLE_CUSTOM_ALL_REDUCE = True
981
982


983
984
985
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
986

Zhuohan Li's avatar
Zhuohan Li committed
987

988
def init_distributed_environment(
989
990
991
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
992
993
994
    local_rank: int = -1,
    backend: str = "nccl",
):
995
996
997
998
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
999
1000
1001
1002
1003
1004
1005
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None and config.parallel_config.data_parallel_size > 1:
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
zhuwenwen's avatar
zhuwenwen committed
1006
        local_rank = rank % torch.cuda.device_count()
1007
1008
1009
1010
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
        ip = parallel_config.data_parallel_master_ip
        port = parallel_config.get_next_dp_init_port()
1011
        distributed_init_method = get_distributed_init_method(ip, port)
1012
1013
1014
        logger.info(
            "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
            world_size, rank, distributed_init_method)
1015
1016
1017
1018
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
1019
1020
1021
1022
1023
1024
1025
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
                "Distributed backend %s is not available; "
                "falling back to gloo.", backend)
            assert torch.distributed.is_gloo_available(), (
                "Fallback Gloo backend is not available.")
            backend = "gloo"
1026
1027
1028
1029
1030
1031
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
        if distributed_init_method == "env://":
            local_rank = envs.LOCAL_RANK
        else:
            local_rank = rank
1042
    global _WORLD, _NODE_COUNT
1043
    if _WORLD is None:
1044
        ranks = list(range(torch.distributed.get_world_size()))
1045
        _WORLD = init_world_group(ranks, local_rank, backend)
1046
1047
1048
        _NODE_COUNT = _node_count(_WORLD.cpu_group)
        logger.debug("Detected %d nodes in the distributed environment",
                     _NODE_COUNT)
1049
1050
1051
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
            "world group already initialized with a different world size")
1052
1053


Zhuohan Li's avatar
Zhuohan Li committed
1054
1055
1056
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
1057
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
1058
1059
) -> None:
    """
1060
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
1061
1062

    Arguments:
1063
1064
1065
1066
1067
1068
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1069
1070
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1071
1072
1073
1074
1075
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1076
1077
1078
1079
1080
1081
1082
1083
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1084
    rank = torch.distributed.get_rank()
1085
1086
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1087

1088
1089
1090
1091
    data_parallel_size = 1
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None:
1092
1093
1094
1095
1096
1097
1098
1099
1100
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1101
1102
1103
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1104
        -1, data_parallel_size, pipeline_model_parallel_size,
1105
1106
        tensor_model_parallel_size)  # noqa

1107
1108
1109
    # Build the tensor model-parallel groups.
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
1110
1111
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1112
1113

    # message queue broadcaster is only used in tensor model parallel group
1114
    _TP = init_model_parallel_group(group_ranks,
1115
1116
                                    get_world_group().local_rank,
                                    backend,
1117
1118
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")
1119

1120
    # Build the pipeline model-parallel groups.
1121
1122
    global _PP
    assert _PP is None, (
1123
        "pipeline model parallel group is already initialized")
1124
    group_ranks = all_ranks.transpose(2, 3).reshape(
1125
1126
        -1, pipeline_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1127
    _PP = init_model_parallel_group(group_ranks,
1128
1129
                                    get_world_group().local_rank,
                                    backend,
1130
                                    group_name="pp")
1131

1132
1133
    global _DP
    assert _DP is None, ("data parallel group is already initialized")
1134
1135
    group_ranks = all_ranks.transpose(1,
                                      3).reshape(-1,
1136
1137
1138
1139
1140
1141
1142
                                                 data_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
    _DP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="dp")

1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
    global _EP
    assert _EP is None, ("expert parallel group is already initialized")
    group_ranks = all_ranks.transpose(1, 2).reshape(
        -1, data_parallel_size * tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
    _EP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="ep")

1153
1154
    logger.info(
        "rank %s in world size %s is assigned as "
1155
1156
1157
        "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
        _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
        _EP.rank_in_group)
1158

Zhuohan Li's avatar
Zhuohan Li committed
1159

1160
1161
1162
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1163
    backend: Optional[str] = None,
1164
1165
1166
1167
1168
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1169
1170
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
1171
1172
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
1173
                                  pipeline_model_parallel_size, backend)
1174
1175
1176
1177
1178
1179
1180
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
1181
1182
    pp_world_size = get_pp_group().world_size
    assert (pp_world_size == pipeline_model_parallel_size), (
1183
        "pipeline parallel group already initialized, but of unexpected size: "
1184
        f"{pp_world_size=} vs. "
1185
1186
1187
        f"{pipeline_model_parallel_size=}")


1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)


Zhuohan Li's avatar
Zhuohan Li committed
1205
def model_parallel_is_initialized():
1206
    """Check if tensor and pipeline parallel groups are initialized."""
1207
    return (_TP is not None and _PP is not None)
1208
1209


1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1238
1239
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1240
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1241
1242
1243
1244


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1245
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1246
1247


1248
1249
1250
1251
1252
1253
1254
def get_node_count() -> int:
    """Return the total number of nodes in the distributed environment. """
    assert _NODE_COUNT is not None, (
        "distributed environment is not initialized")
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1255
def destroy_model_parallel():
1256
    """Set the groups to none and destroy them."""
1257
    global _TP
1258

1259
1260
1261
1262
1263
1264
1265
1266
1267
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1268
1269
1270
1271
1272
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1273
1274
1275
1276
1277
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1278
1279

def destroy_distributed_environment():
1280
    global _WORLD, _NODE_COUNT
1281
1282
1283
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1284
    _NODE_COUNT = None
1285
1286
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1287
1288


1289
1290
1291
1292
1293
1294
1295
1296
1297
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    if shutdown_ray:
        import ray  # Lazy import Ray
        ray.shutdown()
    gc.collect()
1298
    from vllm.platforms import current_platform
1299
1300
1301
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1302
    try:
1303
1304
        if not current_platform.is_cpu():
            torch._C._host_emptyCache()
1305
1306
1307
    except AttributeError:
        logger.warning(
            "torch._C._host_emptyCache() only available in Pytorch >=2.5")
1308
1309


1310
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
1311
                        source_rank: int = 0) -> list[bool]:
1312
    """
1313
1314
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1315
1316
    memory system (shared access to shared memory).
    """
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
    if isinstance(pg, ProcessGroup):
        assert torch.distributed.get_backend(
            pg) != torch.distributed.Backend.NCCL, (
                "in_the_same_node_as should be tested with a non-NCCL group.")
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1331
1332
1333
1334
1335
1336
1337
1338
1339

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1340
            if rank == source_rank:
1341
1342
1343
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
1344
1345
1346
1347
1348
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
                        [shm.name], src=ranks[source_rank], group=pg)
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1349
                is_in_the_same_node[rank] = 1
1350
1351
            else:
                # try to open the shared memory segment
1352
1353
1354
1355
1356
1357
1358
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
                        recv, src=ranks[source_rank], group=pg)
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1359
1360
1361
1362
1363
1364
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
1365
1366
1367
1368
1369
1370
1371
1372
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1373
1374
1375
1376
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1377
1378
1379

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1380
        if rank == source_rank and shm:
1381
            shm.unlink()
1382

1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
1393
1394


1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
def is_global_first_rank() -> bool:
    """
    Check if the current process is the first rank globally across all 
    parallelism strategies (PP, TP, DP, EP, etc.).
    
    Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
    or `get_pp_group().is_first_rank`, this function checks the global rank
    across all parallelism dimensions.
    
    Returns:
        bool: True if this is the global first rank (rank 0), False otherwise.
              Returns True if distributed is not initialized (single process).
    """
    try:
        # If world group is available, use it for the most accurate check
        global _WORLD
        if _WORLD is not None:
            return _WORLD.is_first_rank

        # If torch distributed is not initialized, assume single process
        if not torch.distributed.is_initialized():
            return True

        # Fallback to torch's global rank
        return torch.distributed.get_rank() == 0

    except Exception:
        # If anything goes wrong, assume this is the first rank
        return True


1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
        
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id