parallel_state.py 48.9 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
6
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
7
8
9
10
11
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
12
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
13
14
15
16
17
18
19
20
21
22
23
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
24
import contextlib
25
import gc
26
import pickle
27
import weakref
28
29
30
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
31
from multiprocessing import shared_memory
32
33
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
                    Union)
34
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
35

36
import flux
Zhuohan Li's avatar
Zhuohan Li committed
37
import torch
38
import torch.distributed
39
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
40

41
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
42
import vllm.envs as envs
43
from vllm.distributed.utils import StatelessProcessGroup
44
from vllm.logger import init_logger
45
from vllm.utils import direct_register_custom_op, supports_custom_op
46

47
48
49
if TYPE_CHECKING:
    from vllm.config import VllmConfig

50

51
52
53
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
54

55

56
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
57

58
59

def _split_tensor_dict(
60
61
    tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
62
63
64
65
66
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
67
    metadata_list: List[Tuple[str, Any]] = []
68
    tensor_list: List[torch.Tensor] = []
69
70
71
72
73
74
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
75
            device = value.device.type
76
            metadata_list.append(
77
                (key, TensorMetadata(device, value.dtype, value.size())))
78
79
            tensor_list.append(value)
        else:
80
            metadata_list.append((key, value))
81
82
83
    return metadata_list, tensor_list


84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
_group_name_counter: Dict[str, int] = {}


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


100
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
101
102
103


def _register_group(group: "GroupCoordinator") -> None:
104
    _groups[group.unique_name] = weakref.ref(group)
105
106


107
108
109
110
111
112
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
113
114


115
116
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
117

118

119
if supports_custom_op():
120
    direct_register_custom_op(
121
122
        op_name="all_reduce",
        op_func=all_reduce,
123
        mutates_args=[],
124
        fake_impl=all_reduce_fake,
125
126
    )

127

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
        the processes in the group. It can route the communication to
        a specific implementation (e.g. switch allreduce implementation
        based on the tensor size and cuda graph mode).
    """

    # available attributes:
    rank: int  # global rank
    ranks: List[int]  # global ranks in the group
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    use_pynccl: bool  # a hint of whether to use PyNccl
    use_custom_allreduce: bool  # a hint of whether to use CustomAllreduce
    # communicators are only created for world size > 1
    pynccl_comm: Optional[Any]  # PyNccl communicator
    ca_comm: Optional[Any]  # Custom allreduce communicator
159
    mq_broadcaster: Optional[Any]  # shared memory broadcaster
160
161
162
163
164
165
166
167

    def __init__(
        self,
        group_ranks: List[List[int]],
        local_rank: int,
        torch_distributed_backend: Union[str, Backend],
        use_pynccl: bool,
        use_custom_allreduce: bool,
168
        use_tpu_communicator: bool,
169
        use_hpu_communicator: bool,
170
        use_xpu_communicator: bool,
171
        use_message_queue_broadcaster: bool = False,
172
        group_name: Optional[str] = None,
173
    ):
174
175
176
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
        self.device_group = None
        self.cpu_group = None

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
                self.device_group = device_group
                self.cpu_group = cpu_group

        assert self.cpu_group is not None
        assert self.device_group is not None

199
        from vllm.platforms import current_platform
200
        if current_platform.is_cuda_alike():
201
202
203
204
205
206
            self.device = torch.device(f"cuda:{local_rank}")
        else:
            self.device = torch.device("cpu")

        self.use_pynccl = use_pynccl
        self.use_custom_allreduce = use_custom_allreduce
207
        self.use_tpu_communicator = use_tpu_communicator
208
        self.use_hpu_communicator = use_hpu_communicator
209
        self.use_xpu_communicator = use_xpu_communicator
210
211
212
213
        
        # Initialize pynvshmem
        if torch.distributed.get_world_size(self.device_group) > 1:
            flux.init_flux_shm(self.device_group)
214
215
216
217
218
219
220

        # lazy import to avoid documentation build error
        from vllm.distributed.device_communicators.custom_all_reduce import (
            CustomAllreduce)
        from vllm.distributed.device_communicators.pynccl import (
            PyNcclCommunicator)

221
        self.pynccl_comm: Optional[PyNcclCommunicator] = None
222
223
224
225
226
227
        if use_pynccl and self.world_size > 1:
            self.pynccl_comm = PyNcclCommunicator(
                group=self.cpu_group,
                device=self.device,
            )

228
        self.ca_comm: Optional[CustomAllreduce] = None
229
230
231
232
233
234
235
        if use_custom_allreduce and self.world_size > 1:
            # Initialize a custom fast all-reduce implementation.
            self.ca_comm = CustomAllreduce(
                group=self.cpu_group,
                device=self.device,
            )

236
237
        from vllm.distributed.device_communicators.tpu_communicator import (
            TpuCommunicator)
238
        self.tpu_communicator: Optional[TpuCommunicator] = None
239
240
241
        if use_tpu_communicator and self.world_size > 1:
            self.tpu_communicator = TpuCommunicator(group=self.cpu_group)

242
243
244
245
246
247
        from vllm.distributed.device_communicators.hpu_communicator import (
            HpuCommunicator)
        self.hpu_communicator: Optional[HpuCommunicator]
        if use_hpu_communicator and self.world_size > 1:
            self.hpu_communicator = HpuCommunicator(group=self.device_group)

248
249
250
251
252
253
        from vllm.distributed.device_communicators.xpu_communicator import (
            XpuCommunicator)
        self.xpu_communicator: Optional[XpuCommunicator]
        if use_xpu_communicator and self.world_size > 1:
            self.xpu_communicator = XpuCommunicator(group=self.device_group)

254
        from vllm.distributed.device_communicators.shm_broadcast import (
255
256
257
258
            MessageQueue)
        self.mq_broadcaster: Optional[MessageQueue] = None
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
259
                self.cpu_group, 1 << 22, 6)
260

261
262
263
264
265
266
267
268
269
270
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

271
272
273
274
275
276
277
278
279
280
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
    def graph_capture(
            self, graph_capture_context: Optional[GraphCaptureContext] = None):
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

        ca_comm = self.ca_comm
        maybe_ca_context = nullcontext(
        ) if ca_comm is None else ca_comm.capture()
307
308
309
310
311
312
313

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

314
        with torch.cuda.stream(stream), maybe_ca_context:
315
            yield graph_capture_context
316
317
318

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
319
320
321
322
323
324
325
326
327
328
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
329
330
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
331
332
333
334
335
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

336
        if input_.is_cpu:
337
338
339
340
341
342
343
344
345
346
347
            try:
                import intel_extension_for_pytorch as ipex
                ipex.distributed.all_reduce(input_, group=self.device_group)
                return input_
            except ImportError:
                """
                Intel IPEX not found. Falling back to PyTorch native 
                all_reduce for CPU
                """
                torch.distributed.all_reduce(input_, group=self.device_group)
                return input_
348

349
350
351
        if self.tpu_communicator is not None and \
            not self.tpu_communicator.disabled:
            # TPU handles Dynamo with its own logic.
352
            return self.tpu_communicator.all_reduce(input_)
353

354
355
356
357
        if self.hpu_communicator is not None and \
            not self.hpu_communicator.disabled:
            return self.hpu_communicator.all_reduce(input_)

358
359
360
361
        if self.xpu_communicator is not None and \
                not self.xpu_communicator.disabled:
            return self.xpu_communicator.all_reduce(input_)

362
        return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
363

364
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
365
366
        # always try custom allreduce first,
        # and then pynccl.
367
        ca_comm = self.ca_comm
368
369
370
371
372
        if ca_comm is not None and not ca_comm.disabled and \
            ca_comm.should_custom_ar(input_):
            out = ca_comm.custom_all_reduce(input_)
            assert out is not None
            return out
373
        pynccl_comm = self.pynccl_comm
374
        assert pynccl_comm is not None
youkaichao's avatar
youkaichao committed
375
        out = pynccl_comm.all_reduce(input_)
376
377
378
379
380
381
382
383
        if out is None:
            # fall back to the default all-reduce using PyTorch.
            # this usually happens during testing.
            # when we run the model, allreduce only happens for the TP
            # group, where we always have either custom allreduce or pynccl.
            out = input_.clone()
            torch.distributed.all_reduce(out, group=self.device_group)
        return out
384
385
386
387
388
389
390
391

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
392
393
394
395
396
397

        # For TPUs, use TPU communicator.
        tpu_comm = self.tpu_communicator
        if tpu_comm is not None and not tpu_comm.disabled:
            return tpu_comm.all_gather(input_, dim)

398
399
400
401
402
        # For HPUs, use HPU communicator.
        hpu_comm = self.hpu_communicator
        if hpu_comm is not None and not hpu_comm.disabled:
            return hpu_comm.all_gather(input_, dim)

403
404
405
406
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        input_size = input_.size()
407
408
409
410
        # NOTE: we have to use concat-style all-gather here,
        # stack-style all-gather has compatibility issues with
        # torch.compile . see https://github.com/pytorch/pytorch/issues/138795
        output_size = (input_size[0] * world_size, ) + input_size[1:]
411
        # Allocate output tensor.
412
        output_tensor = torch.empty(output_size,
413
414
415
416
417
418
419
                                    dtype=input_.dtype,
                                    device=input_.device)
        # All-gather.
        torch.distributed.all_gather_into_tensor(output_tensor,
                                                 input_,
                                                 group=self.device_group)
        # Reshape
420
        output_tensor = output_tensor.reshape((world_size, ) + input_size)
421
422
423
424
425
426
427
428
429
430
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(input_size[:dim] +
                                              (world_size *
                                               input_size[dim], ) +
                                              input_size[dim + 1:])
        return output_tensor

    def gather(self,
               input_: torch.Tensor,
               dst: int = 0,
431
               dim: int = -1) -> Optional[torch.Tensor]:
432
433
434
435
436
437
438
439
440
441
442
443
444
445
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
446
447
448
449
        if self.xpu_communicator is not None and \
                not self.xpu_communicator.disabled:
            return self.xpu_communicator.gather(input_, self.rank_in_group,
                                                dst, dim)
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
        # Allocate output tensor.
        if self.rank_in_group == dst:
            gather_list = [torch.empty_like(input_) for _ in range(world_size)]
        else:
            gather_list = None
        # Gather.
        torch.distributed.gather(input_,
                                 gather_list,
                                 dst=self.ranks[dst],
                                 group=self.device_group)
        if self.rank_in_group == dst:
            output_tensor = torch.cat(gather_list, dim=dim)
        else:
            output_tensor = None
        return output_tensor

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
        torch.distributed.broadcast(input_,
                                    src=self.ranks[src],
                                    group=self.device_group)
        return input_

481
482
483
484
485
486
487
488
489
    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
490
491
492
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
493
494
495
496
497
498
499
500
501
502
503
504
        if self.rank_in_group == src:
            torch.distributed.broadcast_object_list([obj],
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return obj
        else:
            recv = [None]
            torch.distributed.broadcast_object_list(recv,
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return recv[0]

505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
    def broadcast_object_list(self,
                              obj_list: List[Any],
                              src: int = 0,
                              group: Optional[ProcessGroup] = None):
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
        torch.distributed.broadcast_object_list(obj_list,
                                                src=self.ranks[src],
                                                group=self.device_group)
        return obj_list

523
524
525
526
527
528
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

529
        assert dst != self.rank_in_group, (
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
            "Invalid destination rank. Destination rank is the same "
            "as the current rank.")

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

        size_tensor = torch.tensor([object_tensor.numel()],
                                   dtype=torch.long,
                                   device="cpu")

        # Send object size

        torch.distributed.send(size_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        # Send object
        torch.distributed.send(object_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

559
        assert src != self.rank_in_group, (
560
561
562
563
564
565
566
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
        rank_size = torch.distributed.recv(size_tensor,
567
                                           src=self.ranks[src],
568
569
570
571
572
573
574
575
576
                                           group=self.cpu_group)

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
            device="cpu")

        rank_object = torch.distributed.recv(object_tensor,
577
                                             src=self.ranks[src],
578
579
580
581
582
583
584
585
586
                                             group=self.cpu_group)

        assert rank_object == rank_size, (
            "Received object sender rank does not match the size sender rank.")

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

587
588
    def broadcast_tensor_dict(
        self,
589
        tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
590
591
592
        src: int = 0,
        group: Optional[ProcessGroup] = None,
        metadata_group: Optional[ProcessGroup] = None
593
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
594
595
596
597
598
599
600
601
602
603
604
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if (not torch.distributed.is_initialized() or self.world_size == 1):
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

605
606
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
607
608
609
610
611
612
613
614
            metadata_list: List[Tuple[Any, Any]] = []
            assert isinstance(
                tensor_dict,
                dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
615
            self.broadcast_object(metadata_list, src=src)
616
617
618
619
620
621
622
623
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
624
                                                         src=self.ranks[src],
625
626
627
628
629
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
630
                                                         src=self.ranks[src],
631
632
633
634
635
636
637
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
638
            metadata_list = self.broadcast_object(None, src=src)
639
640
            tensor_dict = {}
            async_handles = []
641
            for key, value in metadata_list:
642
643
644
645
646
647
                if isinstance(value, TensorMetadata):
                    tensor = torch.empty(value.size,
                                         dtype=value.dtype,
                                         device=value.device)
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
648
                        tensor_dict[key] = tensor
649
650
651
652
653
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
654
                            src=self.ranks[src],
655
656
657
658
                            group=metadata_group,
                            async_op=True)
                    else:
                        # use group for GPU tensors
659
660
661
662
663
                        handle = torch.distributed.broadcast(
                            tensor,
                            src=self.ranks[src],
                            group=group,
                            async_op=True)
664
                    async_handles.append(handle)
665
                    tensor_dict[key] = tensor
666
                else:
667
                    tensor_dict[key] = value
668
669
670
671
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

672
673
    def send_tensor_dict(
        self,
674
        tensor_dict: Dict[str, Union[torch.Tensor, Any]],
675
676
        dst: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
677
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
678
679
680
681
682
683
684
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict

685
686
687
688
689
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

690
691
692
693
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
694
            dst = (self.rank_in_group + 1) % self.world_size
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

        metadata_list: List[Tuple[Any, Any]] = []
        assert isinstance(
            tensor_dict,
            dict), f"Expecting a dictionary, got {type(tensor_dict)}"
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
        for tensor in tensor_list:
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
710
711
712
713
714
715

            # send-allgather: send only a slice, then do allgather.
            if (all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0):
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

716
717
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
718
719
720
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=metadata_group)
721
722
            else:
                # use group for GPU tensors
723
724
725
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=group)
726
727
728
729
        return None

    def recv_tensor_dict(
        self,
730
731
        src: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
732
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
733
734
735
736
737
738
739
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None

740
741
742
743
744
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

745
746
747
748
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
749
            src = (self.rank_in_group - 1) % self.world_size
750
751
752
        assert src < self.world_size, f"Invalid src rank ({src})"

        recv_metadata_list = self.recv_object(src=src)
753
        tensor_dict: Dict[str, Any] = {}
754
755
756
757
758
759
760
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device=value.device)
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
761
                    tensor_dict[key] = tensor
762
                    continue
763
764
765
766
767
768
769
770
771
772

                # send-allgather: send only a slice, then do allgather.
                use_all_gather = (all_gather_group is not None
                                  and tensor.numel() % all_gather_size == 0)

                if use_all_gather:
                    orig_shape = tensor.shape
                    tensor = tensor.reshape(all_gather_size,
                                            -1)[all_gather_rank]

773
774
775
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    torch.distributed.recv(tensor,
776
                                           src=self.ranks[src],
777
778
779
                                           group=metadata_group)
                else:
                    # use group for GPU tensors
780
781
782
                    torch.distributed.recv(tensor,
                                           src=self.ranks[src],
                                           group=group)
783
784
785
786
787
788
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
                        tensor, dim=0)
                    tensor = tensor.reshape(orig_shape)

789
                tensor_dict[key] = tensor
790
            else:
791
                tensor_dict[key] = value
792
793
        return tensor_dict

794
795
796
797
798
799
800
801
802
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

803
804
805
806
    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
        """Sends a tensor to the destination rank in a non-blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
        if dst is None:
807
            dst = (self.rank_in_group + 1) % self.world_size
808
809
810
811
812
813
814
815
816
817
818

        pynccl_comm = self.pynccl_comm
        if pynccl_comm is not None and not pynccl_comm.disabled:
            pynccl_comm.send(tensor, dst)
        else:
            torch.distributed.send(tensor, self.ranks[dst], self.device_group)

    def recv(self,
             size: torch.Size,
             dtype: torch.dtype,
             src: Optional[int] = None) -> torch.Tensor:
819
820
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
821
        if src is None:
822
            src = (self.rank_in_group - 1) % self.world_size
823
824
825
826
827
828
829
830
831

        tensor = torch.empty(size, dtype=dtype, device=self.device)
        pynccl_comm = self.pynccl_comm
        if pynccl_comm is not None and not pynccl_comm.disabled:
            pynccl_comm.recv(tensor, src)
        else:
            torch.distributed.recv(tensor, self.ranks[src], self.device_group)
        return tensor

832
833
834
835
836
837
838
839
840
841
842
    def destroy(self):
        if self.device_group is not None:
            torch.distributed.destroy_process_group(self.device_group)
            self.device_group = None
        if self.cpu_group is not None:
            torch.distributed.destroy_process_group(self.cpu_group)
            self.cpu_group = None
        if self.pynccl_comm is not None:
            self.pynccl_comm = None
        if self.ca_comm is not None:
            self.ca_comm = None
843
844
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
845
846
847
848
849
850
851
852
853
854


_WORLD: Optional[GroupCoordinator] = None


def get_world_group() -> GroupCoordinator:
    assert _WORLD is not None, ("world group is not initialized")
    return _WORLD


855
856
857
858
859
860
861
862
def init_world_group(ranks: List[int], local_rank: int,
                     backend: str) -> GroupCoordinator:
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
        use_pynccl=False,
        use_custom_allreduce=False,
863
        use_tpu_communicator=False,
864
        use_hpu_communicator=False,
865
        use_xpu_communicator=False,
866
        group_name="world",
867
868
869
    )


