parallel_state.py 51 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25
import contextlib
26
import gc
27
import pickle
28
import weakref
29
30
31
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
32
from multiprocessing import shared_memory
33
from typing import Any, Callable, Optional, Union
34
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
35
36

import torch
37
import torch.distributed
38
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
39

40
import vllm.envs as envs
41
42
from vllm.distributed.device_communicators.base_device_communicator import (
    DeviceCommunicatorBase)
43
from vllm.distributed.utils import StatelessProcessGroup
44
from vllm.logger import init_logger
45
46
from vllm.utils import (direct_register_custom_op, get_distributed_init_method,
                        resolve_obj_by_qualname, supports_custom_op)
47
48


49
50
51
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
52

53

54
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
55

56
57

def _split_tensor_dict(
58
59
    tensor_dict: dict[str, Union[torch.Tensor, Any]]
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
60
61
62
63
64
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
65
66
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
67
68
69
70
71
72
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
73
            device = value.device.type
74
            metadata_list.append(
75
                (key, TensorMetadata(device, value.dtype, value.size())))
76
77
            tensor_list.append(value)
        else:
78
            metadata_list.append((key, value))
79
80
81
    return metadata_list, tensor_list


82
_group_name_counter: dict[str, int] = {}
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


98
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
99
100
101


def _register_group(group: "GroupCoordinator") -> None:
102
    _groups[group.unique_name] = weakref.ref(group)
103
104


105
106
107
108
109
110
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
111
112


113
114
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
115

116

117
118
119
120
121
122
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
                   group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
123
    return group._reduce_scatter_out_place(tensor, dim)
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138


def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
                        group_name: str) -> torch.Tensor:
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
               group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
139
    return group._all_gather_out_place(tensor, dim)
140
141
142
143
144
145
146
147
148


def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
                    group_name: str) -> torch.Tensor:
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


149
if supports_custom_op():
150
    from vllm.platforms import current_platform
151
    direct_register_custom_op(
152
153
        op_name="all_reduce",
        op_func=all_reduce,
154
        mutates_args=[],
155
        fake_impl=all_reduce_fake,
156
        dispatch_key=current_platform.dispatch_key,
157
158
    )

159
160
161
162
163
    direct_register_custom_op(
        op_name="reduce_scatter",
        op_func=reduce_scatter,
        mutates_args=[],
        fake_impl=reduce_scatter_fake,
164
        dispatch_key=current_platform.dispatch_key,
165
166
167
168
169
170
171
    )

    direct_register_custom_op(
        op_name="all_gather",
        op_func=all_gather,
        mutates_args=[],
        fake_impl=all_gather_fake,
172
        dispatch_key=current_platform.dispatch_key,
173
174
    )

175

176
177
178
179
180
181
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
182
183
        the processes in the group. It manages both CPU and device
        communication.
