parallel_state.py 49 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

Woosuk Kwon's avatar
Woosuk Kwon committed
3
# Copyright 2023 The vLLM team.
4
5
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
6
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
7
8
9
10
11
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
12
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
13
14
15
16
17
18
19
20
21
22
23
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
24
import contextlib
25
import gc
26
import pickle
27
import weakref
28
29
30
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
31
from multiprocessing import shared_memory
32
from typing import Any, Callable, Optional, Union
33
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
34
35

import torch
36
import torch.distributed
37
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
38

39
import vllm.envs as envs
40
41
from vllm.distributed.device_communicators.base_device_communicator import (
    DeviceCommunicatorBase)
42
from vllm.distributed.utils import StatelessProcessGroup
43
from vllm.logger import init_logger
44
45
from vllm.utils import (direct_register_custom_op, get_distributed_init_method,
                        resolve_obj_by_qualname, supports_custom_op)
46
47


48
49
50
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
51

52

53
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
54

55
56

def _split_tensor_dict(
57
58
    tensor_dict: dict[str, Union[torch.Tensor, Any]]
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
59
60
61
62
63
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
64
65
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
66
67
68
69
70
71
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
72
            device = value.device.type
73
            metadata_list.append(
74
                (key, TensorMetadata(device, value.dtype, value.size())))
75
76
            tensor_list.append(value)
        else:
77
            metadata_list.append((key, value))
78
79
80
    return metadata_list, tensor_list


81
_group_name_counter: dict[str, int] = {}
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


97
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
98
99
100


def _register_group(group: "GroupCoordinator") -> None:
101
    _groups[group.unique_name] = weakref.ref(group)
102
103


104
105
106
107
108
109
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
110
111


112
113
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
114

115

116
117
118
119
120
121
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
                   group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
122
    return group._reduce_scatter_out_place(tensor, dim)
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137


def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
                        group_name: str) -> torch.Tensor:
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
               group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
138
    return group._all_gather_out_place(tensor, dim)
139
140
141
142
143
144
145
146
147


def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
                    group_name: str) -> torch.Tensor:
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


148
if supports_custom_op():
149
    from vllm.platforms import current_platform
150
    direct_register_custom_op(
151
152
        op_name="all_reduce",
        op_func=all_reduce,
153
        mutates_args=[],
154
        fake_impl=all_reduce_fake,
155
        dispatch_key=current_platform.dispatch_key,
156
157
    )

158
159
160
161
162
    direct_register_custom_op(
        op_name="reduce_scatter",
        op_func=reduce_scatter,
        mutates_args=[],
        fake_impl=reduce_scatter_fake,
163
        dispatch_key=current_platform.dispatch_key,
164
165
166
167
168
169
170
    )

    direct_register_custom_op(
        op_name="all_gather",
        op_func=all_gather,
        mutates_args=[],
        fake_impl=all_gather_fake,
171
        dispatch_key=current_platform.dispatch_key,
172
173
    )

174

175
176
177
178
179
180
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
181
182
        the processes in the group. It manages both CPU and device
        communication.