870
def init_model_parallel_group(
871
872
873
874
875
    group_ranks: List[List[int]],
    local_rank: int,
    backend: str,
    use_custom_allreduce: Optional[bool] = None,
    use_message_queue_broadcaster: bool = False,
876
    group_name: Optional[str] = None,
877
) -> GroupCoordinator:
878
879
    if use_custom_allreduce is None:
        use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
880
    from vllm.platforms import current_platform
881
882
883
884
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
885
886
887
        use_pynccl=current_platform.is_cuda_alike(),
        use_custom_allreduce=current_platform.is_cuda_alike()
        and use_custom_allreduce,
888
        use_tpu_communicator=True,
889
        use_hpu_communicator=True,
890
        use_xpu_communicator=True,
891
        use_message_queue_broadcaster=use_message_queue_broadcaster,
892
        group_name=group_name,
893
894
895
    )


896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
_TP: Optional[GroupCoordinator] = None


def get_tp_group() -> GroupCoordinator:
    assert _TP is not None, ("tensor model parallel group is not initialized")
    return _TP


# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group

_PP: Optional[GroupCoordinator] = None


def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, (
        "pipeline model parallel group is not initialized")
    return _PP
914
915


916
917
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
918

919
920
921
922
923
924
925
926
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None


def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
    assert _KV_TRANSFER is not None, (
        "disaggregated KV cache transfer parallel group is not initialized")
    return _KV_TRANSFER

