parallel_state.py 52.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

Woosuk Kwon's avatar
Woosuk Kwon committed
4
# Copyright 2023 The vLLM team.
5
6
# Adapted from
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
Zhuohan Li's avatar
Zhuohan Li committed
7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
8
9
10
11
12
"""vLLM distributed state.
It takes over the control of the distributed environment from PyTorch.
The typical workflow is:

- call `init_distributed_environment` to initialize the distributed environment.
13
- call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
14
15
16
17
18
19
20
21
22
23
24
 initialize the model parallel groups.

- any code dealing with the distributed stuff

- call `destroy_model_parallel` to destroy the model parallel groups.
- call `destroy_distributed_environment` to destroy the distributed environment.

If you only need to use the distributed environment without model/pipeline
 parallelism, you can skip the model parallel initialization and destruction
 steps.
"""
25
import contextlib
26
import gc
27
import pickle
28
import weakref
29
30
31
from collections import namedtuple
from contextlib import contextmanager, nullcontext
from dataclasses import dataclass
32
from multiprocessing import shared_memory
33
from typing import Any, Callable, Optional, Union
34
from unittest.mock import patch
Zhuohan Li's avatar
Zhuohan Li committed
35
36

import torch
37
import torch.distributed
38
from torch.distributed import Backend, ProcessGroup
Zhuohan Li's avatar
Zhuohan Li committed
39

40
import vllm.envs as envs
41
42
from vllm.distributed.device_communicators.base_device_communicator import (
    DeviceCommunicatorBase)
43
from vllm.distributed.utils import StatelessProcessGroup
44
from vllm.logger import init_logger
45
46
from vllm.utils import (direct_register_custom_op, get_distributed_init_method,
                        resolve_obj_by_qualname, supports_custom_op)
47
48


49
50
51
@dataclass
class GraphCaptureContext:
    stream: torch.cuda.Stream
52

53

54
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
55

56
57

def _split_tensor_dict(
58
59
    tensor_dict: dict[str, Union[torch.Tensor, Any]]
) -> tuple[list[tuple[str, Any]], list[torch.Tensor]]:
60
61
62
63
64
    """Split the tensor dictionary into two parts:
    1. A list of (key, value) pairs. If the value is a tensor, it is replaced
         by its metadata.
    2. A list of tensors.
    """
65
66
    metadata_list: list[tuple[str, Any]] = []
    tensor_list: list[torch.Tensor] = []
67
68
69
70
71
72
    for key, value in tensor_dict.items():
        if isinstance(value, torch.Tensor):
            # Note: we cannot use `value.device` here,
            # because it contains not only the device type but also the device
            # index (e.g. "cuda:0"). We only need the device type.
            # receiving side will set the device index.
73
            device = value.device.type
74
            metadata_list.append(
75
                (key, TensorMetadata(device, value.dtype, value.size())))
76
77
            tensor_list.append(value)
        else:
78
            metadata_list.append((key, value))
79
80
81
    return metadata_list, tensor_list


82
_group_name_counter: dict[str, int] = {}
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97


def _get_unique_name(name: str) -> str:
    """Get a unique name for the group.
    Example:
    _get_unique_name("tp") -> "tp:0"
    _get_unique_name("tp") -> "tp:1"
    """
    if name not in _group_name_counter:
        _group_name_counter[name] = 0
    newname = f"{name}:{_group_name_counter[name]}"
    _group_name_counter[name] += 1
    return newname


98
_groups: dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
99
100
101


def _register_group(group: "GroupCoordinator") -> None:
102
    _groups[group.unique_name] = weakref.ref(group)
103
104


105
106
107
108
109
110
def all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
    return group._all_reduce_out_place(tensor)
111
112


113
114
def all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
    return torch.empty_like(tensor)
115

116

117
118
119
120
121
122
def reduce_scatter(tensor: torch.Tensor, dim: int, world_size: int,
                   group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
123
    return group._reduce_scatter_out_place(tensor, dim)
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138


def reduce_scatter_fake(tensor: torch.Tensor, dim: int, world_size: int,
                        group_name: str) -> torch.Tensor:
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] // world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


def all_gather(tensor: torch.Tensor, dim: int, world_size: int,
               group_name: str) -> torch.Tensor:
    assert group_name in _groups, f"Group {group_name} is not found."
    group = _groups[group_name]()
    if group is None:
        raise ValueError(f"Group {group_name} is destroyed.")
139
    return group._all_gather_out_place(tensor, dim)
140
141
142
143
144
145
146
147
148


def all_gather_fake(tensor: torch.Tensor, dim: int, world_size: int,
                    group_name: str) -> torch.Tensor:
    new_shape = list(tensor.shape)
    new_shape[dim] = tensor.shape[dim] * world_size
    return torch.empty(new_shape, dtype=tensor.dtype, device=tensor.device)


149
if supports_custom_op():
150
    from vllm.platforms import current_platform
151
    direct_register_custom_op(
152
153
        op_name="all_reduce",
        op_func=all_reduce,
154
        mutates_args=[],
155
        fake_impl=all_reduce_fake,
156
        dispatch_key=current_platform.dispatch_key,
157
158
    )

159
160
161
162
163
    direct_register_custom_op(
        op_name="reduce_scatter",
        op_func=reduce_scatter,
        mutates_args=[],
        fake_impl=reduce_scatter_fake,
164
        dispatch_key=current_platform.dispatch_key,
165
166
167
168
169
170
171
    )

    direct_register_custom_op(
        op_name="all_gather",
        op_func=all_gather,
        mutates_args=[],
        fake_impl=all_gather_fake,
172
        dispatch_key=current_platform.dispatch_key,
173
174
    )

175

176
177
178
179
180
181
class GroupCoordinator:
    """
    PyTorch ProcessGroup wrapper for a group of processes.
    PyTorch ProcessGroup is bound to one specific communication backend,
        e.g. NCCL, Gloo, MPI, etc.
    GroupCoordinator takes charge of all the communication operations among
182
183
        the processes in the group. It manages both CPU and device
        communication.
184
185
186
187
    """

    # available attributes:
    rank: int  # global rank
188
    ranks: list[int]  # global ranks in the group