184
185
186
187
    """

    # available attributes:
    rank: int  # global rank
188
    ranks: list[int]  # global ranks in the group
189
190
191
192
193
194
195
196
197
198
199
200
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
201
202
    use_device_communicator: bool  # whether to use device communicator
    device_communicator: DeviceCommunicatorBase  # device communicator
203
    mq_broadcaster: Optional[Any]  # shared memory broadcaster
204
205
206

    def __init__(
        self,
207
        group_ranks: list[list[int]],
208
209
        local_rank: int,
        torch_distributed_backend: Union[str, Backend],
210
        use_device_communicator: bool,
211
        use_message_queue_broadcaster: bool = False,
212
        group_name: Optional[str] = None,
213
    ):
214
215
216
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
        self.device_group = None
        self.cpu_group = None

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
                self.device_group = device_group
                self.cpu_group = cpu_group

        assert self.cpu_group is not None
        assert self.device_group is not None

239
        from vllm.platforms import current_platform
240

241
        if current_platform.is_cuda_alike():
242
            self.device = torch.device(f"cuda:{local_rank}")
243
244
245
        elif current_platform.is_out_of_tree():
            self.device = torch.device(
                f"{current_platform.device_name}:{local_rank}")
246
247
248
        else:
            self.device = torch.device("cpu")

249
        self.use_device_communicator = use_device_communicator
250

251
252
253
254
255
256
        self.device_communicator: DeviceCommunicatorBase = None  # type: ignore
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
                current_platform.get_device_communicator_cls())
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
257
                device=self.device,
258
259
                device_group=self.device_group,
                unique_name=self.unique_name,
260
261
            )

262
        from vllm.distributed.device_communicators.shm_broadcast import (
263
264
265
266
            MessageQueue)
        self.mq_broadcaster: Optional[MessageQueue] = None
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
267
                self.cpu_group, 1 << 22, 6)
268

269
        from vllm.platforms import current_platform
270
271
        self.use_custom_op_call = (current_platform.is_cuda_alike()
                                   or current_platform.is_tpu())
272

273
274
275
276
277
278
279
280
281
282
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

283
284
285
286
287
288
289
290
291
292
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
    def graph_capture(
            self, graph_capture_context: Optional[GraphCaptureContext] = None):
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

316
317
318
319
320
321
322
323
324
325
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
            CudaCommunicator)
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
326
327
328
329
330
331
332

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

333
        with torch.cuda.stream(stream), maybe_ca_context:
334
            yield graph_capture_context
335
336
337

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
338
339
340
341
342
343
344
345
346
347
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
348
349
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
350
351
352
353
354
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

355
356
357
358
359
        if self.use_custom_op_call:
            return torch.ops.vllm.all_reduce(input_,
                                             group_name=self.unique_name)
        else:
            return self._all_reduce_out_place(input_)
360

361
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
362
        return self.device_communicator.all_reduce(input_)
363
364
365
366
367
368
369
370

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
371

372
373
374
375
376
377
378
379
380
381
        if self.use_custom_op_call:
            return torch.ops.vllm.all_gather(input_,
                                             dim,
                                             world_size,
                                             group_name=self.unique_name)
        else:
            return self._all_gather_out_place(input_, dim)

    def _all_gather_out_place(self, input_: torch.Tensor,
                              dim: int) -> torch.Tensor:
382
        return self.device_communicator.all_gather(input_, dim)
383

384
385
386
387
388
389
390
391
392
393
    def reduce_scatter(self,
                       input_: torch.Tensor,
                       dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

394
395
396
397
398
399
400
401
402
403
        if self.use_custom_op_call:
            return torch.ops.vllm.reduce_scatter(input_,
                                                 dim,
                                                 world_size,
                                                 group_name=self.unique_name)
        else:
            return self._reduce_scatter_out_place(input_, dim)

    def _reduce_scatter_out_place(self, input_: torch.Tensor,
                                  dim: int) -> torch.Tensor:
404
405
        return self.device_communicator.reduce_scatter(input_, dim)

406
407
408
    def gather(self,
               input_: torch.Tensor,
               dst: int = 0,
409
               dim: int = -1) -> Optional[torch.Tensor]:
410
411
412
413
414
415
416
417
418
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
419
        return self.device_communicator.gather(input_, dst, dim)
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
        torch.distributed.broadcast(input_,
                                    src=self.ranks[src],
                                    group=self.device_group)
        return input_

436
437
438
439
440
441
442
443
444
    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
445
446
447
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
448
449
450
451
452
453
454
455
456
457
458
459
        if self.rank_in_group == src:
            torch.distributed.broadcast_object_list([obj],
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return obj
        else:
            recv = [None]
            torch.distributed.broadcast_object_list(recv,
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return recv[0]

460
    def broadcast_object_list(self,
461
                              obj_list: list[Any],
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
                              src: int = 0,
                              group: Optional[ProcessGroup] = None):
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
        torch.distributed.broadcast_object_list(obj_list,
                                                src=self.ranks[src],
                                                group=self.device_group)
        return obj_list

478
479
480
481
482
483
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

484
        assert dst != self.rank_in_group, (
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
            "Invalid destination rank. Destination rank is the same "
            "as the current rank.")

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

        size_tensor = torch.tensor([object_tensor.numel()],
                                   dtype=torch.long,
                                   device="cpu")

        # Send object size

        torch.distributed.send(size_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        # Send object
        torch.distributed.send(object_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

514
        assert src != self.rank_in_group, (
515
516
517
518
519
520
521
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
        rank_size = torch.distributed.recv(size_tensor,
522
                                           src=self.ranks[src],
523
524
525
526
527
528
529
530
531
                                           group=self.cpu_group)

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
            device="cpu")

        rank_object = torch.distributed.recv(object_tensor,
532
                                             src=self.ranks[src],
533
534
535
536
537
538
539
540
541
                                             group=self.cpu_group)

        assert rank_object == rank_size, (
            "Received object sender rank does not match the size sender rank.")

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

542
543
    def broadcast_tensor_dict(
        self,
544
        tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
545
546
547
        src: int = 0,
        group: Optional[ProcessGroup] = None,
        metadata_group: Optional[ProcessGroup] = None
548
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
549
550
551
552
553
554
555
556
557
558
559
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if (not torch.distributed.is_initialized() or self.world_size == 1):
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

560
561
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
562
            metadata_list: list[tuple[Any, Any]] = []
563
564
565
566
567
568
569
            assert isinstance(
                tensor_dict,
                dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
570
            self.broadcast_object(metadata_list, src=src)
571
572
573
574
575
576
577
578
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
579
                                                         src=self.ranks[src],
580
581
582
583
584
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
585
                                                         src=self.ranks[src],
586
587
588
589
590
591
592
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
593
            metadata_list = self.broadcast_object(None, src=src)
594
595
            tensor_dict = {}
            async_handles = []
596
            for key, value in metadata_list:
597
598
599
600
601
602
                if isinstance(value, TensorMetadata):
                    tensor = torch.empty(value.size,
                                         dtype=value.dtype,
                                         device=value.device)
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
603
                        tensor_dict[key] = tensor
604
605
606
607
608
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
609
                            src=self.ranks[src],
610
611
612
613
                            group=metadata_group,
                            async_op=True)
                    else:
                        # use group for GPU tensors
614
615
616
617
618
                        handle = torch.distributed.broadcast(
                            tensor,
                            src=self.ranks[src],
                            group=group,
                            async_op=True)
619
                    async_handles.append(handle)
620
                    tensor_dict[key] = tensor
621
                else:
622
                    tensor_dict[key] = value
623
624
625
626
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

627
628
    def send_tensor_dict(
        self,
629
        tensor_dict: dict[str, Union[torch.Tensor, Any]],
630
631
        dst: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
632
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
633
634
635
636
637
638
639
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict

640
641
642
643
644
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

645
646
647
648
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
649
            dst = (self.rank_in_group + 1) % self.world_size
650
651
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

652
        metadata_list: list[tuple[Any, Any]] = []
653
654
655
656
657
658
659
660
661
662
663
664
        assert isinstance(
            tensor_dict,
            dict), f"Expecting a dictionary, got {type(tensor_dict)}"
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
        for tensor in tensor_list:
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
665
666
667
668
669
670

            # send-allgather: send only a slice, then do allgather.
            if (all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0):
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

671
672
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
673
674
675
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=metadata_group)
676
677
            else:
                # use group for GPU tensors
678
679
680
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=group)
681
682
683
684
        return None

    def recv_tensor_dict(
        self,
685
686
        src: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
687
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
688
689
690
691
692
693
694
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None

695
696
697
698
699
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

700
701
702
703
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
704
            src = (self.rank_in_group - 1) % self.world_size
705
706
707
        assert src < self.world_size, f"Invalid src rank ({src})"

        recv_metadata_list = self.recv_object(src=src)
708
        tensor_dict: dict[str, Any] = {}
709
710
711
712
713
714
715
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device=value.device)
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
716
                    tensor_dict[key] = tensor
717
                    continue
718
719
720
721
722
723
724
725
726
727

                # send-allgather: send only a slice, then do allgather.
                use_all_gather = (all_gather_group is not None
                                  and tensor.numel() % all_gather_size == 0)

                if use_all_gather:
                    orig_shape = tensor.shape
                    tensor = tensor.reshape(all_gather_size,
                                            -1)[all_gather_rank]

728
729
730
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    torch.distributed.recv(tensor,
731
                                           src=self.ranks[src],
732
733
734
                                           group=metadata_group)
                else:
                    # use group for GPU tensors
735
736
737
                    torch.distributed.recv(tensor,
                                           src=self.ranks[src],
                                           group=group)
738
739
740
741
742
743
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
                        tensor, dim=0)
                    tensor = tensor.reshape(orig_shape)

744
                tensor_dict[key] = tensor
745
            else:
746
                tensor_dict[key] = value
747
748
        return tensor_dict

749
750
751
752
753
754
755
756
757
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

758
759
760
    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
        """Sends a tensor to the destination rank in a non-blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