927

928
@contextmanager
929
def graph_capture(device: torch.device):
930
931
932
933
934
935
936
937
938
939
940
941
942
    """
    `graph_capture` is a context manager which should surround the code that
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
943
944
945
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
            context):
946
947
        yield context

948

949
logger = init_logger(__name__)
950

951
_ENABLE_CUSTOM_ALL_REDUCE = True
952
953


954
955
956
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
957

Zhuohan Li's avatar
Zhuohan Li committed
958

959
def init_distributed_environment(
960
961
962
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
963
964
965
    local_rank: int = -1,
    backend: str = "nccl",
):
966
967
968
969
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
970
971
972
973
974
975
976
977
978
979
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
980
981
982
983
984
985
986
987
988
989
990
991
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
        if distributed_init_method == "env://":
            local_rank = envs.LOCAL_RANK
        else:
            local_rank = rank
    global _WORLD
    if _WORLD is None:
992
        ranks = list(range(torch.distributed.get_world_size()))
993
        _WORLD = init_world_group(ranks, local_rank, backend)
994
995
996
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
            "world group already initialized with a different world size")
997
998


Zhuohan Li's avatar
Zhuohan Li committed
999
1000
1001
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
1002
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
1003
1004
) -> None:
    """
1005
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
1006
1007

    Arguments:
