parallel_state.py 48.6 KB
Newer Older
Woosuk Kwon's avatar
Woosuk Kwon committed
1
# Copyright 2023 The vLLM team.
2
3
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
4
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
5
6
7
8
9
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
10
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
11
12
13
14
15
16
17
18
19
20
21
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
22
import contextlib
23
import gc
24
import pickle
25
import weakref
26
27
28
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
29
from multiprocessing import shared_memory
30
31
from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple,
                    Union)
32
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
33
34

import torch
35
import torch.distributed
36
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
37

38
import vllm.distributed.kv_transfer.kv_transfer_agent as kv_transfer
39
import vllm.envs as envs
40
from vllm.distributed.utils import StatelessProcessGroup
41
from vllm.logger import init_logger
42
from vllm.utils import direct_register_custom_op, supports_custom_op
43

44
45
46
if TYPE_CHECKING:
    from vllm.config import VllmConfig

47

48
49
50
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
51

52

53
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
54

55
56

def _split_tensor_dict(
57
58
    tensor_dict: Dict[str, Union[torch.Tensor, Any]]
) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
59
60
61
62
63
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
64
    metadata_list: List[Tuple[str, Any]] = []
65
    tensor_list: List[torch.Tensor] = []
66
67
68
69
70
71
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
72
            device = value.device.type
73
            metadata_list.append(
74
                (key, TensorMetadata(device, value.dtype, value.size())))
75
76
            tensor_list.append(value)
        else:
77
            metadata_list.append((key, value))
78
79
80
    return metadata_list, tensor_list


81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
_group_name_counter: Dict[str, int] = {}


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


97
_groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
98
99
100


def _register_group(group: "GroupCoordinator") -> None:
101
    _groups[group.unique_name] = weakref.ref(group)
102
103


104
105
106
107
108
109
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
110
111


112
113
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
114

115

116
if supports_custom_op():
117
    direct_register_custom_op(
118
119
        op_name="all_reduce",
        op_func=all_reduce,
120
        mutates_args=[],
121
        fake_impl=all_reduce_fake,
122
123
    )