761
        self.device_communicator.send(tensor, dst)
762
763
764
765
766

    def recv(self,
             size: torch.Size,
             dtype: torch.dtype,
             src: Optional[int] = None) -> torch.Tensor:
767
768
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
769
        return self.device_communicator.recv(size, dtype, src)
770

771
772
773
774
775
776
777
    def destroy(self):
        if self.device_group is not None:
            torch.distributed.destroy_process_group(self.device_group)
            self.device_group = None
        if self.cpu_group is not None:
            torch.distributed.destroy_process_group(self.cpu_group)
            self.cpu_group = None
778
779
        if self.device_communicator is not None:
            self.device_communicator.destroy()
780
781
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
782

783
784
785
786
787
788
789
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
            self.device_communicator.prepare_communication_buffer_for_model(
                model)

    def dispatch(
            self, hidden_states: torch.Tensor,
790
            router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
791
792
793
        if self.device_communicator is not None:
            return self.device_communicator.dispatch(hidden_states,
                                                     router_logits)
794
795
        else:
            return hidden_states, router_logits
796
797
798
799

    def combine(self, hidden_states) -> torch.Tensor:
        if self.device_communicator is not None:
            return self.device_communicator.combine(hidden_states)
800
801
        else:
            return hidden_states
802

803
804

_WORLD: Optional[GroupCoordinator] = None
805
_NODE_COUNT: Optional[int] = None
806
807
808
809
810
811
812


def get_world_group() -> GroupCoordinator:
    assert _WORLD is not None, ("world group is not initialized")
    return _WORLD


813
def init_world_group(ranks: list[int], local_rank: int,
814
815
816
817
818
                     backend: str) -> GroupCoordinator:
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
819
        use_device_communicator=False,
820
        group_name="world",
821
822
823
    )