1008
1009
1010
1011
1012
1013
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1014
1015
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1016
1017
1018
1019
1020
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1021
1022
1023
1024
1025
1026
1027
1028
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1029
1030
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1031

1032
1033
    if (world_size
            != tensor_model_parallel_size * pipeline_model_parallel_size):
Zhuohan Li's avatar
Zhuohan Li committed
1034
        raise RuntimeError(
1035
1036
1037
1038
            f"world_size ({world_size}) is not equal to "
            f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
            f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

1039
    # Build the tensor model-parallel groups.
1040
1041
    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
1042
1043
1044
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
    group_ranks = []
Zhuohan Li's avatar
Zhuohan Li committed
1045
    for i in range(num_tensor_model_parallel_groups):
1046
1047
1048
        ranks = list(
            range(i * tensor_model_parallel_size,
                  (i + 1) * tensor_model_parallel_size))
1049
        group_ranks.append(ranks)
1050
1051

    # message queue broadcaster is only used in tensor model parallel group
1052
    _TP = init_model_parallel_group(group_ranks,
1053
1054
                                    get_world_group().local_rank,
                                    backend,
1055
1056
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")
1057

1058
    # Build the pipeline model-parallel groups.
1059
1060
1061
1062
    num_pipeline_model_parallel_groups: int = (world_size //
                                               pipeline_model_parallel_size)
    global _PP
    assert _PP is None, (
1063
        "pipeline model parallel group is already initialized")
1064
    group_ranks = []
Zhuohan Li's avatar
Zhuohan Li committed
1065
    for i in range(num_pipeline_model_parallel_groups):
1066
        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
1067
        group_ranks.append(ranks)
1068
    # pipeline parallel does not need custom allreduce
1069
    _PP = init_model_parallel_group(group_ranks,
1070
1071
                                    get_world_group().local_rank,
                                    backend,
1072
1073
                                    use_custom_allreduce=False,
                                    group_name="pp")
1074

Zhuohan Li's avatar
Zhuohan Li committed
1075

1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
    """
    Initialize KV cache transfer parallel group.
    """

    global _KV_TRANSFER

    if vllm_config.kv_transfer_config is None:
        return

    if all([
1087
1088
            vllm_config.kv_transfer_config.need_kv_parallel_group, _KV_TRANSFER
            is None
1089
1090
1091
1092
1093
1094
1095
    ]):
        _KV_TRANSFER = kv_transfer.KVTransferAgent(
            rank=get_world_group().rank,
            local_rank=get_world_group().local_rank,
            config=vllm_config)


1096
1097
1098
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1099
    backend: Optional[str] = None,
1100
1101
1102
1103
1104
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1105
1106
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
1107
1108
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
1109
                                  pipeline_model_parallel_size, backend)
1110
1111
1112
1113
1114
1115
1116
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
1117
1118
    pp_world_size = get_pp_group().world_size
    assert (pp_world_size == pipeline_model_parallel_size), (
1119
        "pipeline parallel group already initialized, but of unexpected size: "
1120
        f"{pp_world_size=} vs. "
1121
1122
1123
        f"{pipeline_model_parallel_size=}")


Zhuohan Li's avatar
Zhuohan Li committed
1124
def model_parallel_is_initialized():
1125
    """Check if tensor and pipeline parallel groups are initialized."""
1126
    return (_TP is not None and _PP is not None)
1127
1128


1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1157
1158
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1159
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1160
1161
1162
1163


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1164
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1165
1166
1167


def destroy_model_parallel():
1168
    """Set the groups to none and destroy them."""
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
    global _TP
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None


def destroy_distributed_environment():
    global _WORLD
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1187
1188


1189
1190
1191
1192
1193
1194
1195
1196
1197
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    if shutdown_ray:
        import ray  # Lazy import Ray
        ray.shutdown()
    gc.collect()
1198
    from vllm.platforms import current_platform
1199
    if not current_platform.is_cpu():
1200
        torch.cuda.empty_cache()
1201
1202
1203
1204
1205
    try:
        torch._C._host_emptyCache()
    except AttributeError:
        logger.warning(
            "torch._C._host_emptyCache() only available in Pytorch >=2.5")
1206
1207


1208
1209
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
                        source_rank: int = 0) -> List[bool]:
1210
    """
1211
1212
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1213
1214
    memory system (shared access to shared memory).
    """
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
    if isinstance(pg, ProcessGroup):
        assert torch.distributed.get_backend(
            pg) != torch.distributed.Backend.NCCL, (
                "in_the_same_node_as should be tested with a non-NCCL group.")
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1229
1230
1231
1232
1233
1234
1235
1236
1237

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1238
            if rank == source_rank:
1239
1240
1241
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
1242
1243
1244
1245
1246
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
                        [shm.name], src=ranks[source_rank], group=pg)
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1247
                is_in_the_same_node[rank] = 1
1248
1249
            else:
                # try to open the shared memory segment
1250
1251
1252
1253
1254
1255
1256
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
                        recv, src=ranks[source_rank], group=pg)
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1257
1258
1259
1260
1261
1262
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
1263
1264
1265
1266
1267
1268
1269
1270
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1271
1272
1273
1274
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1275
1276
1277

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1278
        if rank == source_rank and shm:
1279
            shm.unlink()
1280

1281
1282
1283
1284
1285
1286
1287
1288
1289
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

1290
    return [x == 1 for x in aggregated_data.tolist()]