189
190
191
192
193
194
195
196
197
198
199
200
    world_size: int  # size of the group
    # difference between `local_rank` and `rank_in_group`:
    # if we have a group of size 4 across two nodes:
    # Process | Node | Rank | Local Rank | Rank in Group
    #   0     |   0  |  0   |     0      |       0
    #   1     |   0  |  1   |     1      |       1
    #   2     |   1  |  2   |     0      |       2
    #   3     |   1  |  3   |     1      |       3
    local_rank: int  # local rank used to assign devices
    rank_in_group: int  # rank inside the group
    cpu_group: ProcessGroup  # group for CPU communication
    device_group: ProcessGroup  # group for device communication
201
202
    use_device_communicator: bool  # whether to use device communicator
    device_communicator: DeviceCommunicatorBase  # device communicator
203
    mq_broadcaster: Optional[Any]  # shared memory broadcaster
204
205
206

    def __init__(
        self,
207
        group_ranks: list[list[int]],
208
209
        local_rank: int,
        torch_distributed_backend: Union[str, Backend],
210
        use_device_communicator: bool,
211
        use_message_queue_broadcaster: bool = False,
212
        group_name: Optional[str] = None,
213
    ):
214
215
216
        group_name = group_name or "anonymous"
        self.unique_name = _get_unique_name(group_name)
        _register_group(self)
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238

        self.rank = torch.distributed.get_rank()
        self.local_rank = local_rank
        self.device_group = None
        self.cpu_group = None

        for ranks in group_ranks:
            device_group = torch.distributed.new_group(
                ranks, backend=torch_distributed_backend)
            # a group with `gloo` backend, to allow direct coordination between
            # processes through the CPU.
            cpu_group = torch.distributed.new_group(ranks, backend="gloo")
            if self.rank in ranks:
                self.ranks = ranks
                self.world_size = len(ranks)
                self.rank_in_group = ranks.index(self.rank)
                self.device_group = device_group
                self.cpu_group = cpu_group

        assert self.cpu_group is not None
        assert self.device_group is not None

239
        from vllm.platforms import current_platform
240

241
        if current_platform.is_cuda_alike():
242
            self.device = torch.device(f"cuda:{local_rank}")
243
244
        elif current_platform.is_xpu():
            self.device = torch.device(f"xpu:{local_rank}")
245
246
247
        elif current_platform.is_out_of_tree():
            self.device = torch.device(
                f"{current_platform.device_name}:{local_rank}")
248
249
250
        else:
            self.device = torch.device("cpu")

251
        self.use_device_communicator = use_device_communicator
252

253
254
255
256
257
258
        self.device_communicator: DeviceCommunicatorBase = None  # type: ignore
        if use_device_communicator and self.world_size > 1:
            device_comm_cls = resolve_obj_by_qualname(
                current_platform.get_device_communicator_cls())
            self.device_communicator = device_comm_cls(
                cpu_group=self.cpu_group,
259
                device=self.device,
260
261
                device_group=self.device_group,
                unique_name=self.unique_name,
262
263
            )

264
        from vllm.distributed.device_communicators.shm_broadcast import (
265
266
267
268
            MessageQueue)
        self.mq_broadcaster: Optional[MessageQueue] = None
        if use_message_queue_broadcaster and self.world_size > 1:
            self.mq_broadcaster = MessageQueue.create_from_process_group(
269
                self.cpu_group, 1 << 22, 6)
270

271
        from vllm.platforms import current_platform
272
273
        self.use_custom_op_call = (current_platform.is_cuda_alike()
                                   or current_platform.is_tpu())
274

275
276
277
278
279
280
281
282
283
284
    @property
    def first_rank(self):
        """Return the global rank of the first process in the group"""
        return self.ranks[0]

    @property
    def last_rank(self):
        """Return the global rank of the last process in the group"""
        return self.ranks[-1]

285
286
287
288
289
290
291
292
293
294
    @property
    def is_first_rank(self):
        """Return whether the caller is the first process in the group"""
        return self.rank == self.first_rank

    @property
    def is_last_rank(self):
        """Return whether the caller is the last process in the group"""
        return self.rank == self.last_rank

295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
    @property
    def next_rank(self):
        """Return the global rank of the process that follows the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group + 1) % world_size]

    @property
    def prev_rank(self):
        """Return the global rank of the process that precedes the caller"""
        rank_in_group = self.rank_in_group
        world_size = self.world_size
        return self.ranks[(rank_in_group - 1) % world_size]

    @contextmanager
    def graph_capture(
            self, graph_capture_context: Optional[GraphCaptureContext] = None):
        if graph_capture_context is None:
            stream = torch.cuda.Stream()
            graph_capture_context = GraphCaptureContext(stream)
        else:
            stream = graph_capture_context.stream

318
319
320
321
322
323
324
325
326
327
        # only cuda uses this function,
        # so we don't abstract it into the base class
        maybe_ca_context = nullcontext()
        from vllm.distributed.device_communicators.cuda_communicator import (
            CudaCommunicator)
        if self.device_communicator is not None:
            assert isinstance(self.device_communicator, CudaCommunicator)
            ca_comm = self.device_communicator.ca_comm
            if ca_comm is not None:
                maybe_ca_context = ca_comm.capture()  # type: ignore
328
329
330
331
332
333
334

        # ensure all initialization operations complete before attempting to
        # capture the graph on another stream
        curr_stream = torch.cuda.current_stream()
        if curr_stream != stream:
            stream.wait_stream(curr_stream)

335
        with torch.cuda.stream(stream), maybe_ca_context:
336
            yield graph_capture_context
337
338
339

    def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
        """
340
341
342
343
344
345
346
347
348
349
        User-facing all-reduce function before we actually call the
        all-reduce operation.

        We need this because Dynamo does not support passing an arbitrary
        object (`self` in this case) to a custom op. We need to pass the
         group name as a string, and then look up the group coordinator from
         the group name, dispatch the all-reduce operation to the group
         coordinator.

        In addition, PyTorch custom ops do not support mutation or returning
350
351
        a new tensor in the same op. So we always make the all-reduce operation
        out-of-place.