824
def init_model_parallel_group(
825
    group_ranks: list[list[int]],
826
827
828
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
829
    group_name: Optional[str] = None,
830
) -> GroupCoordinator:
831

832
833
834
835
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
836
        use_device_communicator=True,
837
        use_message_queue_broadcaster=use_message_queue_broadcaster,
838
        group_name=group_name,
839
840
841
    )


842
843
844
845
846
847
848
849
850
851
852
853
854
_TP: Optional[GroupCoordinator] = None


def get_tp_group() -> GroupCoordinator:
    assert _TP is not None, ("tensor model parallel group is not initialized")
    return _TP


# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group

_PP: Optional[GroupCoordinator] = None

855
856
857
858
859
860
861
_DP: Optional[GroupCoordinator] = None


def get_dp_group() -> GroupCoordinator:
    assert _DP is not None, ("data parallel group is not initialized")
    return _DP

862

863
864
865
866
867
868
869
870
_EP: Optional[GroupCoordinator] = None


def get_ep_group() -> GroupCoordinator:
    assert _EP is not None, ("expert parallel group is not initialized")
    return _EP


871
872
873
874
def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, (
        "pipeline model parallel group is not initialized")
    return _PP
875
876


877
878
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
879
880


881
@contextmanager
882
def graph_capture(device: torch.device):
883
884
885
886
887
888
889
890
891
892
893
894
895
    """
    `graph_capture` is a context manager which should surround the code that
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
896
897
898
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
            context):
899
900
        yield context

901

902
logger = init_logger(__name__)
903

904
_ENABLE_CUSTOM_ALL_REDUCE = True
905
906


907
908
909
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
910

Zhuohan Li's avatar
Zhuohan Li committed
911

912
def init_distributed_environment(
913
914
915
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
916
917
918
    local_rank: int = -1,
    backend: str = "nccl",
):
919
920
921
922
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
923
924
925
926
927
928
929
930
931
932
933
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None and config.parallel_config.data_parallel_size > 1:
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
        ip = parallel_config.data_parallel_master_ip
        port = parallel_config.get_next_dp_init_port()
934
        distributed_init_method = get_distributed_init_method(ip, port)
935
936
937
        logger.info(
            "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
            world_size, rank, distributed_init_method)
938
939
940
941
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
942
943
944
945
946
947
948
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
                "Distributed backend %s is not available; "
                "falling back to gloo.", backend)
            assert torch.distributed.is_gloo_available(), (
                "Fallback Gloo backend is not available.")
            backend = "gloo"
949
950
951
952
953
954
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
955
956
957
958
959
960
961
962
963
964
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
        if distributed_init_method == "env://":
            local_rank = envs.LOCAL_RANK
        else:
            local_rank = rank
965
    global _WORLD, _NODE_COUNT
966
    if _WORLD is None:
967
        ranks = list(range(torch.distributed.get_world_size()))
968
        _WORLD = init_world_group(ranks, local_rank, backend)
969
970
971
        _NODE_COUNT = _node_count(_WORLD.cpu_group)
        logger.debug("Detected %d nodes in the distributed environment",
                     _NODE_COUNT)
972
973
974
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
            "world group already initialized with a different world size")
975
976


Zhuohan Li's avatar
Zhuohan Li committed
977
978
979
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
980
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
981
982
) -> None:
    """