124

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
        the processes in the group. It can route the communication to
        a specific implementation (e.g. switch allreduce implementation
        based on the tensor size and cuda graph mode).
    """

    # available attributes:
    rank: int  # global rank
    ranks: List[int]  # global ranks in the group
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
    use_pynccl: bool  # a hint of whether to use PyNccl
    use_custom_allreduce: bool  # a hint of whether to use CustomAllreduce
    # communicators are only created for world size > 1
    pynccl_comm: Optional[Any]  # PyNccl communicator
    ca_comm: Optional[Any]  # Custom allreduce communicator
156
    mq_broadcaster: Optional[Any]  # shared memory broadcaster
157
158
159
160
161
162
163
164

    def __init__(
        self,
        group_ranks: List[List[int]],
        local_rank: int,
        torch_distributed_backend: Union[str, Backend],
        use_pynccl: bool,
        use_custom_allreduce: bool,
165
        use_tpu_communicator: bool,
166
        use_hpu_communicator: bool,
167
        use_xpu_communicator: bool,
168
        use_message_queue_broadcaster: bool = False,
169
        group_name: Optional[str] = None,
170
    ):
171
172
173
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
        self.device_group = None
        self.cpu_group = None

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
                self.device_group = device_group
                self.cpu_group = cpu_group

        assert self.cpu_group is not None
        assert self.device_group is not None

196
        from vllm.platforms import current_platform
197
        if current_platform.is_cuda_alike():
198
199
200
201
202
203
            self.device = torch.device(f"cuda:{local_rank}")
        else:
            self.device = torch.device("cpu")

        self.use_pynccl = use_pynccl
        self.use_custom_allreduce = use_custom_allreduce
204
        self.use_tpu_communicator = use_tpu_communicator
205
        self.use_hpu_communicator = use_hpu_communicator
206
        self.use_xpu_communicator = use_xpu_communicator
207
208
209
210
211
212
213

        # lazy import to avoid documentation build error
        from vllm.distributed.device_communicators.custom_all_reduce import (
            CustomAllreduce)
        from vllm.distributed.device_communicators.pynccl import (
            PyNcclCommunicator)

214
        self.pynccl_comm: Optional[PyNcclCommunicator] = None
215
216
217
218
219
220
        if use_pynccl and self.world_size > 1:
            self.pynccl_comm = PyNcclCommunicator(
                group=self.cpu_group,
                device=self.device,
            )

221
        self.ca_comm: Optional[CustomAllreduce] = None
222
223
224
225
226
227
228
        if use_custom_allreduce and self.world_size > 1:
            # Initialize a custom fast all-reduce implementation.
            self.ca_comm = CustomAllreduce(
                group=self.cpu_group,
                device=self.device,
            )

229
230
        from vllm.distributed.device_communicators.tpu_communicator import (
            TpuCommunicator)
231
        self.tpu_communicator: Optional[TpuCommunicator] = None
232
233
234
        if use_tpu_communicator and self.world_size > 1:
            self.tpu_communicator = TpuCommunicator(group=self.cpu_group)

235
236
237
238
239
240
        from vllm.distributed.device_communicators.hpu_communicator import (
            HpuCommunicator)
        self.hpu_communicator: Optional[HpuCommunicator]
        if use_hpu_communicator and self.world_size > 1:
            self.hpu_communicator = HpuCommunicator(group=self.device_group)

241
242
243
244
245
246
        from vllm.distributed.device_communicators.xpu_communicator import (
            XpuCommunicator)
        self.xpu_communicator: Optional[XpuCommunicator]
        if use_xpu_communicator and self.world_size > 1:
            self.xpu_communicator = XpuCommunicator(group=self.device_group)

247
        from vllm.distributed.device_communicators.shm_broadcast import (
248
249
250
251
            MessageQueue)
        self.mq_broadcaster: Optional[MessageQueue] = None
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
252
                self.cpu_group, 1 << 22, 6)
253

254
255
256
257
258
259
260
261
262
263
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

264
265
266
267
268
269
270
271
272
273
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
    def graph_capture(
            self, graph_capture_context: Optional[GraphCaptureContext] = None):
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

        ca_comm = self.ca_comm
        maybe_ca_context = nullcontext(
        ) if ca_comm is None else ca_comm.capture()
300
301
302
303
304
305
306

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

307
308
309
310
311
312
        with torch.cuda.stream(stream), maybe_ca_context:
            pynccl_comm = self.pynccl_comm
            maybe_pynccl_context: Any
            if not pynccl_comm:
                maybe_pynccl_context = nullcontext()
            else:
313
                maybe_pynccl_context = pynccl_comm.change_state()
314
315
316
317
318
            with maybe_pynccl_context:
                yield graph_capture_context

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
319
320
321
322
323
324
325
326
327
328
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
329
330
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
331
332
333
334
335
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

336
337
338
339
340
        if input_.is_cpu:
            import intel_extension_for_pytorch as ipex
            ipex.distributed.all_reduce(input_, group=self.device_group)
            return input_

341
342
343
        if self.tpu_communicator is not None and \
            not self.tpu_communicator.disabled:
            # TPU handles Dynamo with its own logic.
344
            return self.tpu_communicator.all_reduce(input_)
345

346
347
348
349
        if self.hpu_communicator is not None and \
            not self.hpu_communicator.disabled:
            return self.hpu_communicator.all_reduce(input_)

350
351
352
353
        if self.xpu_communicator is not None and \
                not self.xpu_communicator.disabled:
            return self.xpu_communicator.all_reduce(input_)

354
        return torch.ops.vllm.all_reduce(input_, group_name=self.unique_name)
355

356
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
357
358
        # always try custom allreduce first,
        # and then pynccl.
359
        ca_comm = self.ca_comm
360
361
362
363
364
        if ca_comm is not None and not ca_comm.disabled and \
            ca_comm.should_custom_ar(input_):
            out = ca_comm.custom_all_reduce(input_)
            assert out is not None
            return out
365
        pynccl_comm = self.pynccl_comm
366
367
368
369
370
371
372
373
374
375
376
377
378
        assert pynccl_comm is not None
        # TODO: pynccl should not use `stream=`
        # it can just always use the current stream.
        out = pynccl_comm.all_reduce(input_,
                                     stream=torch.cuda.current_stream())
        if out is None:
            # fall back to the default all-reduce using PyTorch.
            # this usually happens during testing.
            # when we run the model, allreduce only happens for the TP
            # group, where we always have either custom allreduce or pynccl.
            out = input_.clone()
            torch.distributed.all_reduce(out, group=self.device_group)
        return out
379
380
381
382
383
384
385
386

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
387
388
389
390
391
392

        # For TPUs, use TPU communicator.
        tpu_comm = self.tpu_communicator
        if tpu_comm is not None and not tpu_comm.disabled:
            return tpu_comm.all_gather(input_, dim)

393
394
395
396
397
        # For HPUs, use HPU communicator.
        hpu_comm = self.hpu_communicator
        if hpu_comm is not None and not hpu_comm.disabled:
            return hpu_comm.all_gather(input_, dim)

398
399
400
401
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
        input_size = input_.size()
402
403
404
405
        # NOTE: we have to use concat-style all-gather here,
        # stack-style all-gather has compatibility issues with
        # torch.compile . see https://github.com/pytorch/pytorch/issues/138795
        output_size = (input_size[0] * world_size, ) + input_size[1:]
406
        # Allocate output tensor.
407
        output_tensor = torch.empty(output_size,
408
409
410
411
412
413
414
                                    dtype=input_.dtype,
                                    device=input_.device)
        # All-gather.
        torch.distributed.all_gather_into_tensor(output_tensor,
                                                 input_,
                                                 group=self.device_group)
        # Reshape
415
        output_tensor = output_tensor.reshape((world_size, ) + input_size)
416
417
418
419
420
421
422
423
424
425
        output_tensor = output_tensor.movedim(0, dim)
        output_tensor = output_tensor.reshape(input_size[:dim] +
                                              (world_size *
                                               input_size[dim], ) +
                                              input_size[dim + 1:])
        return output_tensor

    def gather(self,
               input_: torch.Tensor,
               dst: int = 0,
426
               dim: int = -1) -> Optional[torch.Tensor]:
427
428
429
430
431
432
433
434
435
436
437
438
439
440
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
        if dim < 0:
            # Convert negative dim to positive.
            dim += input_.dim()
441
442
443
444
        if self.xpu_communicator is not None and \
                not self.xpu_communicator.disabled:
            return self.xpu_communicator.gather(input_, self.rank_in_group,
                                                dst, dim)
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
        # Allocate output tensor.
        if self.rank_in_group == dst:
            gather_list = [torch.empty_like(input_) for _ in range(world_size)]
        else:
            gather_list = None
        # Gather.
        torch.distributed.gather(input_,
                                 gather_list,
                                 dst=self.ranks[dst],
                                 group=self.device_group)
        if self.rank_in_group == dst:
            output_tensor = torch.cat(gather_list, dim=dim)
        else:
            output_tensor = None
        return output_tensor

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
        torch.distributed.broadcast(input_,
                                    src=self.ranks[src],
                                    group=self.device_group)
        return input_

476
477
478
479
480
481
482
483
484
    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
485
486
487
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
488
489
490
491
492
493
494
495
496
497
498
499
        if self.rank_in_group == src:
            torch.distributed.broadcast_object_list([obj],
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return obj
        else:
            recv = [None]
            torch.distributed.broadcast_object_list(recv,
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return recv[0]

500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
    def broadcast_object_list(self,
                              obj_list: List[Any],
                              src: int = 0,
                              group: Optional[ProcessGroup] = None):
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
        torch.distributed.broadcast_object_list(obj_list,
                                                src=self.ranks[src],
                                                group=self.device_group)
        return obj_list

518
519
520
521
522
523
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

524
        assert dst != self.rank_in_group, (
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
            "Invalid destination rank. Destination rank is the same "
            "as the current rank.")

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

        size_tensor = torch.tensor([object_tensor.numel()],
                                   dtype=torch.long,
                                   device="cpu")

        # Send object size

        torch.distributed.send(size_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        # Send object
        torch.distributed.send(object_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

554
        assert src != self.rank_in_group, (
555
556
557
558
559
560
561
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
        rank_size = torch.distributed.recv(size_tensor,
562
                                           src=self.ranks[src],
563
564
565
566
567
568
569
570
571
                                           group=self.cpu_group)

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
            device="cpu")

        rank_object = torch.distributed.recv(object_tensor,
572
                                             src=self.ranks[src],
573
574
575
576
577
578
579
580
581
                                             group=self.cpu_group)

        assert rank_object == rank_size, (
            "Received object sender rank does not match the size sender rank.")

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

582
583
    def broadcast_tensor_dict(
        self,
584
        tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
585
586
587
        src: int = 0,
        group: Optional[ProcessGroup] = None,
        metadata_group: Optional[ProcessGroup] = None
588
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
589
590
591
592
593
594
595
596
597
598
599
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if (not torch.distributed.is_initialized() or self.world_size == 1):
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

600
601
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
602
603
604
605
606
607
608
609
            metadata_list: List[Tuple[Any, Any]] = []
            assert isinstance(
                tensor_dict,
                dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
610
            self.broadcast_object(metadata_list, src=src)
611
612
613
614
615
616
617
618
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
619
                                                         src=self.ranks[src],
620
621
622
623
624
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
625
                                                         src=self.ranks[src],
626
627
628
629
630
631
632
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
633
            metadata_list = self.broadcast_object(None, src=src)
634
635
            tensor_dict = {}
            async_handles = []
636
            for key, value in metadata_list:
637
638
639
640
641
642
                if isinstance(value, TensorMetadata):
                    tensor = torch.empty(value.size,
                                         dtype=value.dtype,
                                         device=value.device)
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
643
                        tensor_dict[key] = tensor
644
645
646
647
648
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
649
                            src=self.ranks[src],
650
651
652
653
                            group=metadata_group,
                            async_op=True)
                    else:
                        # use group for GPU tensors
654
655
656
657
658
                        handle = torch.distributed.broadcast(
                            tensor,
                            src=self.ranks[src],
                            group=group,
                            async_op=True)
659
                    async_handles.append(handle)
660
                    tensor_dict[key] = tensor
661
                else:
662
                    tensor_dict[key] = value
663
664
665
666
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

667
668
    def send_tensor_dict(
        self,
669
        tensor_dict: Dict[str, Union[torch.Tensor, Any]],
670
671
        dst: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
672
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
673
674
675
676
677
678
679
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict

680
681
682
683
684
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

685
686
687
688
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
689
            dst = (self.rank_in_group + 1) % self.world_size
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

        metadata_list: List[Tuple[Any, Any]] = []
        assert isinstance(
            tensor_dict,
            dict), f"Expecting a dictionary, got {type(tensor_dict)}"
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
        for tensor in tensor_list:
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
705
706
707
708
709
710

            # send-allgather: send only a slice, then do allgather.
            if (all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0):
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

711
712
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
713
714
715
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=metadata_group)
716
717
            else:
                # use group for GPU tensors
718
719
720
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=group)
721
722
723
724
        return None

    def recv_tensor_dict(
        self,
725
726
        src: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
727
    ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
728
729
730
731
732
733
734
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None

735
736
737
738
739
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

740
741
742
743
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
744
            src = (self.rank_in_group - 1) % self.world_size
745
746
747
        assert src < self.world_size, f"Invalid src rank ({src})"

        recv_metadata_list = self.recv_object(src=src)
748
        tensor_dict: Dict[str, Any] = {}
749
750
751
752
753
754
755
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device=value.device)
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
756
                    tensor_dict[key] = tensor
757
                    continue
758
759
760
761
762
763
764
765
766
767

                # send-allgather: send only a slice, then do allgather.
                use_all_gather = (all_gather_group is not None
                                  and tensor.numel() % all_gather_size == 0)

                if use_all_gather:
                    orig_shape = tensor.shape
                    tensor = tensor.reshape(all_gather_size,
                                            -1)[all_gather_rank]

768
769
770
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    torch.distributed.recv(tensor,
771
                                           src=self.ranks[src],
772
773
774
                                           group=metadata_group)
                else:
                    # use group for GPU tensors
775
776
777
                    torch.distributed.recv(tensor,
                                           src=self.ranks[src],
                                           group=group)
778
779
780
781
782
783
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
                        tensor, dim=0)
                    tensor = tensor.reshape(orig_shape)

784
                tensor_dict[key] = tensor
785
            else:
786
                tensor_dict[key] = value
787
788
        return tensor_dict

789
790
791
792
793
794
795
796
797
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

798
799
800
801
    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
        """Sends a tensor to the destination rank in a non-blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
        if dst is None:
802
            dst = (self.rank_in_group + 1) % self.world_size
803
804
805
806
807
808
809
810
811
812
813

        pynccl_comm = self.pynccl_comm
        if pynccl_comm is not None and not pynccl_comm.disabled:
            pynccl_comm.send(tensor, dst)
        else:
            torch.distributed.send(tensor, self.ranks[dst], self.device_group)

    def recv(self,
             size: torch.Size,
             dtype: torch.dtype,
             src: Optional[int] = None) -> torch.Tensor:
814
815
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
816
        if src is None:
817
            src = (self.rank_in_group - 1) % self.world_size
818
819
820
821
822
823
824
825
826

        tensor = torch.empty(size, dtype=dtype, device=self.device)
        pynccl_comm = self.pynccl_comm
        if pynccl_comm is not None and not pynccl_comm.disabled:
            pynccl_comm.recv(tensor, src)
        else:
            torch.distributed.recv(tensor, self.ranks[src], self.device_group)
        return tensor

827
828
829
830
831
832
833
834
835
836
837
    def destroy(self):
        if self.device_group is not None:
            torch.distributed.destroy_process_group(self.device_group)
            self.device_group = None
        if self.cpu_group is not None:
            torch.distributed.destroy_process_group(self.cpu_group)
            self.cpu_group = None
        if self.pynccl_comm is not None:
            self.pynccl_comm = None
        if self.ca_comm is not None:
            self.ca_comm = None