352
353
354
355
356
        """
        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_

357
358
359
360
361
        if self.use_custom_op_call:
            return torch.ops.vllm.all_reduce(input_,
                                             group_name=self.unique_name)
        else:
            return self._all_reduce_out_place(input_)
362

363
    def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
364
        return self.device_communicator.all_reduce(input_)
365
366
367
368
369
370
371
372

    def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")
373

374
375
376
377
378
379
380
381
382
383
        if self.use_custom_op_call:
            return torch.ops.vllm.all_gather(input_,
                                             dim,
                                             world_size,
                                             group_name=self.unique_name)
        else:
            return self._all_gather_out_place(input_, dim)

    def _all_gather_out_place(self, input_: torch.Tensor,
                              dim: int) -> torch.Tensor:
384
        return self.device_communicator.all_gather(input_, dim)
385

386
387
388
389
390
391
392
393
394
395
    def reduce_scatter(self,
                       input_: torch.Tensor,
                       dim: int = -1) -> torch.Tensor:
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
        assert -input_.dim() <= dim < input_.dim(), (
            f"Invalid dim ({dim}) for input tensor with shape {input_.size()}")

396
397
398
399
400
401
402
403
404
405
        if self.use_custom_op_call:
            return torch.ops.vllm.reduce_scatter(input_,
                                                 dim,
                                                 world_size,
                                                 group_name=self.unique_name)
        else:
            return self._reduce_scatter_out_place(input_, dim)

    def _reduce_scatter_out_place(self, input_: torch.Tensor,
                                  dim: int) -> torch.Tensor:
406
407
        return self.device_communicator.reduce_scatter(input_, dim)

408
409
410
    def gather(self,
               input_: torch.Tensor,
               dst: int = 0,
411
               dim: int = -1) -> Optional[torch.Tensor]:
412
413
414
415
416
417
418
419
420
        """
        NOTE: We assume that the input tensor is on the same device across
        all the ranks.
        NOTE: `dst` is the local rank of the destination rank.
        """
        world_size = self.world_size
        # Bypass the function if we are using only 1 GPU.
        if world_size == 1:
            return input_
421
        return self.device_communicator.gather(input_, dst, dim)
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437

    def broadcast(self, input_: torch.Tensor, src: int = 0):
        """Broadcast the input tensor.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return input_
        # Broadcast.
        torch.distributed.broadcast(input_,
                                    src=self.ranks[src],
                                    group=self.device_group)
        return input_

438
439
440
441
442
443
444
445
446
    def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
        """Broadcast the input object.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj
447
448
449
        if self.mq_broadcaster is not None:
            assert src == 0, "Message queue broadcaster only supports src=0"
            return self.mq_broadcaster.broadcast_object(obj)
450
451
452
453
454
455
456
457
458
459
460
461
        if self.rank_in_group == src:
            torch.distributed.broadcast_object_list([obj],
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return obj
        else:
            recv = [None]
            torch.distributed.broadcast_object_list(recv,
                                                    src=self.ranks[src],
                                                    group=self.cpu_group)
            return recv[0]

462
    def broadcast_object_list(self,
463
                              obj_list: list[Any],
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
                              src: int = 0,
                              group: Optional[ProcessGroup] = None):
        """Broadcast the input object list.
        NOTE: `src` is the local rank of the source rank.
        """
        assert src < self.world_size, f"Invalid src rank ({src})"

        # Bypass the function if we are using only 1 GPU.
        if self.world_size == 1:
            return obj_list
        # Broadcast.
        torch.distributed.broadcast_object_list(obj_list,
                                                src=self.ranks[src],
                                                group=self.device_group)
        return obj_list

480
481
482
483
484
485
    def send_object(self, obj: Any, dst: int) -> None:
        """Send the input object list to the destination rank."""
        """NOTE: `dst` is the local rank of the destination rank."""

        assert dst < self.world_size, f"Invalid dst rank ({dst})"

486
        assert dst != self.rank_in_group, (
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
            "Invalid destination rank. Destination rank is the same "
            "as the current rank.")

        # Serialize object to tensor and get the size as well
        object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)

        size_tensor = torch.tensor([object_tensor.numel()],
                                   dtype=torch.long,
                                   device="cpu")

        # Send object size

        torch.distributed.send(size_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        # Send object
        torch.distributed.send(object_tensor,
                               dst=self.ranks[dst],
                               group=self.cpu_group)

        return None

    def recv_object(self, src: int) -> Any:
        """Receive the input object list from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""

        assert src < self.world_size, f"Invalid src rank ({src})"

516
        assert src != self.rank_in_group, (
517
518
519
520
521
522
523
            "Invalid source rank. Source rank is the same as the current rank."
        )

        size_tensor = torch.empty(1, dtype=torch.long, device="cpu")

        # Receive object size
        rank_size = torch.distributed.recv(size_tensor,
524
                                           src=self.ranks[src],
525
526
527
528
529
530
531
532
533
                                           group=self.cpu_group)

        # Tensor to receive serialized objects into.
        object_tensor = torch.empty(  # type: ignore[call-overload]
            size_tensor.item(),  # type: ignore[arg-type]
            dtype=torch.uint8,
            device="cpu")

        rank_object = torch.distributed.recv(object_tensor,
534
                                             src=self.ranks[src],
535
536
537
538
539
540
541
542
543
                                             group=self.cpu_group)

        assert rank_object == rank_size, (
            "Received object sender rank does not match the size sender rank.")

        obj = pickle.loads(object_tensor.numpy().tobytes())

        return obj