183
184
185
186
    """

    # available attributes:
    rank: int  # global rank
187
    ranks: list[int]  # global ranks in the group
188
189
190
191
192
193
194
195
196
197
198
199
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
200
201
    use_device_communicator: bool  # whether to use device communicator
    device_communicator: DeviceCommunicatorBase  # device communicator
202
    mq_broadcaster: Optional[Any]  # shared memory broadcaster
203
204
205

    def __init__(
        self,
206
        group_ranks: list[list[int]],
207
208
        local_rank: int,
        torch_distributed_backend: Union[str, Backend],
209
        use_device_communicator: bool,
210
        use_message_queue_broadcaster: bool = False,
211
        group_name: Optional[str] = None,
212
    ):
213
214
215
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
        self.device_group = None
        self.cpu_group = None

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
                self.device_group = device_group
                self.cpu_group = cpu_group

        assert self.cpu_group is not None
        assert self.device_group is not None

238
        from vllm.platforms import current_platform
239

240
        if current_platform.is_cuda_alike():
241
            self.device = torch.device(f"cuda:{local_rank}")
242
243
244
        elif current_platform.is_out_of_tree():
            self.device = torch.device(
                f"{current_platform.device_name}:{local_rank}")
245
246
247
        else:
            self.device = torch.device("cpu")

248
        self.use_device_communicator = use_device_communicator
249

250
251
252
253
254
255
        self.device_communicator: DeviceCommunicatorBase = None  # type: ignore
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
                current_platform.get_device_communicator_cls())
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
256
                device=self.device,
257
258
                device_group=self.device_group,
                unique_name=self.unique_name,
259
260
            )

261
        from vllm.distributed.device_communicators.shm_broadcast import (
262
263
264
265
            MessageQueue)
        self.mq_broadcaster: Optional[MessageQueue] = None
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
266
                self.cpu_group, 1 << 22, 6)
267

268
        from vllm.platforms import current_platform
269
270
        self.use_custom_op_call = (current_platform.is_cuda_alike()
                                   or current_platform.is_tpu())
271

272
273
274
275
276
277
278
279
280
281
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

282
283
284
285
286
287
288
289
290
291
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
    def graph_capture(
            self, graph_capture_context: Optional[GraphCaptureContext] = None):
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

315
316
317
318
319
320
321
322
323
324
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
            CudaCommunicator)
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
325
326
327
328
329
330
331

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

332
        with torch.cuda.stream(stream), maybe_ca_context:
333
            yield graph_capture_context
334
335
336

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
337
338
339
340
341
342
343
344
345
346
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
347
348
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
349
350
351
352
353
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

354
355
356
357
358
        if self.use_custom_op_call:
            return torch.ops.vllm.all_reduce(input_,
                                             group_name=self.unique_name)
        else:
            return self._all_reduce_out_place(input_)
359

360
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
361
        return self.device_communicator.all_reduce(input_)
362
363
364
365
366
367
368
369

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
370

371
372
373
374
375
376
377
378
379
380
        if self.use_custom_op_call:
            return torch.ops.vllm.all_gather(input_,
                                             dim,
                                             world_size,
                                             group_name=self.unique_name)
        else:
            return self._all_gather_out_place(input_, dim)

    def _all_gather_out_place(self, input_: torch.Tensor,
                              dim: int) -> torch.Tensor:
381
        return self.device_communicator.all_gather(input_, dim)
382

383
384
385
386
387
388
389
390
391
392
    def reduce_scatter(self,
                       input_: torch.Tensor,
                       dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

393
394
395
396
397
398
399
400
401
402
        if self.use_custom_op_call:
            return torch.ops.vllm.reduce_scatter(input_,
                                                 dim,
                                                 world_size,
                                                 group_name=self.unique_name)
        else:
            return self._reduce_scatter_out_place(input_, dim)

    def _reduce_scatter_out_place(self, input_: torch.Tensor,
                                  dim: int) -> torch.Tensor:
403
404
        return self.device_communicator.reduce_scatter(input_, dim)

405
406
407
    def gather(self,
               input_: torch.Tensor,
               dst: int = 0,
408
               dim: int = -1) -> Optional[torch.Tensor]:
409
410
411
412
413
414
415
416
417
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
418
        return self.device_communicator.gather(input_, dst, dim)
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
        torch.distributed.broadcast(input_,
                                    src=self.ranks[src],
                                    group=self.device_group)
        return input_

435
436
437
438
439
440
441
442
443
    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
444
445
446
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
447
448
449
450
451
452
453
454
455
456
457
458
        if self.rank_in_group == src:
            torch.distributed.broadcast_object_list([obj],
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return obj
        else:
            recv = [None]
            torch.distributed.broadcast_object_list(recv,
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return recv[0]

459
    def broadcast_object_list(self,
460
                              obj_list: list[Any],
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
                              src: int = 0,
                              group: Optional[ProcessGroup] = None):
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
        torch.distributed.broadcast_object_list(obj_list,
                                                src=self.ranks[src],
                                                group=self.device_group)
        return obj_list

477
478
479
480
481
482
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

483
        assert dst != self.rank_in_group, (
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
            "Invalid destination rank. Destination rank is the same "
            "as the current rank.")

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

        size_tensor = torch.tensor([object_tensor.numel()],
                                   dtype=torch.long,
                                   device="cpu")

        # Send object size

        torch.distributed.send(size_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        # Send object
        torch.distributed.send(object_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

513
        assert src != self.rank_in_group, (
514
515
516
517
518
519
520
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
        rank_size = torch.distributed.recv(size_tensor,
521
                                           src=self.ranks[src],
522
523
524
525
526
527
528
529
530
                                           group=self.cpu_group)

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
            device="cpu")

        rank_object = torch.distributed.recv(object_tensor,
531
                                             src=self.ranks[src],
532
533
534
535
536
537
538
539
540
                                             group=self.cpu_group)

        assert rank_object == rank_size, (
            "Received object sender rank does not match the size sender rank.")

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

541
542
    def broadcast_tensor_dict(
        self,
543
        tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
544
545
546
        src: int = 0,
        group: Optional[ProcessGroup] = None,
        metadata_group: Optional[ProcessGroup] = None
547
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
548
549
550
551
552
553
554
555
556
557
558
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if (not torch.distributed.is_initialized() or self.world_size == 1):
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

559
560
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
561
            metadata_list: list[tuple[Any, Any]] = []
562
563
564
565
566
567
568
            assert isinstance(
                tensor_dict,
                dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
569
            self.broadcast_object(metadata_list, src=src)
570
571
572
573
574
575
576
577
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
578
                                                         src=self.ranks[src],
579
580
581
582
583
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
584
                                                         src=self.ranks[src],
585
586
587
588
589
590
591
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
592
            metadata_list = self.broadcast_object(None, src=src)
593
594
            tensor_dict = {}
            async_handles = []
595
            for key, value in metadata_list:
596
597
598
599
600
601
                if isinstance(value, TensorMetadata):
                    tensor = torch.empty(value.size,
                                         dtype=value.dtype,
                                         device=value.device)
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
602
                        tensor_dict[key] = tensor
603
604
605
606
607
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
608
                            src=self.ranks[src],
609
610
611
612
                            group=metadata_group,
                            async_op=True)
                    else:
                        # use group for GPU tensors
613
614
615
616
617
                        handle = torch.distributed.broadcast(
                            tensor,
                            src=self.ranks[src],
                            group=group,
                            async_op=True)
618
                    async_handles.append(handle)
619
                    tensor_dict[key] = tensor
620
                else:
621
                    tensor_dict[key] = value
622
623
624
625
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

626
627
    def send_tensor_dict(
        self,
628
        tensor_dict: dict[str, Union[torch.Tensor, Any]],
629
630
        dst: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
631
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
632
633
634
635
636
637
638
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict

639
640
641
642
643
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

644
645
646
647
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
648
            dst = (self.rank_in_group + 1) % self.world_size
649
650
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

651
        metadata_list: list[tuple[Any, Any]] = []
652
653
654
655
656
657
658
659
660
661
662
663
        assert isinstance(
            tensor_dict,
            dict), f"Expecting a dictionary, got {type(tensor_dict)}"
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
        for tensor in tensor_list:
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
664
665
666
667
668
669

            # send-allgather: send only a slice, then do allgather.
            if (all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0):
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

670
671
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
672
673
674
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=metadata_group)
675
676
            else:
                # use group for GPU tensors
677
678
679
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=group)
680
681
682
683
        return None

    def recv_tensor_dict(
        self,
684
685
        src: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
686
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
687
688
689
690
691
692
693
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None

694
695
696
697
698
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

699
700
701
702
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
703
            src = (self.rank_in_group - 1) % self.world_size
704
705
706
        assert src < self.world_size, f"Invalid src rank ({src})"

        recv_metadata_list = self.recv_object(src=src)
707
        tensor_dict: dict[str, Any] = {}
708
709
710
711
712
713
714
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device=value.device)
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
715
                    tensor_dict[key] = tensor
716
                    continue
717
718
719
720
721
722
723
724
725
726

                # send-allgather: send only a slice, then do allgather.
                use_all_gather = (all_gather_group is not None
                                  and tensor.numel() % all_gather_size == 0)

                if use_all_gather:
                    orig_shape = tensor.shape
                    tensor = tensor.reshape(all_gather_size,
                                            -1)[all_gather_rank]

727
728
729
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    torch.distributed.recv(tensor,
730
                                           src=self.ranks[src],
731
732
733
                                           group=metadata_group)
                else:
                    # use group for GPU tensors
734
735
736
                    torch.distributed.recv(tensor,
                                           src=self.ranks[src],
                                           group=group)
737
738
739
740
741
742
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
                        tensor, dim=0)
                    tensor = tensor.reshape(orig_shape)

743
                tensor_dict[key] = tensor
744
            else:
745
                tensor_dict[key] = value
746
747
        return tensor_dict

748
749
750
751
752
753
754
755
756
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

757
758
759
    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
        """Sends a tensor to the destination rank in a non-blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