838
839
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
840
841
842
843
844
845
846
847
848
849


_WORLD: Optional[GroupCoordinator] = None


def get_world_group() -> GroupCoordinator:
    assert _WORLD is not None, ("world group is not initialized")
    return _WORLD


850
851
852
853
854
855
856
857
def init_world_group(ranks: List[int], local_rank: int,
                     backend: str) -> GroupCoordinator:
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
        use_pynccl=False,
        use_custom_allreduce=False,
858
        use_tpu_communicator=False,
859
        use_hpu_communicator=False,
860
        use_xpu_communicator=False,
861
        group_name="world",
862
863
864
    )


865
def init_model_parallel_group(
866
867
868
869
870
    group_ranks: List[List[int]],
    local_rank: int,
    backend: str,
    use_custom_allreduce: Optional[bool] = None,
    use_message_queue_broadcaster: bool = False,
871
    group_name: Optional[str] = None,
872
) -> GroupCoordinator:
873
874
    if use_custom_allreduce is None:
        use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
875
876
877
878
879
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
        use_pynccl=True,
880
        use_custom_allreduce=use_custom_allreduce,
881
        use_tpu_communicator=True,
882
        use_hpu_communicator=True,
883
        use_xpu_communicator=True,
884
        use_message_queue_broadcaster=use_message_queue_broadcaster,