544
545
    def broadcast_tensor_dict(
        self,
546
        tensor_dict: Optional[dict[str, Union[torch.Tensor, Any]]] = None,
547
548
549
        src: int = 0,
        group: Optional[ProcessGroup] = None,
        metadata_group: Optional[ProcessGroup] = None
550
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
551
552
553
554
555
556
557
558
559
560
561
        """Broadcast the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if (not torch.distributed.is_initialized() or self.world_size == 1):
            return tensor_dict

        group = self.device_group
        metadata_group = self.cpu_group
        assert src < self.world_size, f"Invalid src rank ({src})"

562
563
        rank_in_group = self.rank_in_group
        if rank_in_group == src:
564
            metadata_list: list[tuple[Any, Any]] = []
565
566
567
568
569
570
571
            assert isinstance(
                tensor_dict,
                dict), (f"Expecting a dictionary, got {type(tensor_dict)}")
            metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
            # `metadata_list` lives in CPU memory.
            # `broadcast_object_list` has serialization & deserialization,
            # all happening on CPU. Therefore, we can use the CPU group.
572
            self.broadcast_object(metadata_list, src=src)
573
574
575
576
577
578
579
580
            async_handles = []
            for tensor in tensor_list:
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
                    continue
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    handle = torch.distributed.broadcast(tensor,
581
                                                         src=self.ranks[src],
582
583
584
585
586
                                                         group=metadata_group,
                                                         async_op=True)
                else:
                    # use group for GPU tensors
                    handle = torch.distributed.broadcast(tensor,
587
                                                         src=self.ranks[src],
588
589
590
591
592
593
594
                                                         group=group,
                                                         async_op=True)
                async_handles.append(handle)
            for async_handle in async_handles:
                async_handle.wait()

        else:
595
            metadata_list = self.broadcast_object(None, src=src)
596
597
            tensor_dict = {}
            async_handles = []
598
            for key, value in metadata_list:
599
600
601
602
603
604
                if isinstance(value, TensorMetadata):
                    tensor = torch.empty(value.size,
                                         dtype=value.dtype,
                                         device=value.device)
                    if tensor.numel() == 0:
                        # Skip broadcasting empty tensors.
605
                        tensor_dict[key] = tensor
606
607
608
609
610
                        continue
                    if tensor.is_cpu:
                        # use metadata_group for CPU tensors
                        handle = torch.distributed.broadcast(
                            tensor,
611
                            src=self.ranks[src],
612
613
614
615
                            group=metadata_group,
                            async_op=True)
                    else:
                        # use group for GPU tensors
616
617
618
619
620
                        handle = torch.distributed.broadcast(
                            tensor,
                            src=self.ranks[src],
                            group=group,
                            async_op=True)
621
                    async_handles.append(handle)
622
                    tensor_dict[key] = tensor
623
                else:
624
                    tensor_dict[key] = value
625
626
627
628
            for async_handle in async_handles:
                async_handle.wait()
        return tensor_dict

629
630
    def send_tensor_dict(
        self,
631
        tensor_dict: dict[str, Union[torch.Tensor, Any]],
632
633
        dst: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
634
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
635
636
637
638
639
640
641
        """Send the input tensor dictionary.
        NOTE: `dst` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return tensor_dict

642
643
644
645
646
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

647
648
649
650
        group = self.device_group
        metadata_group = self.cpu_group

        if dst is None:
651
            dst = (self.rank_in_group + 1) % self.world_size
652
653
        assert dst < self.world_size, f"Invalid dst rank ({dst})"

654
        metadata_list: list[tuple[Any, Any]] = []
655
656
657
658
659
660
661
662
663
664
665
666
        assert isinstance(
            tensor_dict,
            dict), f"Expecting a dictionary, got {type(tensor_dict)}"
        metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
        # `metadata_list` lives in CPU memory.
        # `send_object_list` has serialization & deserialization,
        # all happening on CPU. Therefore, we can use the CPU group.
        self.send_object(metadata_list, dst=dst)
        for tensor in tensor_list:
            if tensor.numel() == 0:
                # Skip sending empty tensors.
                continue
667
668
669
670
671
672

            # send-allgather: send only a slice, then do allgather.
            if (all_gather_group is not None
                    and tensor.numel() % all_gather_size == 0):
                tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]

673
674
            if tensor.is_cpu:
                # use metadata_group for CPU tensors
675
676
677
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=metadata_group)
678
679
            else:
                # use group for GPU tensors
680
681
682
                torch.distributed.send(tensor,
                                       dst=self.ranks[dst],
                                       group=group)
683
684
685
686
        return None

    def recv_tensor_dict(
        self,
687
688
        src: Optional[int] = None,
        all_gather_group: Optional["GroupCoordinator"] = None,
689
    ) -> Optional[dict[str, Union[torch.Tensor, Any]]]:
690
691
692
693
694
695
696
        """Recv the input tensor dictionary.
        NOTE: `src` is the local rank of the source rank.
        """
        # Bypass the function if we are using only 1 GPU.
        if not torch.distributed.is_initialized() or self.world_size == 1:
            return None

697
698
699
700
701
        all_gather_size = (1 if all_gather_group is None else
                           all_gather_group.world_size)
        all_gather_rank = (0 if all_gather_group is None else
                           all_gather_group.rank_in_group)

702
703
704
705
        group = self.device_group
        metadata_group = self.cpu_group

        if src is None:
706
            src = (self.rank_in_group - 1) % self.world_size
707
708
709
        assert src < self.world_size, f"Invalid src rank ({src})"

        recv_metadata_list = self.recv_object(src=src)
710
        tensor_dict: dict[str, Any] = {}
711
712
713
714
715
716
717
        for key, value in recv_metadata_list:
            if isinstance(value, TensorMetadata):
                tensor = torch.empty(value.size,
                                     dtype=value.dtype,
                                     device=value.device)
                if tensor.numel() == 0:
                    # Skip broadcasting empty tensors.
718
                    tensor_dict[key] = tensor
719
                    continue
720
721
722
723
724
725
726
727
728
729

                # send-allgather: send only a slice, then do allgather.
                use_all_gather = (all_gather_group is not None
                                  and tensor.numel() % all_gather_size == 0)

                if use_all_gather:
                    orig_shape = tensor.shape
                    tensor = tensor.reshape(all_gather_size,
                                            -1)[all_gather_rank]

730
731
732
                if tensor.is_cpu:
                    # use metadata_group for CPU tensors
                    torch.distributed.recv(tensor,
733
                                           src=self.ranks[src],
734
735
736
                                           group=metadata_group)
                else:
                    # use group for GPU tensors
737
738
739
                    torch.distributed.recv(tensor,
                                           src=self.ranks[src],
                                           group=group)