760
        self.device_communicator.send(tensor, dst)
761
762
763
764
765

    def recv(self,
             size: torch.Size,
             dtype: torch.dtype,
             src: Optional[int] = None) -> torch.Tensor:
766
767
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
768
        return self.device_communicator.recv(size, dtype, src)
769

770
771
772
773
774
775
776
    def destroy(self):
        if self.device_group is not None:
            torch.distributed.destroy_process_group(self.device_group)
            self.device_group = None
        if self.cpu_group is not None:
            torch.distributed.destroy_process_group(self.cpu_group)
            self.cpu_group = None
777
778
        if self.device_communicator is not None:
            self.device_communicator.destroy()
779
780
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
781

782
783
784
785
786
787
788
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
            self.device_communicator.prepare_communication_buffer_for_model(
                model)

    def dispatch(
            self, hidden_states: torch.Tensor,
789
            router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
790
791
792
        if self.device_communicator is not None:
            return self.device_communicator.dispatch(hidden_states,
                                                     router_logits)
793
794
        else:
            return hidden_states, router_logits
795
796
797
798

    def combine(self, hidden_states) -> torch.Tensor:
        if self.device_communicator is not None:
            return self.device_communicator.combine(hidden_states)
799
800
        else:
            return hidden_states
801

802
803
804
805
806
807
808
809
810

_WORLD: Optional[GroupCoordinator] = None


def get_world_group() -> GroupCoordinator:
    assert _WORLD is not None, ("world group is not initialized")
    return _WORLD


811
def init_world_group(ranks: list[int], local_rank: int,
812
813
814
815
816
                     backend: str) -> GroupCoordinator:
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
817
        use_device_communicator=False,
818
        group_name="world",
819
820
821
    )