885
        group_name=group_name,
886
887
888
    )


889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
_TP: Optional[GroupCoordinator] = None


def get_tp_group() -> GroupCoordinator:
    assert _TP is not None, ("tensor model parallel group is not initialized")
    return _TP


# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group

_PP: Optional[GroupCoordinator] = None


def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, (
        "pipeline model parallel group is not initialized")
    return _PP
907
908


909
910
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
911

912
913
914
915
916
917
918
919
_KV_TRANSFER: Optional[kv_transfer.KVTransferAgent] = None


def get_kv_transfer_group() -> kv_transfer.KVTransferAgent:
    assert _KV_TRANSFER is not None, (
        "disaggregated KV cache transfer parallel group is not initialized")
    return _KV_TRANSFER

920

921
@contextmanager
922
def graph_capture(device: torch.device):
923
924
925
926
927
928
929
930
931
932
933
934
935
    """
    `graph_capture` is a context manager which should surround the code that
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
936
937
938
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
            context):
939
940
        yield context

941

942
logger = init_logger(__name__)
943

944
_ENABLE_CUSTOM_ALL_REDUCE = True
945
946


947
948
949
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
950

Zhuohan Li's avatar
Zhuohan Li committed
951

952
def init_distributed_environment(
953
954
955
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
956
957
958
    local_rank: int = -1,
    backend: str = "nccl",
):
959
960
961
962
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
963
964
965
966
967
968
969
970
971
972
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
973
974
975
976
977
978
979
980
981
982
983
984
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
        if distributed_init_method == "env://":
            local_rank = envs.LOCAL_RANK
        else:
            local_rank = rank
    global _WORLD
    if _WORLD is None:
985
        ranks = list(range(torch.distributed.get_world_size()))
986
        _WORLD = init_world_group(ranks, local_rank, backend)
987
988
989
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
            "world group already initialized with a different world size")
990
991


Zhuohan Li's avatar
Zhuohan Li committed
992
993
994
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
995
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
996
997
) -> None:
    """