983
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
984
985

    Arguments:
986
987
988
989
990
991
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
992
993
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
994
995
996
997
998
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
999
1000
1001
1002
1003
1004
1005
1006
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1007
    rank = torch.distributed.get_rank()
1008
1009
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1010

1011
1012
1013
1014
    data_parallel_size = 1
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None:
1015
1016
1017
1018
1019
1020
1021
1022
1023
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1024
1025
1026
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1027
        -1, data_parallel_size, pipeline_model_parallel_size,
1028
1029
        tensor_model_parallel_size)  # noqa

1030
1031
1032
    # Build the tensor model-parallel groups.
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
1033
1034
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1035
1036

    # message queue broadcaster is only used in tensor model parallel group
1037
    _TP = init_model_parallel_group(group_ranks,
1038
1039
                                    get_world_group().local_rank,
                                    backend,
1040
1041
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")
1042

1043
    # Build the pipeline model-parallel groups.
1044
1045
    global _PP
    assert _PP is None, (
1046
        "pipeline model parallel group is already initialized")
1047
    group_ranks = all_ranks.transpose(2, 3).reshape(
1048
1049
        -1, pipeline_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1050
    _PP = init_model_parallel_group(group_ranks,
1051
1052
                                    get_world_group().local_rank,
                                    backend,
1053
                                    group_name="pp")
1054

1055
1056
    global _DP
    assert _DP is None, ("data parallel group is already initialized")
1057
1058
    group_ranks = all_ranks.transpose(1,
                                      3).reshape(-1,
1059
1060
1061
1062
1063
1064
1065
                                                 data_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
    _DP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="dp")

1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
    global _EP
    assert _EP is None, ("expert parallel group is already initialized")
    group_ranks = all_ranks.transpose(1, 2).reshape(
        -1, data_parallel_size * tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
    _EP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="ep")

1076
1077
    logger.info(
        "rank %s in world size %s is assigned as "
1078
1079
1080
        "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
        _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
        _EP.rank_in_group)
1081

Zhuohan Li's avatar
Zhuohan Li committed
1082

1083
1084
1085
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1086
    backend: Optional[str] = None,
1087
1088
1089
1090
1091
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1092
1093
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
1094
1095
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
1096
                                  pipeline_model_parallel_size, backend)
1097
1098
1099
1100
1101
1102
1103
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
1104
1105
    pp_world_size = get_pp_group().world_size
    assert (pp_world_size == pipeline_model_parallel_size), (
1106
        "pipeline parallel group already initialized, but of unexpected size: "
1107
        f"{pp_world_size=} vs. "
1108
1109
1110
        f"{pipeline_model_parallel_size=}")


1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)