822
def init_model_parallel_group(
823
    group_ranks: list[list[int]],
824
825
826
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
827
    group_name: Optional[str] = None,
828
) -> GroupCoordinator:
829

830
831
832
833
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
834
        use_device_communicator=True,
835
        use_message_queue_broadcaster=use_message_queue_broadcaster,
836
        group_name=group_name,
837
838
839
    )


840
841
842
843
844
845
846
847
848
849
850
851
852
_TP: Optional[GroupCoordinator] = None


def get_tp_group() -> GroupCoordinator:
    assert _TP is not None, ("tensor model parallel group is not initialized")
    return _TP


# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group

_PP: Optional[GroupCoordinator] = None

853
854
855
856
857
858
859
_DP: Optional[GroupCoordinator] = None


def get_dp_group() -> GroupCoordinator:
    assert _DP is not None, ("data parallel group is not initialized")
    return _DP

860

861
862
863
864
865
866
867
868
_EP: Optional[GroupCoordinator] = None


def get_ep_group() -> GroupCoordinator:
    assert _EP is not None, ("expert parallel group is not initialized")
    return _EP


869
870
871
872
def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, (
        "pipeline model parallel group is not initialized")
    return _PP
873
874


875
876
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
877
878


879
@contextmanager
880
def graph_capture(device: torch.device):
881
882
883
884
885
886
887
888
889
890
891
892
893
    """
    `graph_capture` is a context manager which should surround the code that
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
894
895
896
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
            context):
897
898
        yield context

899

900
logger = init_logger(__name__)
901

902
_ENABLE_CUSTOM_ALL_REDUCE = True
903
904


905
906
907
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
908

Zhuohan Li's avatar
Zhuohan Li committed
909

910
def init_distributed_environment(
911
912
913
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
914
915
916
    local_rank: int = -1,
    backend: str = "nccl",
):
917
918
919
920
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
921
922
923
924
925
926
927
928
929
930
931
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None and config.parallel_config.data_parallel_size > 1:
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
        ip = parallel_config.data_parallel_master_ip
        port = parallel_config.get_next_dp_init_port()
932
        distributed_init_method = get_distributed_init_method(ip, port)
933
934
935
        logger.info(
            "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
            world_size, rank, distributed_init_method)
936
937
938
939
940
941
942
943
944
945
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
946
947
948
949
950
951
952
953
954
955
956
957
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
        if distributed_init_method == "env://":
            local_rank = envs.LOCAL_RANK
        else:
            local_rank = rank
    global _WORLD
    if _WORLD is None:
958
        ranks = list(range(torch.distributed.get_world_size()))
959
        _WORLD = init_world_group(ranks, local_rank, backend)
960
961
962
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
            "world group already initialized with a different world size")
963
964


Zhuohan Li's avatar
Zhuohan Li committed
965
966
967
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
968
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
969
970
) -> None:
    """
971
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
972
973

    Arguments:
974
975
976
977
978
979
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
980
981
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
982
983
984
985
986
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
987
988
989
990
991
992
993
994
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
995
    rank = torch.distributed.get_rank()
996
997
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
998

999
1000
1001
1002
    data_parallel_size = 1
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None:
1003
1004
1005
1006
1007
1008
1009
1010
1011
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1012
1013
1014
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1015
        -1, data_parallel_size, pipeline_model_parallel_size,