998
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
999
1000

    Arguments:
1001
1002
1003
1004
1005
1006
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
1007
1008
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
1009
1010
1011
1012
1013
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1014
1015
1016
1017
1018
1019
1020
1021
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1022
1023
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1024

1025
1026
    if (world_size !=
            tensor_model_parallel_size * pipeline_model_parallel_size):
Zhuohan Li's avatar
Zhuohan Li committed
1027
        raise RuntimeError(
1028
1029
1030
1031
            f"world_size ({world_size}) is not equal to "
            f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
            f"pipeline_model_parallel_size ({pipeline_model_parallel_size})")

1032
    # Build the tensor model-parallel groups.
1033
1034
    num_tensor_model_parallel_groups: int = (world_size //
                                             tensor_model_parallel_size)
1035
1036
1037
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
    group_ranks = []
Zhuohan Li's avatar
Zhuohan Li committed
1038
    for i in range(num_tensor_model_parallel_groups):
1039
1040
1041
        ranks = list(
            range(i * tensor_model_parallel_size,
                  (i + 1) * tensor_model_parallel_size))
1042
        group_ranks.append(ranks)
1043
1044

    # message queue broadcaster is only used in tensor model parallel group
1045
    _TP = init_model_parallel_group(group_ranks,
1046
1047
                                    get_world_group().local_rank,
                                    backend,
1048
1049
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")
1050

1051
    # Build the pipeline model-parallel groups.
1052
1053
1054
1055
    num_pipeline_model_parallel_groups: int = (world_size //
                                               pipeline_model_parallel_size)
    global _PP
    assert _PP is None, (
1056
        "pipeline model parallel group is already initialized")
1057
    group_ranks = []
Zhuohan Li's avatar
Zhuohan Li committed
1058
    for i in range(num_pipeline_model_parallel_groups):
1059
        ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
1060
        group_ranks.append(ranks)
1061
    # pipeline parallel does not need custom allreduce
1062
    _PP = init_model_parallel_group(group_ranks,
1063
1064
                                    get_world_group().local_rank,
                                    backend,
1065
1066
                                    use_custom_allreduce=False,
                                    group_name="pp")
1067

Zhuohan Li's avatar
Zhuohan Li committed
1068

1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
def ensure_kv_transfer_initialized(vllm_config: "VllmConfig") -> None:
    """
    Initialize KV cache transfer parallel group.
    """

    global _KV_TRANSFER

    if vllm_config.kv_transfer_config is None:
        return

    if all([
            vllm_config.kv_transfer_config.need_kv_parallel_group,
            _KV_TRANSFER is None
    ]):
        _KV_TRANSFER = kv_transfer.KVTransferAgent(
            rank=get_world_group().rank,
            local_rank=get_world_group().local_rank,
            config=vllm_config)


1089
1090
1091
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1092
    backend: Optional[str] = None,
1093
1094
1095
1096
1097
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1098
1099
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
1100
1101
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
1102
                                  pipeline_model_parallel_size, backend)