740
741
742
743
744
745
                if use_all_gather:
                    # do the allgather
                    tensor = all_gather_group.all_gather(  # type: ignore
                        tensor, dim=0)
                    tensor = tensor.reshape(orig_shape)

746
                tensor_dict[key] = tensor
747
            else:
748
                tensor_dict[key] = value
749
750
        return tensor_dict

751
752
753
754
755
756
757
758
759
    def barrier(self):
        """Barrier synchronization among the group.
        NOTE: don't use `device_group` here! `barrier` in NCCL is
        terrible because it is internally a broadcast operation with
        secretly created GPU tensors. It is easy to mess up the current
        device. Use the CPU group instead.
        """
        torch.distributed.barrier(group=self.cpu_group)

760
761
762
    def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
        """Sends a tensor to the destination rank in a non-blocking way"""
        """NOTE: `dst` is the local rank of the destination rank."""
763
        self.device_communicator.send(tensor, dst)
764
765
766
767
768

    def recv(self,
             size: torch.Size,
             dtype: torch.dtype,
             src: Optional[int] = None) -> torch.Tensor:
769
770
        """Receives a tensor from the source rank."""
        """NOTE: `src` is the local rank of the source rank."""
771
        return self.device_communicator.recv(size, dtype, src)
772

773
774
775
776
777
778
779
    def destroy(self):
        if self.device_group is not None:
            torch.distributed.destroy_process_group(self.device_group)
            self.device_group = None
        if self.cpu_group is not None:
            torch.distributed.destroy_process_group(self.cpu_group)
            self.cpu_group = None
780
781
        if self.device_communicator is not None:
            self.device_communicator.destroy()
782
783
        if self.mq_broadcaster is not None:
            self.mq_broadcaster = None
784

