shm_broadcast.py 29.6 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
5
6
import pickle
import time
from contextlib import contextmanager
7
from dataclasses import dataclass, field
8
from multiprocessing import shared_memory
9
from pickle import PickleBuffer
10
from threading import Event
11
from typing import TYPE_CHECKING, Any, cast
12
13
14
15
from unittest.mock import patch

import torch
import torch.distributed as dist
16
import zmq
17
from torch.distributed import ProcessGroup
18
19
20
21
22
23
24
25
from zmq import (  # type: ignore
    IPV6,  # type: ignore
    SUB,
    SUBSCRIBE,
    XPUB,
    XPUB_VERBOSE,
    Context,
)
26
27

import vllm.envs as envs
28
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
29
from vllm.logger import init_logger
30
from vllm.platforms import current_platform
31
from vllm.utils.network_utils import (
32
33
34
35
36
    get_ip,
    get_open_port,
    get_open_zmq_ipc_path,
    is_valid_ipv6_address,
)
37

38
39
40
if TYPE_CHECKING:
    from _typeshed import SizedBuffer

41
42
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

43
44
45
46
47
48
49
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")


def to_bytes_big(value: int, size: int) -> bytes:
    return value.to_bytes(size, byteorder="big")


50
51
52
logger = init_logger(__name__)


53
54
55
56
57
58
59
60
61
62
def long_wait_time_msg(threshold: int) -> str:
    return (
        "No available shared memory broadcast block found "
        f"in {threshold} seconds. This typically happens "
        "when some processes are hanging or doing some "
        "time-consuming work (e.g. compilation, "
        "weight/kv cache quantization)."
    )


63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
class SpinTimer:
    def record_activity(self):
        pass

    def spin(self):
        sched_yield()


class SpinSleepTimer(SpinTimer):
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when vllm does nothing. This would lead to more
    CPU thermal headroom when a request eventually comes, especially when
    multiple GPUs are connected as each GPU would otherwise pin one thread at
    100% CPU usage.

    The simplest solution is to reduce polling frequency when there is no
    activity for a certain period of time.
    """

    def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
        self.last_activity = time.monotonic()
        self.busy_loop_s = busy_loop_s
        self.wait_sleep_s = wait_sleep_s

    def record_activity(self):
        self.last_activity = time.monotonic()

    def spin(self):
        curr_time = time.monotonic()
        if curr_time >= self.last_activity + self.busy_loop_s:
            time.sleep(self.wait_sleep_s)
        else:
            sched_yield()


99
class ShmRingBuffer:
100
101
102
103
104
    def __init__(
        self,
        n_reader: int,
        max_chunk_bytes: int,
        max_chunks: int,
105
        name: str | None = None,
106
    ):
107
108
109
110
111
112
113
        """
        A shared memory ring buffer implementation for broadcast communication.
        Essentially, it is a queue where only one will `enqueue` and multiple
        will `dequeue`. The max size of each item, together with the max number
        of items that can be stored in the buffer are known in advance.
        In this case, we don't need to synchronize the access to
         the buffer.
114

115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
        Buffer memory layout:
                  data                                 metadata
                    |                                      |
                    | (current_idx)                        | (current_idx)
                    v                                      v
        +-------------------------------+----------------------------------------+
        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
        +-------------------------------+----------------------------------------+
        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

        metadata memory layout: each byte is a flag, the first byte is the written
        flag, and the rest are reader flags. The flags are set to 0 by default.
        +--------------+--------------+--------------+-----+--------------+
        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
        +--------------+--------------+--------------+-----+--------------+

131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
        The state of metadata is as follows:

        (case 1) 0???...???: the block is not written yet, cannot read, can write
        (case 2) 1000...000: the block is just written, can read, cannot write
        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

        State transition for readers:

        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
        Only after the caller finishes reading the block, the reader can mark the block as read.
        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

        State transition for writer:

        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

151
152
153
154
        During creation, `name` is None and the buffer is created. We can pass the
        created object to other processes by pickling it. The other processes will
        get the name of the shared memory and open it, so that they can access the
        same shared memory buffer.
