shm_broadcast.py 31.7 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
import pickle
5
import threading
6
7
import time
from contextlib import contextmanager
8
from dataclasses import dataclass, field
9
from multiprocessing import shared_memory
10
from pickle import PickleBuffer
11
from threading import Event
12
from typing import TYPE_CHECKING, Any, cast
13
14
15
16
from unittest.mock import patch

import torch
import torch.distributed as dist
17
import zmq
18
from torch.distributed import ProcessGroup
19
20
21
22
23
24
25
26
from zmq import (  # type: ignore
    IPV6,  # type: ignore
    SUB,
    SUBSCRIBE,
    XPUB,
    XPUB_VERBOSE,
    Context,
)
27
28

import vllm.envs as envs
29
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
30
from vllm.logger import init_logger
31
from vllm.platforms import current_platform
32
from vllm.utils.network_utils import (
33
34
35
36
37
    get_ip,
    get_open_port,
    get_open_zmq_ipc_path,
    is_valid_ipv6_address,
)
38

39
40
41
if TYPE_CHECKING:
    from _typeshed import SizedBuffer

42
43
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

44
45
46
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")


47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# Memory fence for cross-process shared memory visibility.
# Required for correct producer-consumer synchronization when using
# shared memory without locks.
_memory_fence_lock = threading.Lock()


def memory_fence():
    """
    Full memory barrier for shared memory synchronization.

    Ensures all prior memory writes are visible to other processes before
    any subsequent reads. This is critical for lock-free producer-consumer
    patterns using shared memory.

    Implementation acquires and immediately releases a lock. Python's
    threading.Lock provides sequentially consistent memory barrier semantics
    across all major platforms (POSIX, Windows). This is a lightweight
    operation (~20ns) that guarantees:
    - All stores before the barrier are visible to other threads/processes
    - All loads after the barrier see the latest values
    """
    # Lock acquire/release provides full memory barrier semantics.
    # Using context manager ensures lock release even on exceptions.
    with _memory_fence_lock:
        pass


74
75
76
77
def to_bytes_big(value: int, size: int) -> bytes:
    return value.to_bytes(size, byteorder="big")


78
79
80
logger = init_logger(__name__)


81
82
83
84
85
86
87
88
89
90
def long_wait_time_msg(threshold: int) -> str:
    return (
        "No available shared memory broadcast block found "
        f"in {threshold} seconds. This typically happens "
        "when some processes are hanging or doing some "
        "time-consuming work (e.g. compilation, "
        "weight/kv cache quantization)."
    )


91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
class SpinTimer:
    def record_activity(self):
        pass

    def spin(self):
        sched_yield()


class SpinSleepTimer(SpinTimer):
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when vllm does nothing. This would lead to more
    CPU thermal headroom when a request eventually comes, especially when
    multiple GPUs are connected as each GPU would otherwise pin one thread at
    100% CPU usage.

    The simplest solution is to reduce polling frequency when there is no
    activity for a certain period of time.
    """

    def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
        self.last_activity = time.monotonic()
        self.busy_loop_s = busy_loop_s
        self.wait_sleep_s = wait_sleep_s

    def record_activity(self):
        self.last_activity = time.monotonic()

    def spin(self):
        curr_time = time.monotonic()
        if curr_time >= self.last_activity + self.busy_loop_s:
            time.sleep(self.wait_sleep_s)
        else:
            sched_yield()


127
class ShmRingBuffer:
128
129
130
131
132
    def __init__(
        self,
        n_reader: int,
        max_chunk_bytes: int,
        max_chunks: int,
133
        name: str | None = None,
134
    ):
135
136
137
138
139
140
141
        """
        A shared memory ring buffer implementation for broadcast communication.
        Essentially, it is a queue where only one will `enqueue` and multiple
        will `dequeue`. The max size of each item, together with the max number
        of items that can be stored in the buffer are known in advance.
        In this case, we don't need to synchronize the access to
         the buffer.
142

143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
        Buffer memory layout:
                  data                                 metadata
                    |                                      |
                    | (current_idx)                        | (current_idx)
                    v                                      v
        +-------------------------------+----------------------------------------+
        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
        +-------------------------------+----------------------------------------+
        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

        metadata memory layout: each byte is a flag, the first byte is the written
        flag, and the rest are reader flags. The flags are set to 0 by default.
        +--------------+--------------+--------------+-----+--------------+
        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
        +--------------+--------------+--------------+-----+--------------+