Zhuohan Li's avatar
Zhuohan Li committed
1128
def model_parallel_is_initialized():
1129
    """Check if tensor and pipeline parallel groups are initialized."""
1130
    return (_TP is not None and _PP is not None)
1131
1132


1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1161
1162
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1163
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1164
1165
1166
1167


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1168
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1169
1170


1171
1172
1173
1174
1175
1176
1177
def get_node_count() -> int:
    """Return the total number of nodes in the distributed environment. """
    assert _NODE_COUNT is not None, (
        "distributed environment is not initialized")
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1178
def destroy_model_parallel():
1179
    """Set the groups to none and destroy them."""
1180
    global _TP
1181

1182
1183
1184
1185
1186
1187
1188
1189
1190
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1191
1192
1193
1194
1195
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1196
1197
1198
1199
1200
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1201
1202

def destroy_distributed_environment():
1203
    global _WORLD, _NODE_COUNT
1204
1205
1206
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1207
    _NODE_COUNT = None
1208
1209
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1210
1211


1212
1213
1214
1215
1216
1217
1218
1219
1220
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    if shutdown_ray:
        import ray  # Lazy import Ray
        ray.shutdown()
    gc.collect()
1221
    from vllm.platforms import current_platform
1222
1223
1224
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1225
    try:
1226
1227
        if not current_platform.is_cpu():
            torch._C._host_emptyCache()
1228
1229
1230
    except AttributeError:
        logger.warning(
            "torch._C._host_emptyCache() only available in Pytorch >=2.5")
1231
1232


1233
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
1234
                        source_rank: int = 0) -> list[bool]:
1235
    """
1236
1237
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1238
1239
    memory system (shared access to shared memory).
    """
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
    if isinstance(pg, ProcessGroup):
        assert torch.distributed.get_backend(
            pg) != torch.distributed.Backend.NCCL, (
                "in_the_same_node_as should be tested with a non-NCCL group.")
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1254
1255
1256
1257
1258
1259
1260
1261
1262

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1263
            if rank == source_rank:
1264
1265
1266
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
1267
1268
1269
1270
1271
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
                        [shm.name], src=ranks[source_rank], group=pg)
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1272
                is_in_the_same_node[rank] = 1
1273
1274
            else:
                # try to open the shared memory segment
1275
1276
1277
1278
1279
1280
1281
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
                        recv, src=ranks[source_rank], group=pg)
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1282
1283
1284
1285
1286
1287
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
1288
1289
1290
1291
1292
1293
1294
1295
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1296
1297
1298
1299
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1300
1301
1302

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1303
        if rank == source_rank and shm:
1304
            shm.unlink()
1305

1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354


def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
        
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id