785
786
787
788
789
790
791
    def prepare_communication_buffer_for_model(self, model: torch.nn.Module):
        if self.device_communicator is not None:
            self.device_communicator.prepare_communication_buffer_for_model(
                model)

    def dispatch(
            self, hidden_states: torch.Tensor,
792
            router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
793
794
795
        if self.device_communicator is not None:
            return self.device_communicator.dispatch(hidden_states,
                                                     router_logits)
796
797
        else:
            return hidden_states, router_logits
798
799
800
801

    def combine(self, hidden_states) -> torch.Tensor:
        if self.device_communicator is not None:
            return self.device_communicator.combine(hidden_states)
802
803
        else:
            return hidden_states
804

805
806

_WORLD: Optional[GroupCoordinator] = None
807
_NODE_COUNT: Optional[int] = None
808
809
810
811
812
813
814


def get_world_group() -> GroupCoordinator:
    assert _WORLD is not None, ("world group is not initialized")
    return _WORLD


815
def init_world_group(ranks: list[int], local_rank: int,
816
817
818
819
820
                     backend: str) -> GroupCoordinator:
    return GroupCoordinator(
        group_ranks=[ranks],
        local_rank=local_rank,
        torch_distributed_backend=backend,
821
        use_device_communicator=False,
822
        group_name="world",
823
824
825
    )


826
def init_model_parallel_group(
827
    group_ranks: list[list[int]],
828
829
830
    local_rank: int,
    backend: str,
    use_message_queue_broadcaster: bool = False,
831
    group_name: Optional[str] = None,
832
) -> GroupCoordinator:
833

834
835
836
837
    return GroupCoordinator(
        group_ranks=group_ranks,
        local_rank=local_rank,
        torch_distributed_backend=backend,
838
        use_device_communicator=True,
839
        use_message_queue_broadcaster=use_message_queue_broadcaster,
840
        group_name=group_name,
841
842
843
    )


844
845
846
847
848
849
850
851
852
853
854
855
856
_TP: Optional[GroupCoordinator] = None


def get_tp_group() -> GroupCoordinator:
    assert _TP is not None, ("tensor model parallel group is not initialized")
    return _TP


# kept for backward compatibility
get_tensor_model_parallel_group = get_tp_group

_PP: Optional[GroupCoordinator] = None

857
858
859
860
861
862
863
_DP: Optional[GroupCoordinator] = None


def get_dp_group() -> GroupCoordinator:
    assert _DP is not None, ("data parallel group is not initialized")
    return _DP

864

865
866
867
868
869
870
871
872
_EP: Optional[GroupCoordinator] = None


def get_ep_group() -> GroupCoordinator:
    assert _EP is not None, ("expert parallel group is not initialized")
    return _EP


873
874
875
876
def get_pp_group() -> GroupCoordinator:
    assert _PP is not None, (
        "pipeline model parallel group is not initialized")
    return _PP
877
878


879
880
# kept for backward compatibility
get_pipeline_model_parallel_group = get_pp_group
881
882


883
@contextmanager
884
def graph_capture(device: torch.device):
885
886
887
888
889
890
891
892
893
894
895
896
897
    """
    `graph_capture` is a context manager which should surround the code that
    is capturing the CUDA graph. Its main purpose is to ensure that the
    some operations will be run after the graph is captured, before the graph
    is replayed. It returns a `GraphCaptureContext` object which contains the
    necessary data for the graph capture. Currently, it only contains the
    stream that the graph capture is running on. This stream is set to the
    current CUDA stream when the context manager is entered and reset to the
    default stream when the context manager is exited. This is to ensure that
    the graph capture is running on a separate stream from the default stream,
    in order to explicitly distinguish the kernels to capture
    from other kernels possibly launched on background in the default stream.
    """
898
899
900
    context = GraphCaptureContext(torch.cuda.Stream(device=device))
    with get_tp_group().graph_capture(context), get_pp_group().graph_capture(
            context):
901
902
        yield context

903

904
logger = init_logger(__name__)
905

906
_ENABLE_CUSTOM_ALL_REDUCE = True
907
908


909
910
911
def set_custom_all_reduce(enable: bool):
    global _ENABLE_CUSTOM_ALL_REDUCE
    _ENABLE_CUSTOM_ALL_REDUCE = enable
912

Zhuohan Li's avatar
Zhuohan Li committed
913

914
def init_distributed_environment(
915
916
917
    world_size: int = -1,
    rank: int = -1,
    distributed_init_method: str = "env://",
918
919
920
    local_rank: int = -1,
    backend: str = "nccl",
):
921
922
923
924
    logger.debug(
        "world_size=%d rank=%d local_rank=%d "
        "distributed_init_method=%s backend=%s", world_size, rank, local_rank,
        distributed_init_method, backend)
925
926
927
928
929
930
931
932
933
934
935
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None and config.parallel_config.data_parallel_size > 1:
        parallel_config = config.parallel_config
        # adjust to take into account data parallelism
        # offset the rank by the data parallel rank
        rank = parallel_config.data_parallel_rank * world_size + rank
        # adjust the world size to take into account data parallelism
        world_size = parallel_config.world_size_across_dp
        ip = parallel_config.data_parallel_master_ip
        port = parallel_config.get_next_dp_init_port()
936
        distributed_init_method = get_distributed_init_method(ip, port)
937
938
939
        logger.info(
            "Adjusting world_size=%d rank=%d distributed_init_method=%s for DP",
            world_size, rank, distributed_init_method)
940
941
942
943
    if not torch.distributed.is_initialized():
        assert distributed_init_method is not None, (
            "distributed_init_method must be provided when initializing "
            "distributed environment")
944
945
946
947
948
949
950
        if not torch.distributed.is_backend_available(backend):
            logger.warning(
                "Distributed backend %s is not available; "
                "falling back to gloo.", backend)
            assert torch.distributed.is_gloo_available(), (
                "Fallback Gloo backend is not available.")
            backend = "gloo"
951
952
953
954
955
956
        # this backend is used for WORLD
        torch.distributed.init_process_group(
            backend=backend,
            init_method=distributed_init_method,
            world_size=world_size,
            rank=rank)
957
958
959
960
961
962
963
964
965
966
    # set the local rank
    # local_rank is not available in torch ProcessGroup,
    # see https://github.com/pytorch/pytorch/issues/122816
    if local_rank == -1:
        # local rank not set, this usually happens in single-node
        # setting, where we can use rank as local rank
        if distributed_init_method == "env://":
            local_rank = envs.LOCAL_RANK
        else:
            local_rank = rank
967
    global _WORLD, _NODE_COUNT
968
    if _WORLD is None:
969
        ranks = list(range(torch.distributed.get_world_size()))
970
        _WORLD = init_world_group(ranks, local_rank, backend)
971
972
973
        _NODE_COUNT = _node_count(_WORLD.cpu_group)
        logger.debug("Detected %d nodes in the distributed environment",
                     _NODE_COUNT)
974
975
976
    else:
        assert _WORLD.world_size == torch.distributed.get_world_size(), (
            "world group already initialized with a different world size")
977
978


Zhuohan Li's avatar
Zhuohan Li committed
979
980
981
def initialize_model_parallel(
    tensor_model_parallel_size: int = 1,
    pipeline_model_parallel_size: int = 1,
982
    backend: Optional[str] = None,
Zhuohan Li's avatar
Zhuohan Li committed
983
984
) -> None:
    """
985
    Initialize model parallel groups.
Zhuohan Li's avatar
Zhuohan Li committed
986
987

    Arguments:
988
989
990
991
992
993
        tensor_model_parallel_size: number of GPUs used for tensor model
            parallelism.
        pipeline_model_parallel_size: number of GPUs used for pipeline model
            parallelism.

    Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
Zhuohan Li's avatar
Zhuohan Li committed
994
995
    use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
    the model pipeline. The present function will
996
997
998
999
1000
    create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
        4 tensor model-parallel groups:
            [g0, g1], [g2, g3], [g4, g5], [g6, g7]
        2 pipeline model-parallel groups:
            [g0, g2, g4, g6], [g1, g3, g5, g7]
Zhuohan Li's avatar
Zhuohan Li committed
1001
1002
1003
1004
1005
1006
1007
1008
    Note that for efficiency, the caller should make sure adjacent ranks
    are on the same DGX box. For example if we are using 2 DGX-1 boxes
    with a total of 16 GPUs, rank 0 to 7 belong to the first box and
    ranks 8 to 15 belong to the second box.
    """
    # Get world size and rank. Ensure some consistencies.
    assert torch.distributed.is_initialized()
    world_size: int = torch.distributed.get_world_size()
1009
    rank = torch.distributed.get_rank()
1010
1011
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
Zhuohan Li's avatar
Zhuohan Li committed
1012

1013
1014
1015
1016
    data_parallel_size = 1
    from vllm.config import get_current_vllm_config
    config = get_current_vllm_config()
    if config is not None:
1017
1018
1019
1020
1021
1022
1023
1024
1025
        data_parallel_size = config.parallel_config.data_parallel_size

    # the layout order is: ExternalDP x DP x PP x TP
    # ExternalDP is the data parallel group that is not part of the model,
    # every dp rank can generate independently (in verl integration).
    # DP is the data parallel group that is part of the model,
    # all the ranks in the same DP group should generate simultaneously,
    # i.e. the `generate` call in the same DP group should be called together,
    # otherwise it will cause deadlock.
1026
1027
1028
    # to get group_ranks for each dimension, transpose that dimension to the
    # last dimension, then reshape to 2D, then unbind the last dimension
    all_ranks = torch.arange(world_size).reshape(
1029
        -1, data_parallel_size, pipeline_model_parallel_size,
1030
1031
        tensor_model_parallel_size)  # noqa

1032
1033
1034
    # Build the tensor model-parallel groups.
    global _TP
    assert _TP is None, ("tensor model parallel group is already initialized")
1035
1036
    group_ranks = all_ranks.view(-1, tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1037
1038

    # message queue broadcaster is only used in tensor model parallel group
1039
    _TP = init_model_parallel_group(group_ranks,
1040
1041
                                    get_world_group().local_rank,
                                    backend,
1042
1043
                                    use_message_queue_broadcaster=True,
                                    group_name="tp")
1044

1045
    # Build the pipeline model-parallel groups.
1046
1047
    global _PP
    assert _PP is None, (
1048
        "pipeline model parallel group is already initialized")
1049
    group_ranks = all_ranks.transpose(2, 3).reshape(
1050
1051
        -1, pipeline_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
1052
    _PP = init_model_parallel_group(group_ranks,
1053
1054
                                    get_world_group().local_rank,
                                    backend,
1055
                                    group_name="pp")
1056

1057
1058
    global _DP
    assert _DP is None, ("data parallel group is already initialized")
1059
1060
    group_ranks = all_ranks.transpose(1,
                                      3).reshape(-1,
1061
1062
1063
1064
1065
1066
1067
                                                 data_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
    _DP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="dp")

1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
    global _EP
    assert _EP is None, ("expert parallel group is already initialized")
    group_ranks = all_ranks.transpose(1, 2).reshape(
        -1, data_parallel_size * tensor_model_parallel_size).unbind(0)
    group_ranks = [x.tolist() for x in group_ranks]
    _EP = init_model_parallel_group(group_ranks,
                                    get_world_group().local_rank,
                                    backend,
                                    group_name="ep")

1078
1079
    logger.info(
        "rank %s in world size %s is assigned as "
1080
1081
1082
        "DP rank %s, PP rank %s, TP rank %s, EP rank %s", rank, world_size,
        _DP.rank_in_group, _PP.rank_in_group, _TP.rank_in_group,
        _EP.rank_in_group)
1083

Zhuohan Li's avatar
Zhuohan Li committed
1084

1085
1086
1087
def ensure_model_parallel_initialized(
    tensor_model_parallel_size: int,
    pipeline_model_parallel_size: int,
1088
    backend: Optional[str] = None,
1089
1090
1091
1092
1093
) -> None:
    """Helper to initialize model parallel groups if they are not initialized,
    or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
    values if the model parallel groups are initialized.
    """
1094
1095
    backend = backend or torch.distributed.get_backend(
        get_world_group().device_group)
1096
1097
    if not model_parallel_is_initialized():
        initialize_model_parallel(tensor_model_parallel_size,
1098
                                  pipeline_model_parallel_size, backend)
1099
1100
1101
1102
1103
1104
1105
        return

    assert (
        get_tensor_model_parallel_world_size() == tensor_model_parallel_size
    ), ("tensor parallel group already initialized, but of unexpected size: "
        f"{get_tensor_model_parallel_world_size()=} vs. "
        f"{tensor_model_parallel_size=}")
1106
1107
    pp_world_size = get_pp_group().world_size
    assert (pp_world_size == pipeline_model_parallel_size), (
1108
        "pipeline parallel group already initialized, but of unexpected size: "
1109
        f"{pp_world_size=} vs. "
1110
1111
1112
        f"{pipeline_model_parallel_size=}")


1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
def prepare_communication_buffer_for_model(model: torch.nn.Module):
    """Prepare the communication buffer for the model.
    Traditional communication libraries like NCCL are almost
    model agnostic. However, emerging new communication libraries like
    MoE all2all (DeepEP) usually allocate the communication buffer
    based on the model shape for optimal performance.
    """
    if _TP is not None:
        _TP.prepare_communication_buffer_for_model(model)
    if _PP is not None:
        _PP.prepare_communication_buffer_for_model(model)
    if _DP is not None:
        _DP.prepare_communication_buffer_for_model(model)
    if _EP is not None:
        _EP.prepare_communication_buffer_for_model(model)


Zhuohan Li's avatar
Zhuohan Li committed
1130
def model_parallel_is_initialized():
1131
    """Check if tensor and pipeline parallel groups are initialized."""
1132
    return (_TP is not None and _PP is not None)
1133
1134


1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
_TP_STATE_PATCHED = False


@contextmanager
def patch_tensor_parallel_group(tp_group: GroupCoordinator):
    """Patch the tp group temporarily until this function ends.

    This method is for draft workers of speculative decoding to run draft model
    with different tp degree from that of target model workers.

    Args:
        tp_group (GroupCoordinator): the tp group coordinator
    """
    global _TP_STATE_PATCHED
    assert not _TP_STATE_PATCHED, "Should not call when it's already patched"

    _TP_STATE_PATCHED = True
    old_tp_group = get_tp_group()
    global _TP
    _TP = tp_group
    try:
        yield
    finally:
        # restore the original state
        _TP_STATE_PATCHED = False
        _TP = old_tp_group


Zhuohan Li's avatar
Zhuohan Li committed
1163
1164
def get_tensor_model_parallel_world_size():
    """Return world size for the tensor model parallel group."""
1165
    return get_tp_group().world_size
Zhuohan Li's avatar
Zhuohan Li committed
1166
1167
1168
1169


def get_tensor_model_parallel_rank():
    """Return my rank for the tensor model parallel group."""
1170
    return get_tp_group().rank_in_group
Zhuohan Li's avatar
Zhuohan Li committed
1171
1172


1173
1174
1175
1176
1177
1178
1179
def get_node_count() -> int:
    """Return the total number of nodes in the distributed environment. """
    assert _NODE_COUNT is not None, (
        "distributed environment is not initialized")
    return _NODE_COUNT


Zhuohan Li's avatar
Zhuohan Li committed
1180
def destroy_model_parallel():
1181
    """Set the groups to none and destroy them."""
1182
    global _TP
1183

1184
1185
1186
1187
1188
1189
1190
1191
1192
    if _TP:
        _TP.destroy()
    _TP = None

    global _PP
    if _PP:
        _PP.destroy()
    _PP = None

1193
1194
1195
1196
1197
    global _DP
    if _DP:
        _DP.destroy()
    _DP = None

1198
1199
1200
1201
1202
    global _EP
    if _EP:
        _EP.destroy()
    _EP = None

1203
1204

def destroy_distributed_environment():
1205
    global _WORLD, _NODE_COUNT
1206
1207
1208
    if _WORLD:
        _WORLD.destroy()
    _WORLD = None
1209
    _NODE_COUNT = None
1210
1211
    if torch.distributed.is_initialized():
        torch.distributed.destroy_process_group()
1212
1213


1214
1215
1216
1217
1218
1219
1220
1221
1222
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
    destroy_model_parallel()
    destroy_distributed_environment()
    with contextlib.suppress(AssertionError):
        torch.distributed.destroy_process_group()
    if shutdown_ray:
        import ray  # Lazy import Ray
        ray.shutdown()
    gc.collect()
1223
    from vllm.platforms import current_platform
1224
1225
1226
    empty_cache = current_platform.empty_cache
    if empty_cache is not None:
        empty_cache()
1227
    try:
1228
1229
        if not current_platform.is_cpu():
            torch._C._host_emptyCache()
1230
1231
1232
    except AttributeError:
        logger.warning(
            "torch._C._host_emptyCache() only available in Pytorch >=2.5")
1233
1234


1235
def in_the_same_node_as(pg: Union[ProcessGroup, StatelessProcessGroup],
1236
                        source_rank: int = 0) -> list[bool]:
1237
    """
1238
1239
    This is a collective operation that returns if each rank is in the same node
    as the source rank. It tests if processes are attached to the same
1240
1241
    memory system (shared access to shared memory).
    """
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
    if isinstance(pg, ProcessGroup):
        assert torch.distributed.get_backend(
            pg) != torch.distributed.Backend.NCCL, (
                "in_the_same_node_as should be tested with a non-NCCL group.")
        # local rank inside the group
        rank = torch.distributed.get_rank(group=pg)
        world_size = torch.distributed.get_world_size(group=pg)

        # global ranks of the processes in the group
        ranks = torch.distributed.get_process_group_ranks(pg)
    else:
        rank = pg.rank
        world_size = pg.world_size
        ranks = list(range(world_size))
1256
1257
1258
1259
1260
1261
1262
1263
1264

    # local tensor in each process to store the result
    is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)

    magic_message = b"magic_message"
    shm = None

    try:
        with contextlib.suppress(OSError):
1265
            if rank == source_rank:
1266
1267
1268
                # create a shared memory segment
                shm = shared_memory.SharedMemory(create=True, size=128)
                shm.buf[:len(magic_message)] = magic_message
1269
1270
1271
1272
1273
                if isinstance(pg, ProcessGroup):
                    torch.distributed.broadcast_object_list(
                        [shm.name], src=ranks[source_rank], group=pg)
                else:
                    pg.broadcast_obj(shm.name, src=source_rank)
1274
                is_in_the_same_node[rank] = 1
1275
1276
            else:
                # try to open the shared memory segment
1277
1278
1279
1280
1281
1282
1283
                if isinstance(pg, ProcessGroup):
                    recv = [None]
                    torch.distributed.broadcast_object_list(
                        recv, src=ranks[source_rank], group=pg)
                    name = recv[0]
                else:
                    name = pg.broadcast_obj(None, src=source_rank)
1284
1285
1286
1287
1288
1289
                # fix to https://stackoverflow.com/q/62748654/9191338
                # Python incorrectly tracks shared memory even if it is not
                # created by the process. The following patch is a workaround.
                with patch("multiprocessing.resource_tracker.register",
                           lambda *args, **kwargs: None):
                    shm = shared_memory.SharedMemory(name=name)
1290
1291
1292
1293
1294
1295
1296
1297
                if shm.buf[:len(magic_message)] == magic_message:
                    is_in_the_same_node[rank] = 1
    except Exception as e:
        logger.error("Error ignored in is_in_the_same_node: %s", e)
    finally:
        if shm:
            shm.close()

1298
1299
1300
1301
    if isinstance(pg, ProcessGroup):
        torch.distributed.barrier(group=pg)
    else:
        pg.barrier()
1302
1303
1304

    # clean up the shared memory segment
    with contextlib.suppress(OSError):
1305
        if rank == source_rank and shm:
1306
            shm.unlink()
1307

1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
    if isinstance(pg, ProcessGroup):
        torch.distributed.all_reduce(is_in_the_same_node, group=pg)
        aggregated_data = is_in_the_same_node
    else:
        aggregated_data = torch.zeros_like(is_in_the_same_node)
        for i in range(world_size):
            rank_data = pg.broadcast_obj(is_in_the_same_node, src=i)
            aggregated_data += rank_data

    return [x == 1 for x in aggregated_data.tolist()]
1318
1319


1320
1321
def is_global_first_rank() -> bool:
    """
1322
    Check if the current process is the first rank globally across all
1323
    parallelism strategies (PP, TP, DP, EP, etc.).
1324

1325
1326
1327
    Unlike group-specific checks like `get_tensor_model_parallel_rank() == 0`
    or `get_pp_group().is_first_rank`, this function checks the global rank
    across all parallelism dimensions.
1328

1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
    Returns:
        bool: True if this is the global first rank (rank 0), False otherwise.
              Returns True if distributed is not initialized (single process).
    """
    try:
        # If world group is available, use it for the most accurate check
        global _WORLD
        if _WORLD is not None:
            return _WORLD.is_first_rank

        # If torch distributed is not initialized, assume single process
        if not torch.distributed.is_initialized():
            return True

        # Fallback to torch's global rank
        return torch.distributed.get_rank() == 0

    except Exception:
        # If anything goes wrong, assume this is the first rank
        return True


1351
1352
1353
1354
1355
1356
def _node_count(pg: Union[ProcessGroup, StatelessProcessGroup]) -> int:
    """
    Returns the total number of nodes in the process group.

    Args:
        pg: The process group to analyze
1357

1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
    Returns:
        int: The total number of nodes
    """
    if isinstance(pg, ProcessGroup):
        world_size = torch.distributed.get_world_size(group=pg)
    else:
        world_size = pg.world_size

    if world_size == 1:
        return 1

    # Build node assignment map
    node_assignment = [0] * world_size  # rank -> node_id
    next_node_id = 0

    for current_rank in range(world_size):
        if node_assignment[current_rank] != 0:
            continue  # Already assigned to a node

        # Assign current rank to a new node
        next_node_id += 1
        node_assignment[current_rank] = next_node_id

        # Find all ranks on the same node as current_rank
        same_node_flags = in_the_same_node_as(pg, current_rank)
        for other_rank, is_same_node in enumerate(same_node_flags):
            if is_same_node and node_assignment[other_rank] == 0:
                node_assignment[other_rank] = next_node_id

    return next_node_id