159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
        The state of metadata is as follows:

        (case 1) 0???...???: the block is not written yet, cannot read, can write
        (case 2) 1000...000: the block is just written, can read, cannot write
        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

        State transition for readers:

        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
        Only after the caller finishes reading the block, the reader can mark the block as read.
        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

        State transition for writer:

        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

179
180
181
182
        During creation, `name` is None and the buffer is created. We can pass the
        created object to other processes by pickling it. The other processes will
        get the name of the shared memory and open it, so that they can access the
        same shared memory buffer.
183
        """  # noqa
184
185
186
187
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
188
189
190
        self.total_bytes_of_buffer = (
            self.max_chunk_bytes + self.metadata_size
        ) * self.max_chunks
191
192
193
194
195
196
197
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

        if name is None:
            # we are creating a buffer
            self.is_creator = True
            self.shared_memory = shared_memory.SharedMemory(
198
199
                create=True, size=self.total_bytes_of_buffer
            )
200
            # initialize the metadata section to 0
201
            with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer:
202
203
204
205
206
207
208
                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
        else:
            # we are opening an existing buffer
            self.is_creator = False
            # fix to https://stackoverflow.com/q/62748654/9191338
            # Python incorrectly tracks shared memory even if it is not
            # created by the process. The following patch is a workaround.
209
210
211
212
            with patch(
                "multiprocessing.resource_tracker.register",
                lambda *args, **kwargs: None,
            ):
213
214
                try:
                    self.shared_memory = shared_memory.SharedMemory(name=name)
215
216
217
218
219
                    # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
                    # Some platforms allocate memory based on page size,
                    # so the shared memory block size may be larger or equal
                    # to the requested size. The size parameter is ignored
                    # when attaching to an existing block.
220
                    assert self.shared_memory.size >= self.total_bytes_of_buffer
221
222
223
224
225
                except FileNotFoundError:
                    # we might deserialize the object in a different node
                    # in this case, this object is not used,
                    # and we should suppress the error
                    pass
226

227
    def handle(self):
228
229
230
231
232
233
        return (
            self.n_reader,
            self.max_chunk_bytes,
            self.max_chunks,
            self.shared_memory.name,
        )
234

235
236
237
    def __reduce__(self):
        return (
            self.__class__,
238
            self.handle(),
239
240
241
        )

    def __del__(self):
242
243
244
245
        if hasattr(self, "shared_memory"):
            self.shared_memory.close()
            if self.is_creator:
                self.shared_memory.unlink()
246
247
248
249
250

    @contextmanager
    def get_data(self, current_idx: int):
        start = self.data_offset + current_idx * self.max_chunk_bytes
        end = start + self.max_chunk_bytes
251
        with self.shared_memory.buf[start:end] as buf:
252
253
254
255
256
257
            yield buf

    @contextmanager
    def get_metadata(self, current_idx: int):
        start = self.metadata_offset + current_idx * self.metadata_size
        end = start + self.metadata_size
258
        with self.shared_memory.buf[start:end] as buf:
259
260
261
            yield buf


262
263
@dataclass
class Handle:
264
    local_reader_ranks: list[int] = field(default_factory=list)
265

266
267
268
    buffer_handle: tuple[int, int, int, str] | None = None
    local_subscribe_addr: str | None = None
    remote_subscribe_addr: str | None = None
269
    remote_addr_ipv6: bool = False
270
271
272
273
274
275
276


class MessageQueue:
    def __init__(
        self,
        n_reader,  # number of all readers
        n_local_reader,  # number of local readers through shared memory
277
        local_reader_ranks: list[int] | None = None,
278
279
280
        # Default of 24MiB chosen to be large enough to accommodate grammar
        # bitmask tensors for large batches (1024 requests).
        max_chunk_bytes: int = 1024 * 1024 * 24,
281
        max_chunks: int = 10,
282
        connect_ip: str | None = None,
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    ):
        if local_reader_ranks is None:
            local_reader_ranks = list(range(n_local_reader))
        else:
            assert len(local_reader_ranks) == n_local_reader
        self.n_local_reader = n_local_reader
        n_remote_reader = n_reader - n_local_reader
        self.n_remote_reader = n_remote_reader

        context = Context()

        if n_local_reader > 0:
            # for local readers, we will:
            # 1. create a shared memory ring buffer to communicate small data
            # 2. create a publish-subscribe socket to communicate large data
298
            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
299

300
301
302
303
304
305
306
307
            # XPUB is very similar to PUB,
            # except that it can receive subscription messages
            # to confirm the number of subscribers
            self.local_socket = context.socket(XPUB)
            # set the verbose option so that we can receive every subscription
            # message. otherwise, we will only receive the first subscription
            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
            self.local_socket.setsockopt(XPUB_VERBOSE, True)
308
309
310
            local_subscribe_addr = get_open_zmq_ipc_path()
            logger.debug("Binding to %s", local_subscribe_addr)
            self.local_socket.bind(local_subscribe_addr)
311
312
313
314

            self.current_idx = 0
        else:
            self.buffer = None  # type: ignore
315
            local_subscribe_addr = None
316
317
318
            self.local_socket = None
            self.current_idx = -1

319
        remote_addr_ipv6 = False
320
321
322
        if n_remote_reader > 0:
            # for remote readers, we will:
            # create a publish-subscribe socket to communicate large data
323
324
            if not connect_ip:
                connect_ip = get_ip()
325
326
            self.remote_socket = context.socket(XPUB)
            self.remote_socket.setsockopt(XPUB_VERBOSE, True)
327
            remote_subscribe_port = get_open_port()
328
329
            if is_valid_ipv6_address(connect_ip):
                self.remote_socket.setsockopt(IPV6, 1)
330
                remote_addr_ipv6 = True
331
                connect_ip = f"[{connect_ip}]"
332
            socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
333
            self.remote_socket.bind(socket_addr)
334
            remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
335
        else:
336
            remote_subscribe_addr = None
337
338
339
340
341
342
343
            self.remote_socket = None

        self._is_writer = True
        self._is_local_reader = False
        self.local_reader_rank = -1
        # rank does not matter for remote readers
        self._is_remote_reader = False
344
        self._read_spin_timer = SpinTimer()
345
346
347

        self.handle = Handle(
            local_reader_ranks=local_reader_ranks,
348
            buffer_handle=self.buffer.handle() if self.buffer is not None else None,
349
350
351
            local_subscribe_addr=local_subscribe_addr,
            remote_subscribe_addr=remote_subscribe_addr,
            remote_addr_ipv6=remote_addr_ipv6,
352
353
        )

354
        logger.debug("vLLM message queue communication handle: %s", self.handle)
355

356
357
358
359
360
361
362
363
364
365
366
367
    def export_handle(self) -> Handle:
        return self.handle

    @staticmethod
    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
        self = MessageQueue.__new__(MessageQueue)
        self.handle = handle
        self._is_writer = False

        context = Context()

        if rank in handle.local_reader_ranks:
368
369
            assert handle.buffer_handle is not None
            self.buffer = ShmRingBuffer(*handle.buffer_handle)
370
371
372
373
374
375
376
            self.current_idx = 0
            self.local_reader_rank = handle.local_reader_ranks.index(rank)
            self._is_local_reader = True
            self._is_remote_reader = False

            self.local_socket = context.socket(SUB)
            self.local_socket.setsockopt_string(SUBSCRIBE, "")
377
            socket_addr = handle.local_subscribe_addr
378
379
            logger.debug("Connecting to %s", socket_addr)
            self.local_socket.connect(socket_addr)
380
381

            self.remote_socket = None
382

383
384
385
            self._read_spin_timer = (
                SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
            )
386
387
388
389
390
391
392
393
394
395
396
        else:
            self.buffer = None  # type: ignore
            self.current_idx = -1
            self.local_reader_rank = -1
            self._is_local_reader = False
            self._is_remote_reader = True

            self.local_socket = None

            self.remote_socket = context.socket(SUB)
            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
397
            if handle.remote_addr_ipv6:
398
                self.remote_socket.setsockopt(IPV6, 1)
399
            socket_addr = handle.remote_subscribe_addr
400
401
            logger.debug("Connecting to %s", socket_addr)
            self.remote_socket.connect(socket_addr)
402
403
404
405
406
407
408
409
410
411
412
413

        return self

    def wait_until_ready(self):
        """This is a collective operation. All processes (including the
        readers and the writer) should call this function.
        """
        if self._is_writer:
            # wait for all readers to connect

            # local readers
            for i in range(self.n_local_reader):
414
415
                # wait for subscription messages from all local readers
                self.local_socket.recv()
416
            if self.n_local_reader > 0:
417
418
                # send a message to all local readers
                # to make sure the publish channel is working
419
420
421
422
                self.local_socket.send(b"READY")

            # remote readers
            for i in range(self.n_remote_reader):
423
424
                # wait for subscription messages from all remote readers
                self.remote_socket.recv()
425
            if self.n_remote_reader > 0:
426
427
                # send a message to all remote readers
                # to make sure the publish channel is working
428
429
                self.remote_socket.send(b"READY")
        elif self._is_local_reader:
430
            # wait for the writer to send a message
431
432
433
            recv = self.local_socket.recv()
            assert recv == b"READY"
        elif self._is_remote_reader:
434
            # wait for the writer to send a message
435
436
            recv = self.remote_socket.recv()
            assert recv == b"READY"
437
438

    @contextmanager
439
    def acquire_write(self, timeout: float | None = None):
440
        assert self._is_writer, "Only writers can acquire write"
441
        start_time = time.monotonic()
442
443
444
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
445
446
447
448
                # Memory fence ensures we see the latest read flags from readers.
                # Without this, we may read stale flags from our CPU cache and
                # spin indefinitely even though readers have completed.
                memory_fence()
449
450
451
452
                read_count = sum(metadata_buffer[1:])
                written_flag = metadata_buffer[0]
                if written_flag and read_count != self.buffer.n_reader:
                    # this block is written and not read by all readers
453
454
455
456
                    # for writers, `self.current_idx` is the next block to write
                    # if this block is not ready to write,
                    # we need to wait until it is read by all readers

457
                    # Release the processor to other threads
458
                    sched_yield()
459

460
461
462
463
464
                    # if we time out, raise an exception
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
                        raise TimeoutError

465
                    # if we wait for a long time, log a message
466
                    if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
467
                        logger.info(
468
                            long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
469
                        )
470
471
                        n_warning += 1

472
473
474
475
476
477
478
479
480
481
482
483
                    continue
                # found a block that is either
                # (1) not written
                # (2) read by all readers

                # mark the block as not written
                metadata_buffer[0] = 0
                # let caller write to the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has written to the buffer
484
485
486
487
                # NOTE: order is important here
                # first set the read flags to 0
                # then set the written flag to 1
                # otherwise, the readers may think they already read the block
488
489
490
                for i in range(1, self.buffer.n_reader + 1):
                    # set read flag to 0, meaning it is not read yet
                    metadata_buffer[i] = 0
491
492
                # mark the block as written
                metadata_buffer[0] = 1
493
494
495
496
                # Memory fence ensures the write is visible to readers on other cores
                # before we proceed. Without this, readers may spin indefinitely
                # waiting for a write that's stuck in our CPU's store buffer.
                memory_fence()
497
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
498
499
500
                break

    @contextmanager
501
502
    def acquire_read(
        self,
503
504
        timeout: float | None = None,
        cancel: Event | None = None,
505
506
        indefinite: bool = False,
    ):
507
        assert self._is_local_reader, "Only readers can acquire read"
508
        start_time = time.monotonic()
509
510
511
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
512
513
514
515
                # Memory fence ensures we see the latest writes from the writer.
                # Without this, we may read stale flags from our CPU cache
                # and spin indefinitely even though writer has updated them.
                memory_fence()
516
                read_flag = metadata_buffer[self.local_reader_rank + 1]
517
518
519
520
521
                written_flag = metadata_buffer[0]
                if not written_flag or read_flag:
                    # this block is either
                    # (1) not written
                    # (2) already read by this reader
522
523
524
525
526

                    # for readers, `self.current_idx` is the next block to read
                    # if this block is not ready,
                    # we need to wait until it is written

527
                    # Release the processor to other threads
528
                    self._read_spin_timer.spin()
529

530
531
532
                    if cancel is not None and cancel.is_set():
                        raise RuntimeError("cancelled")

533
                    # if we time out, raise an exception
534
535
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
536
537
                        raise TimeoutError

538
                    # if we wait for a long time, log a message
539
540
541
                    if not indefinite and (
                        elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
                    ):
542
                        logger.info(
543
                            long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
544
                        )
545
546
                        n_warning += 1

547
548
549
550
551
552
553
554
                    continue
                # found a block that is not read by this reader
                # let caller read from the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has read from the buffer
                # set the read flag
555
                metadata_buffer[self.local_reader_rank + 1] = 1
556
557
558
559
                # Memory fence ensures the read flag is visible to the writer.
                # Without this, writer may not see our read completion and
                # could wait indefinitely for all readers to finish.
                memory_fence()
560
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
561
562

                self._read_spin_timer.record_activity()
563
564
                break

565
    def enqueue(self, obj, timeout: float | None = None):
566
        """Write to message queue with optional timeout (in seconds)"""
567
        assert self._is_writer, "Only writers can enqueue"
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
        all_buffers: list[SizedBuffer] = [b""]
        total_bytes = 6  # 2 bytes for oob buffer count, 4 for main buffer size

        def oob_callback(buf: PickleBuffer) -> bool:
            raw_buf = buf.raw()
            if len(raw_buf) < 1024 * 1024:
                # In-line buffers smaller than 1MiB.
                return True
            all_buffers.append(raw_buf)
            nonlocal total_bytes
            total_bytes += len(raw_buf) + 4
            return False

        all_buffers[0] = pickle.dumps(
            obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
        )
584
        if self.n_local_reader > 0:
585
            if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
586
                with self.acquire_write(timeout) as buf:
587
                    buf[0] = 1  # overflow
588
                self.local_socket.send_multipart(all_buffers, copy=False)
589
            else:
590
591
592
593
                # Byte 0: 0
                # Bytes 1-2: Count of buffers
                # Then each buffer follows, preceded by 4 bytes containing its length:
                # [4 byte int L][L bytes of buffer content] ...
594
                with self.acquire_write(timeout) as buf:
595
                    buf[0] = 0  # not overflow
596
597
598
599
600
601
602
603
604
                    offset = 3
                    buf[1:offset] = to_bytes_big(len(all_buffers), 2)  # oob buf count
                    for buffer in all_buffers:
                        buf_len = len(buffer)
                        # prepend each buffer with 4 bytes containing its size.
                        buf_offset = offset + 4
                        buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
                        buf[buf_offset : (offset := buf_offset + buf_len)] = buffer

605
        if self.n_remote_reader > 0:
606
            self.remote_socket.send_multipart(all_buffers, copy=False)
607

608
609
    def dequeue(
        self,
610
611
        timeout: float | None = None,
        cancel: Event | None = None,
612
613
614
        indefinite: bool = False,
    ):
        """Read from message queue with optional timeout (in seconds)"""
615
        if self._is_local_reader:
616
            with self.acquire_read(timeout, cancel, indefinite) as buf:
617
618
                overflow = buf[0] == 1
                if not overflow:
619
620
621
622
623
624
625
626
627
                    offset = 3
                    buf_count = from_bytes_big(buf[1:offset])
                    all_buffers = []
                    for i in range(buf_count):
                        buf_offset = offset + 4
                        buf_len = from_bytes_big(buf[offset:buf_offset])
                        offset = buf_offset + buf_len
                        all_buffers.append(buf[buf_offset:offset])
                    obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
628
            if overflow:
629
                obj = MessageQueue.recv(self.local_socket, timeout)
630
        elif self._is_remote_reader:
631
            obj = MessageQueue.recv(self.remote_socket, timeout)
632
633
        else:
            raise RuntimeError("Only readers can dequeue")
634
635
        return obj

636
    @staticmethod
637
    def recv(socket: zmq.Socket, timeout: float | None) -> Any:
638
639
640
        timeout_ms = None if timeout is None else int(timeout * 1000)
        if not socket.poll(timeout=timeout_ms):
            raise TimeoutError
641
642
        recv, *recv_oob = socket.recv_multipart(copy=False)
        return pickle.loads(recv, buffers=recv_oob)
643

644
645
646
647
    def broadcast_object(self, obj=None):
        if self._is_writer:
            self.enqueue(obj)
            return obj
648
        return self.dequeue()
649

650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
    @staticmethod
    def create_from_process_group_single_reader(
        pg: ProcessGroup,
        max_chunk_bytes,
        max_chunks,
        reader_rank: int = 0,
        blocking: bool = False,
    ) -> tuple["MessageQueue", list[Handle]]:
        """
        Creates a MessageQueue for a process group with a single reader.

        This method is designed for scenarios where only one process (the reader)
        will consume messages, and all other processes are writers. It sets up
        the shared memory buffer and communication handles accordingly, and
        gathers the handles from all processes to the reader.

        Args:
            pg (ProcessGroup): The torch distributed process group.
            max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
            max_chunks (int): Maximum number of chunks in the buffer.
            reader_rank (int, optional): The global rank that will act as the reader.
                Defaults to 0.
            blocking (bool, optional): If True, blocks until all processes are ready.
                Defaults to False.

        Returns:
            tuple[MessageQueue, list[Handle]]:
            The MessageQueue instance for the calling process,
            and a list of handles (only non-empty for the reader process).
        """
680
        local_size = current_platform.device_count()
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
        rank = dist.get_rank()
        same_node = rank // local_size == reader_rank // local_size
        buffer_io = MessageQueue(
            n_reader=1,
            n_local_reader=1 if same_node else 0,
            max_chunk_bytes=max_chunk_bytes,
            max_chunks=max_chunks,
        )
        handle = buffer_io.export_handle()
        handles = [None] * dist.get_world_size(pg) if rank == reader_rank else None
        dist.gather_object(handle, handles, dst=reader_rank, group=pg)
        if blocking:
            buffer_io.wait_until_ready()
        return buffer_io, cast(list[Handle], handles or [])

696
    @staticmethod
697
    def create_from_process_group(
698
        pg: ProcessGroup | StatelessProcessGroup,
699
700
        max_chunk_bytes,
        max_chunks,
701
702
703
        writer_rank: int = 0,
        external_writer_handle=None,
        blocking: bool = True,
704
    ) -> "MessageQueue":
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
        """
        Creates a MessageQueue for a distributed process group with one writer and
        multiple readers.

        This method is designed for scenarios where one process (the writer) sends
        messages, and all other processes (the readers) receive messages. It sets up
        the shared memory buffer and socket communication handles accordingly, and
        broadcasts the handle from the writer to all readers.

        Args:
            pg (ProcessGroup | StatelessProcessGroup): The torch distributed process
                group.
            max_chunk_bytes (int): Maximum size in bytes for each chunk in the buffer.
            max_chunks (int): Maximum number of chunks in the buffer.
            writer_rank (int, optional): The global rank that will act as the writer.
                Defaults to 0.
            external_writer_handle (Handle, optional): Used when there is a handle
                from an external Message Queue. If provided, use this handle to init
                PG writer message queue instead of creating a new one. Defaults to None.
            blocking (bool, optional): If True, blocks until all processes are ready.
                Defaults to True.

        Returns:
            MessageQueue: The MessageQueue instance for the calling process.

        """
731
732
733
734
735
736
737
738
        if isinstance(pg, ProcessGroup):
            group_rank = dist.get_rank(pg)
            group_world_size = dist.get_world_size(pg)
            global_ranks = dist.get_process_group_ranks(pg)
        else:
            group_rank = pg.rank
            group_world_size = pg.world_size
            global_ranks = list(range(pg.world_size))
739
        from vllm.distributed.parallel_state import in_the_same_node_as
740

741
        status = in_the_same_node_as(pg, source_rank=writer_rank)
742
        if group_rank == writer_rank:
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
            if external_writer_handle is not None:
                buffer_io = MessageQueue.create_from_handle(
                    external_writer_handle, group_rank
                )
            else:
                same_node_ranks = [i for i, s in enumerate(status) if s]
                n_reader = group_world_size - 1
                n_local_reader = len(same_node_ranks) - 1
                local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
                buffer_io = MessageQueue(
                    n_reader=n_reader,
                    n_local_reader=n_local_reader,
                    local_reader_ranks=local_reader_ranks,
                    max_chunk_bytes=max_chunk_bytes,
                    max_chunks=max_chunks,
                )
759
            handle = buffer_io.export_handle()
760
            if isinstance(pg, ProcessGroup):
761
762
763
                dist.broadcast_object_list(
                    [handle], src=global_ranks[writer_rank], group=pg
                )
764
765
            else:
                pg.broadcast_obj(handle, writer_rank)
766
        else:
767
768
            if isinstance(pg, ProcessGroup):
                recv = [None]
769
770
771
                dist.broadcast_object_list(
                    recv, src=global_ranks[writer_rank], group=pg
                )
772
773
774
                handle = recv[0]  # type: ignore
            else:
                handle = pg.broadcast_obj(None, writer_rank)
775
            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
776
777
        if blocking:
            buffer_io.wait_until_ready()
778
        return buffer_io