1016
1017
        tensor_model_parallel_size)  # noqa

1018
1019
1020
    # Build the tensor model-parallel groups.
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
1021
1022
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1023
1024

    # message queue broadcaster is only used in tensor model parallel group
1025
    _TP = init_model_parallel_group(group_ranks,
1026
1027
                                    get_world_group().local_rank,
                                    backend,
1028
1029
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")
1030

1031
    # Build the pipeline model-parallel groups.
1032
1033
    global _PP
    assert _PP is None, (
1034
        "pipeline model parallel group is already initialized")
1035
    group_ranks = all_ranks.transpose(2, 3).reshape(
1036
1037
        -1, pipeline_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1038
    _PP = init_model_parallel_group(group_ranks,
1039
1040
                                    get_world_group().local_rank,
                                    backend,
1041
                                    group_name="pp")
1042

1043
1044
    global _DP
    assert _DP is None, ("data parallel group is already initialized")
1045
1046
    group_ranks = all_ranks.transpose(1,
                                      3).reshape(-1,
1047
1048
1049
1050
1051
1052
1053
                                                 data_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
    _DP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="dp")

1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
    global _EP
    assert _EP is None, ("expert parallel group is already initialized")
    group_ranks = all_ranks.transpose(1, 2).reshape(
        -1, data_parallel_size * tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
    _EP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="ep")

1064
1065
    logger.info(
        "rank %s in world size %s is assigned as "
1066
1067
1068
        "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
        _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
        _EP.rank_in_group)
1069

Zhuohan Li's avatar
Zhuohan Li committed
1070

1071
1072
1073
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1074
    backend: Optional[str] = None,
1075
1076
1077
1078
1079
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1080
1081
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
1082
1083
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
1084
                                  pipeline_model_parallel_size, backend)
1085
1086
1087
1088
1089
1090
1091
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
1092
1093
    pp_world_size = get_pp_group().world_size
    assert (pp_world_size == pipeline_model_parallel_size), (
1094
        "pipeline parallel group already initialized, but of unexpected size: "
1095
        f"{pp_world_size=} vs. "
1096
1097
1098
        f"{pipeline_model_parallel_size=}")


1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)


Zhuohan Li's avatar
Zhuohan Li committed
1116
def model_parallel_is_initialized():
1117
    """Check if tensor and pipeline parallel groups are initialized."""
1118
    return (_TP is not None and _PP is not None)
1119
1120


1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1149
1150
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1151
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1152
1153
1154
1155


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1156
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1157
1158
1159


def destroy_model_parallel():
1160
    """Set the groups to none and destroy them."""
1161
    global _TP
1162

1163
1164
1165
1166
1167
1168
1169
1170
1171
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1172
1173
1174
1175
1176
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1177
1178
1179
1180
1181
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1182
1183
1184
1185
1186
1187
1188
1189

def destroy_distributed_environment():
    global _WORLD
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1190
1191


1192
1193
1194
1195
1196
1197
1198
1199
1200
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    if shutdown_ray:
        import ray  # Lazy import Ray
        ray.shutdown()
    gc.collect()
1201
    from vllm.platforms import current_platform
1202
1203
1204
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1205
1206
1207
1208
1209
    try:
        torch._C._host_emptyCache()
    except AttributeError:
        logger.warning(
            "torch._C._host_emptyCache() only available in Pytorch >=2.5")
1210
1211


1212
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
1213
                        source_rank: int = 0) -> list[bool]:
1214
    """
1215
1216
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1217
1218
    memory system (shared access to shared memory).
    """
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
    if isinstance(pg, ProcessGroup):
        assert torch.distributed.get_backend(
            pg) != torch.distributed.Backend.NCCL, (
                "in_the_same_node_as should be tested with a non-NCCL group.")
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1233
1234
1235
1236
1237
1238
1239
1240
1241

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1242
            if rank == source_rank:
1243
1244
1245
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
1246
1247
1248
1249
1250
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
                        [shm.name], src=ranks[source_rank], group=pg)
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1251
                is_in_the_same_node[rank] = 1
1252
1253
            else:
                # try to open the shared memory segment
1254
1255
1256
1257
1258
1259
1260
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
                        recv, src=ranks[source_rank], group=pg)
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1261
1262
1263
1264
1265
1266
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
1267
1268
1269
1270
1271
1272
1273
1274
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1275
1276
1277
1278
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1279
1280
1281

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1282
        if rank == source_rank and shm:
1283
            shm.unlink()
1284

1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]