1103
1104
1105
1106
1107
1108
1109
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
1110
1111
    pp_world_size = get_pp_group().world_size
    assert (pp_world_size == pipeline_model_parallel_size), (
1112
        "pipeline parallel group already initialized, but of unexpected size: "
1113
        f"{pp_world_size=} vs. "
1114
1115
1116
        f"{pipeline_model_parallel_size=}")


Zhuohan Li's avatar
Zhuohan Li committed
1117
def model_parallel_is_initialized():
1118
    """Check if tensor and pipeline parallel groups are initialized."""
1119
    return (_TP is not None and _PP is not None)
1120
1121


1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1150
1151
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1152
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1153
1154
1155
1156


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1157
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1158
1159
1160


def destroy_model_parallel():
1161
    """Set the groups to none and destroy them."""
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
    global _TP
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None


def destroy_distributed_environment():
    global _WORLD
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1180
1181


1182
1183
1184
1185
1186
1187
1188
1189
1190
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    if shutdown_ray:
        import ray  # Lazy import Ray
        ray.shutdown()
    gc.collect()
1191
    from vllm.platforms import current_platform
1192
    if not current_platform.is_cpu():
1193
1194
1195
        torch.cuda.empty_cache()


1196
1197
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
                        source_rank: int = 0) -> List[bool]:
1198
    """
1199
1200
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1201
1202
    memory system (shared access to shared memory).
    """
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
    if isinstance(pg, ProcessGroup):
        assert torch.distributed.get_backend(
            pg) != torch.distributed.Backend.NCCL, (
                "in_the_same_node_as should be tested with a non-NCCL group.")
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1217
1218
1219
1220
1221
1222
1223
1224
1225

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1226
            if rank == source_rank:
1227
1228
1229
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
1230
1231
1232
1233
1234
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
                        [shm.name], src=ranks[source_rank], group=pg)
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1235
                is_in_the_same_node[rank] = 1
1236
1237
            else:
                # try to open the shared memory segment
1238
1239
1240
1241
1242
1243
1244
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
                        recv, src=ranks[source_rank], group=pg)
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1245
1246
1247
1248
1249
1250
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
1251
1252
1253
1254
1255
1256
1257
1258
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1259
1260
1261
1262
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1263
1264
1265

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1266
        if rank == source_rank and shm:
1267
            shm.unlink()
1268

1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]