155
        """  # noqa
156
157
158
159
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
160
161
162
        self.total_bytes_of_buffer = (
            self.max_chunk_bytes + self.metadata_size
        ) * self.max_chunks
163
164
165
166
167
168
169
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

        if name is None:
            # we are creating a buffer
            self.is_creator = True
            self.shared_memory = shared_memory.SharedMemory(
170
171
                create=True, size=self.total_bytes_of_buffer
            )
172
            # initialize the metadata section to 0
173
            with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer:
174
175
176
177
178
179
180
                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
        else:
            # we are opening an existing buffer
            self.is_creator = False
            # fix to https://stackoverflow.com/q/62748654/9191338
            # Python incorrectly tracks shared memory even if it is not
            # created by the process. The following patch is a workaround.
181
182
183
184
            with patch(
                "multiprocessing.resource_tracker.register",
                lambda *args, **kwargs: None,
            ):
185
186
                try:
                    self.shared_memory = shared_memory.SharedMemory(name=name)
187
188
189
190
191
                    # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
                    # Some platforms allocate memory based on page size,
                    # so the shared memory block size may be larger or equal
                    # to the requested size. The size parameter is ignored
                    # when attaching to an existing block.
192
                    assert self.shared_memory.size >= self.total_bytes_of_buffer
193
194
195
196
197
                except FileNotFoundError:
                    # we might deserialize the object in a different node
                    # in this case, this object is not used,
                    # and we should suppress the error
                    pass
198

199
    def handle(self):
200
201
202
203
204
205
        return (
            self.n_reader,
            self.max_chunk_bytes,
            self.max_chunks,
            self.shared_memory.name,
        )
206

207
208
209
    def __reduce__(self):
        return (
            self.__class__,
210
            self.handle(),
211
212
213
        )

    def __del__(self):
214
215
216
217
        if hasattr(self, "shared_memory"):
            self.shared_memory.close()
            if self.is_creator:
                self.shared_memory.unlink()
218
219
220
221
222

    @contextmanager
    def get_data(self, current_idx: int):
        start = self.data_offset + current_idx * self.max_chunk_bytes
        end = start + self.max_chunk_bytes
223
        with self.shared_memory.buf[start:end] as buf:
224
225
226
227
228
229
            yield buf

    @contextmanager
    def get_metadata(self, current_idx: int):
        start = self.metadata_offset + current_idx * self.metadata_size
        end = start + self.metadata_size
230
        with self.shared_memory.buf[start:end] as buf:
231
232
233
            yield buf


234
235
@dataclass
class Handle:
236
    local_reader_ranks: list[int] = field(default_factory=list)
237

238
239
240
    buffer_handle: tuple[int, int, int, str] | None = None
    local_subscribe_addr: str | None = None
    remote_subscribe_addr: str | None = None
241
    remote_addr_ipv6: bool = False
242
243
244
245
246
247
248


class MessageQueue:
    def __init__(
        self,
        n_reader,  # number of all readers
        n_local_reader,  # number of local readers through shared memory
249
        local_reader_ranks: list[int] | None = None,
250
251
252
        # Default of 24MiB chosen to be large enough to accommodate grammar
        # bitmask tensors for large batches (1024 requests).
        max_chunk_bytes: int = 1024 * 1024 * 24,
253
        max_chunks: int = 10,
254
        connect_ip: str | None = None,
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
    ):
        if local_reader_ranks is None:
            local_reader_ranks = list(range(n_local_reader))
        else:
            assert len(local_reader_ranks) == n_local_reader
        self.n_local_reader = n_local_reader
        n_remote_reader = n_reader - n_local_reader
        self.n_remote_reader = n_remote_reader

        context = Context()

        if n_local_reader > 0:
            # for local readers, we will:
            # 1. create a shared memory ring buffer to communicate small data
            # 2. create a publish-subscribe socket to communicate large data
270
            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
271

272
273
274
275
276
277
278
279
            # XPUB is very similar to PUB,
            # except that it can receive subscription messages
            # to confirm the number of subscribers
            self.local_socket = context.socket(XPUB)
            # set the verbose option so that we can receive every subscription
            # message. otherwise, we will only receive the first subscription
            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
            self.local_socket.setsockopt(XPUB_VERBOSE, True)
280
281
282
            local_subscribe_addr = get_open_zmq_ipc_path()
            logger.debug("Binding to %s", local_subscribe_addr)
            self.local_socket.bind(local_subscribe_addr)
283
284
285
286

            self.current_idx = 0
        else:
            self.buffer = None  # type: ignore
287
            local_subscribe_addr = None
288
289
290
            self.local_socket = None
            self.current_idx = -1

291
        remote_addr_ipv6 = False
292
293
294
        if n_remote_reader > 0:
            # for remote readers, we will:
            # create a publish-subscribe socket to communicate large data
295
296
            if not connect_ip:
                connect_ip = get_ip()
297
298
            self.remote_socket = context.socket(XPUB)
            self.remote_socket.setsockopt(XPUB_VERBOSE, True)
299
            remote_subscribe_port = get_open_port()
300
301
            if is_valid_ipv6_address(connect_ip):
                self.remote_socket.setsockopt(IPV6, 1)
302
                remote_addr_ipv6 = True
303
                connect_ip = f"[{connect_ip}]"
304
            socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
305
            self.remote_socket.bind(socket_addr)
306
            remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
307
        else:
308
            remote_subscribe_addr = None
309
310
311
312
313
314
315
            self.remote_socket = None

        self._is_writer = True
        self._is_local_reader = False
        self.local_reader_rank = -1
        # rank does not matter for remote readers
        self._is_remote_reader = False
316
        self._read_spin_timer = SpinTimer()
317
318
319

        self.handle = Handle(
            local_reader_ranks=local_reader_ranks,
320
            buffer_handle=self.buffer.handle() if self.buffer is not None else None,
321
322
323
            local_subscribe_addr=local_subscribe_addr,
            remote_subscribe_addr=remote_subscribe_addr,
            remote_addr_ipv6=remote_addr_ipv6,
324
325
        )

326
        logger.debug("vLLM message queue communication handle: %s", self.handle)
327

328
329
330
331
332
333
334
335
336
337
338
339
    def export_handle(self) -> Handle:
        return self.handle

    @staticmethod
    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
        self = MessageQueue.__new__(MessageQueue)
        self.handle = handle
        self._is_writer = False

        context = Context()

        if rank in handle.local_reader_ranks:
340
341
            assert handle.buffer_handle is not None
            self.buffer = ShmRingBuffer(*handle.buffer_handle)
342
343
344
345
346
347
348
            self.current_idx = 0
            self.local_reader_rank = handle.local_reader_ranks.index(rank)
            self._is_local_reader = True
            self._is_remote_reader = False

            self.local_socket = context.socket(SUB)
            self.local_socket.setsockopt_string(SUBSCRIBE, "")
349
            socket_addr = handle.local_subscribe_addr
350
351
            logger.debug("Connecting to %s", socket_addr)
            self.local_socket.connect(socket_addr)
352
353

            self.remote_socket = None
354

355
356
357
            self._read_spin_timer = (
                SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
            )
358
359
360
361
362
363
364
365
366
367
368
        else:
            self.buffer = None  # type: ignore
            self.current_idx = -1
            self.local_reader_rank = -1
            self._is_local_reader = False
            self._is_remote_reader = True

            self.local_socket = None

            self.remote_socket = context.socket(SUB)
            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
369
            if handle.remote_addr_ipv6:
370
                self.remote_socket.setsockopt(IPV6, 1)
371
            socket_addr = handle.remote_subscribe_addr
372
373
            logger.debug("Connecting to %s", socket_addr)
            self.remote_socket.connect(socket_addr)
374
375
376
377
378
379
380
381
382
383
384
385

        return self

    def wait_until_ready(self):
        """This is a collective operation. All processes (including the
        readers and the writer) should call this function.
        """
        if self._is_writer:
            # wait for all readers to connect

            # local readers
            for i in range(self.n_local_reader):
386
387
                # wait for subscription messages from all local readers
                self.local_socket.recv()
388
            if self.n_local_reader > 0:
389
390
                # send a message to all local readers
                # to make sure the publish channel is working
391
392
393
394
                self.local_socket.send(b"READY")

            # remote readers
            for i in range(self.n_remote_reader):
395
396
                # wait for subscription messages from all remote readers
                self.remote_socket.recv()
397
            if self.n_remote_reader > 0:
398
399
                # send a message to all remote readers
                # to make sure the publish channel is working
400
401
                self.remote_socket.send(b"READY")
        elif self._is_local_reader:
402
            # wait for the writer to send a message
403
404
405
            recv = self.local_socket.recv()
            assert recv == b"READY"
        elif self._is_remote_reader:
406
            # wait for the writer to send a message
407
408
            recv = self.remote_socket.recv()
            assert recv == b"READY"
409
410

    @contextmanager
411
    def acquire_write(self, timeout: float | None = None):
412
        assert self._is_writer, "Only writers can acquire write"
413
        start_time = time.monotonic()
414
415
416
417
418
419
420
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
                read_count = sum(metadata_buffer[1:])
                written_flag = metadata_buffer[0]
                if written_flag and read_count != self.buffer.n_reader:
                    # this block is written and not read by all readers
421
422
423
424
                    # for writers, `self.current_idx` is the next block to write
                    # if this block is not ready to write,
                    # we need to wait until it is read by all readers

425
                    # Release the processor to other threads
426
                    sched_yield()
427

428
429
430
431
432
                    # if we time out, raise an exception
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
                        raise TimeoutError

433
                    # if we wait for a long time, log a message
434
                    if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
435
                        logger.info(
436
                            long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
437
                        )
438
439
                        n_warning += 1

440
441
442
443
444
445
446
447
448
449
450
451
                    continue
                # found a block that is either
                # (1) not written
                # (2) read by all readers

                # mark the block as not written
                metadata_buffer[0] = 0
                # let caller write to the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has written to the buffer
452
453
454
455
                # NOTE: order is important here
                # first set the read flags to 0
                # then set the written flag to 1
                # otherwise, the readers may think they already read the block
456
457
458
                for i in range(1, self.buffer.n_reader + 1):
                    # set read flag to 0, meaning it is not read yet
                    metadata_buffer[i] = 0
459
460
                # mark the block as written
                metadata_buffer[0] = 1
461
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
462
463
464
                break

    @contextmanager
465
466
    def acquire_read(
        self,
467
468
        timeout: float | None = None,
        cancel: Event | None = None,
469
470
        indefinite: bool = False,
    ):
471
        assert self._is_local_reader, "Only readers can acquire read"
472
        start_time = time.monotonic()
473
474
475
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
476
                read_flag = metadata_buffer[self.local_reader_rank + 1]
477
478
479
480
481
                written_flag = metadata_buffer[0]
                if not written_flag or read_flag:
                    # this block is either
                    # (1) not written
                    # (2) already read by this reader
482
483
484
485
486

                    # for readers, `self.current_idx` is the next block to read
                    # if this block is not ready,
                    # we need to wait until it is written

487
                    # Release the processor to other threads
488
                    self._read_spin_timer.spin()
489

490
491
492
                    if cancel is not None and cancel.is_set():
                        raise RuntimeError("cancelled")

493
                    # if we time out, raise an exception
494
495
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
496
497
                        raise TimeoutError

498
                    # if we wait for a long time, log a message
499
500
501
                    if not indefinite and (
                        elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
                    ):
502
                        logger.info(
503
                            long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
504
                        )
505
506
                        n_warning += 1

507
508
509
510
511
512
513
514
                    continue
                # found a block that is not read by this reader
                # let caller read from the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has read from the buffer
                # set the read flag
515
                metadata_buffer[self.local_reader_rank + 1] = 1
516
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
517
518

                self._read_spin_timer.record_activity()
519
520
                break

521
    def enqueue(self, obj, timeout: float | None = None):
522
        """Write to message queue with optional timeout (in seconds)"""
523
        assert self._is_writer, "Only writers can enqueue"
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
        all_buffers: list[SizedBuffer] = [b""]
        total_bytes = 6  # 2 bytes for oob buffer count, 4 for main buffer size

        def oob_callback(buf: PickleBuffer) -> bool:
            raw_buf = buf.raw()
            if len(raw_buf) < 1024 * 1024:
                # In-line buffers smaller than 1MiB.
                return True
            all_buffers.append(raw_buf)
            nonlocal total_bytes
            total_bytes += len(raw_buf) + 4
            return False

        all_buffers[0] = pickle.dumps(
            obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
        )
540
        if self.n_local_reader > 0:
541
            if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
542
                with self.acquire_write(timeout) as buf:
543
                    buf[0] = 1  # overflow
544
                self.local_socket.send_multipart(all_buffers, copy=False)
545
            else:
546
547
548
549
                # Byte 0: 0
                # Bytes 1-2: Count of buffers
                # Then each buffer follows, preceded by 4 bytes containing its length:
                # [4 byte int L][L bytes of buffer content] ...
550
                with self.acquire_write(timeout) as buf:
551
                    buf[0] = 0  # not overflow
552
553
554
555
556
557
558
559
560
                    offset = 3
                    buf[1:offset] = to_bytes_big(len(all_buffers), 2)  # oob buf count
                    for buffer in all_buffers:
                        buf_len = len(buffer)
                        # prepend each buffer with 4 bytes containing its size.
                        buf_offset = offset + 4
                        buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
                        buf[buf_offset : (offset := buf_offset + buf_len)] = buffer

561
        if self.n_remote_reader > 0:
562
            self.remote_socket.send_multipart(all_buffers, copy=False)
563

564
565
    def dequeue(
        self,
566
567
        timeout: float | None = None,
        cancel: Event | None = None,
568
569
570
        indefinite: bool = False,
    ):
        """Read from message queue with optional timeout (in seconds)"""
571
        if self._is_local_reader:
572
            with self.acquire_read(timeout, cancel, indefinite) as buf:
573
574
                overflow = buf[0] == 1
                if not overflow:
575
576
577
578
579
580
581
582
583
                    offset = 3
                    buf_count = from_bytes_big(buf[1:offset])
                    all_buffers = []
                    for i in range(buf_count):
                        buf_offset = offset + 4
                        buf_len = from_bytes_big(buf[offset:buf_offset])
                        offset = buf_offset + buf_len
                        all_buffers.append(buf[buf_offset:offset])
                    obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
584
            if overflow:
585
                obj = MessageQueue.recv(self.local_socket, timeout)
586
        elif self._is_remote_reader:
587
            obj = MessageQueue.recv(self.remote_socket, timeout)
588
589
        else:
            raise RuntimeError("Only readers can dequeue")
590
591
        return obj

592
    @staticmethod
593
    def recv(socket: zmq.Socket, timeout: float | None) -> Any:
594
595
596
        timeout_ms = None if timeout is None else int(timeout * 1000)
        if not socket.poll(timeout=timeout_ms):
            raise TimeoutError
597
598
        recv, *recv_oob = socket.recv_multipart(copy=False)
        return pickle.loads(recv, buffers=recv_oob)
599

600
601
602
603
    def broadcast_object(self, obj=None):
        if self._is_writer:
            self.enqueue(obj)
            return obj
604
        return self.dequeue()
605

606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
    @staticmethod
    def create_from_process_group_single_reader(
        pg: ProcessGroup,
        max_chunk_bytes,
        max_chunks,
        reader_rank: int = 0,
        blocking: bool = False,
    ) -> tuple["MessageQueue", list[Handle]]:
        """
        Creates a MessageQueue for a process group with a single reader.

        This method is designed for scenarios where only one process (the reader)
        will consume messages, and all other processes are writers. It sets up
        the shared memory buffer and communication handles accordingly, and
        gathers the handles from all processes to the reader.

        Args:
            pg (ProcessGroup): The torch distributed process group.
            max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
            max_chunks (int): Maximum number of chunks in the buffer.
            reader_rank (int, optional): The global rank that will act as the reader.
                Defaults to 0.
            blocking (bool, optional): If True, blocks until all processes are ready.
                Defaults to False.

        Returns:
            tuple[MessageQueue, list[Handle]]:
            The MessageQueue instance for the calling process,
            and a list of handles (only non-empty for the reader process).
        """
636
        local_size = current_platform.device_count()
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
        rank = dist.get_rank()
        same_node = rank // local_size == reader_rank // local_size
        buffer_io = MessageQueue(
            n_reader=1,
            n_local_reader=1 if same_node else 0,
            max_chunk_bytes=max_chunk_bytes,
            max_chunks=max_chunks,
        )
        handle = buffer_io.export_handle()
        handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
        dist.gather_object(handle, handles, dst=reader_rank, group=pg)
        if blocking:
            buffer_io.wait_until_ready()
        return buffer_io, cast(list[Handle], handles or [])

652
    @staticmethod
653
    def create_from_process_group(
654
        pg: ProcessGroup | StatelessProcessGroup,
655
656
        max_chunk_bytes,
        max_chunks,
657
658
659
        writer_rank: int = 0,
        external_writer_handle=None,
        blocking: bool = True,
660
    ) -> "MessageQueue":
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
        """
        Creates a MessageQueue for a distributed process group with one writer and
        multiple readers.

        This method is designed for scenarios where one process (the writer) sends
        messages, and all other processes (the readers) receive messages. It sets up
        the shared memory buffer and socket communication handles accordingly, and
        broadcasts the handle from the writer to all readers.

        Args:
            pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
                group.
            max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
            max_chunks (int): Maximum number of chunks in the buffer.
            writer_rank (int, optional): The global rank that will act as the writer.
                Defaults to 0.
            external_writer_handle (Handle, optional): Used when there is a handle
                from an external Message Queue. If provided, use this handle to init
                PG writer message queue instead of creating a new one. Defaults to None.
            blocking (bool, optional): If True, blocks until all processes are ready.
                Defaults to True.

        Returns:
            MessageQueue: The MessageQueue instance for the calling process.

        """
687
688
689
690
691
692
693
694
        if isinstance(pg, ProcessGroup):
            group_rank = dist.get_rank(pg)
            group_world_size = dist.get_world_size(pg)
            global_ranks = dist.get_process_group_ranks(pg)
        else:
            group_rank = pg.rank
            group_world_size = pg.world_size
            global_ranks = list(range(pg.world_size))
695
        from vllm.distributed.parallel_state import in_the_same_node_as
696

697
        status = in_the_same_node_as(pg, source_rank=writer_rank)
698
        if group_rank == writer_rank:
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
            if external_writer_handle is not None:
                buffer_io = MessageQueue.create_from_handle(
                    external_writer_handle, group_rank
                )
            else:
                same_node_ranks = [i for i, s in enumerate(status) if s]
                n_reader = group_world_size - 1
                n_local_reader = len(same_node_ranks) - 1
                local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
                buffer_io = MessageQueue(
                    n_reader=n_reader,
                    n_local_reader=n_local_reader,
                    local_reader_ranks=local_reader_ranks,
                    max_chunk_bytes=max_chunk_bytes,
                    max_chunks=max_chunks,
                )
715
            handle = buffer_io.export_handle()
716
            if isinstance(pg, ProcessGroup):
717
718
719
                dist.broadcast_object_list(
                    [handle], src=global_ranks[writer_rank], group=pg
                )
720
721
            else:
                pg.broadcast_obj(handle, writer_rank)
722
        else:
723
724
            if isinstance(pg, ProcessGroup):
                recv = [None]
725
726
727
                dist.broadcast_object_list(
                    recv, src=global_ranks[writer_rank], group=pg
                )
728
729
730
                handle = recv[0]  # type: ignore
            else:
                handle = pg.broadcast_obj(None, writer_rank)
731
            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
732
733
        if blocking:
            buffer_io.wait_until_ready()
734
        return buffer_io