shm_broadcast.py 26 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
5
6
import pickle
import time
from contextlib import contextmanager
7
from dataclasses import dataclass, field
8
from multiprocessing import shared_memory
9
from pickle import PickleBuffer
10
from threading import Event
11
from typing import TYPE_CHECKING, Any
12
13
14
15
from unittest.mock import patch

import torch
import torch.distributed as dist
16
import zmq
17
from torch.distributed import ProcessGroup
18
19
20
21
22
23
24
25
from zmq import (  # type: ignore
    IPV6,  # type: ignore
    SUB,
    SUBSCRIBE,
    XPUB,
    XPUB_VERBOSE,
    Context,
)
26
27

import vllm.envs as envs
28
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
29
from vllm.logger import init_logger
30
from vllm.utils.network_utils import (
31
32
33
34
35
    get_ip,
    get_open_port,
    get_open_zmq_ipc_path,
    is_valid_ipv6_address,
)
36

37
38
39
if TYPE_CHECKING:
    from _typeshed import SizedBuffer

40
41
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

42
43
44
45
46
47
48
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")


def to_bytes_big(value: int, size: int) -> bytes:
    return value.to_bytes(size, byteorder="big")


49
50
51
logger = init_logger(__name__)


52
53
54
55
56
57
58
59
60
61
def long_wait_time_msg(threshold: int) -> str:
    return (
        "No available shared memory broadcast block found "
        f"in {threshold} seconds. This typically happens "
        "when some processes are hanging or doing some "
        "time-consuming work (e.g. compilation, "
        "weight/kv cache quantization)."
    )


62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class SpinTimer:
    def record_activity(self):
        pass

    def spin(self):
        sched_yield()


class SpinSleepTimer(SpinTimer):
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when vllm does nothing. This would lead to more
    CPU thermal headroom when a request eventually comes, especially when
    multiple GPUs are connected as each GPU would otherwise pin one thread at
    100% CPU usage.

    The simplest solution is to reduce polling frequency when there is no
    activity for a certain period of time.
    """

    def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
        self.last_activity = time.monotonic()
        self.busy_loop_s = busy_loop_s
        self.wait_sleep_s = wait_sleep_s

    def record_activity(self):
        self.last_activity = time.monotonic()

    def spin(self):
        curr_time = time.monotonic()
        if curr_time >= self.last_activity + self.busy_loop_s:
            time.sleep(self.wait_sleep_s)
        else:
            sched_yield()


98
class ShmRingBuffer:
99
100
101
102
103
    def __init__(
        self,
        n_reader: int,
        max_chunk_bytes: int,
        max_chunks: int,
104
        name: str | None = None,
105
    ):
106
107
108
109
110
111
112
        """
        A shared memory ring buffer implementation for broadcast communication.
        Essentially, it is a queue where only one will `enqueue` and multiple
        will `dequeue`. The max size of each item, together with the max number
        of items that can be stored in the buffer are known in advance.
        In this case, we don't need to synchronize the access to
         the buffer.
113

114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
        Buffer memory layout:
                  data                                 metadata
                    |                                      |
                    | (current_idx)                        | (current_idx)
                    v                                      v
        +-------------------------------+----------------------------------------+
        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
        +-------------------------------+----------------------------------------+
        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

        metadata memory layout: each byte is a flag, the first byte is the written
        flag, and the rest are reader flags. The flags are set to 0 by default.
        +--------------+--------------+--------------+-----+--------------+
        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
        +--------------+--------------+--------------+-----+--------------+

130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
        The state of metadata is as follows:

        (case 1) 0???...???: the block is not written yet, cannot read, can write
        (case 2) 1000...000: the block is just written, can read, cannot write
        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

        State transition for readers:

        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
        Only after the caller finishes reading the block, the reader can mark the block as read.
        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

        State transition for writer:

        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

150
151
152
153
        During creation, `name` is None and the buffer is created. We can pass the
        created object to other processes by pickling it. The other processes will
        get the name of the shared memory and open it, so that they can access the
        same shared memory buffer.
154
        """  # noqa
155
156
157
158
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
159
160
161
        self.total_bytes_of_buffer = (
            self.max_chunk_bytes + self.metadata_size
        ) * self.max_chunks
162
163
164
165
166
167
168
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

        if name is None:
            # we are creating a buffer
            self.is_creator = True
            self.shared_memory = shared_memory.SharedMemory(
169
170
                create=True, size=self.total_bytes_of_buffer
            )
171
            # initialize the metadata section to 0
172
            with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer:
173
174
175
176
177
178
179
                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
        else:
            # we are opening an existing buffer
            self.is_creator = False
            # fix to https://stackoverflow.com/q/62748654/9191338
            # Python incorrectly tracks shared memory even if it is not
            # created by the process. The following patch is a workaround.
180
181
182
183
            with patch(
                "multiprocessing.resource_tracker.register",
                lambda *args, **kwargs: None,
            ):
184
185
                try:
                    self.shared_memory = shared_memory.SharedMemory(name=name)
186
187
188
189
190
                    # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
                    # Some platforms allocate memory based on page size,
                    # so the shared memory block size may be larger or equal
                    # to the requested size. The size parameter is ignored
                    # when attaching to an existing block.
191
                    assert self.shared_memory.size >= self.total_bytes_of_buffer
192
193
194
195
196
                except FileNotFoundError:
                    # we might deserialize the object in a different node
                    # in this case, this object is not used,
                    # and we should suppress the error
                    pass
197

198
    def handle(self):
199
200
201
202
203
204
        return (
            self.n_reader,
            self.max_chunk_bytes,
            self.max_chunks,
            self.shared_memory.name,
        )
205

206
207
208
    def __reduce__(self):
        return (
            self.__class__,
209
            self.handle(),
210
211
212
        )

    def __del__(self):
213
214
215
216
        if hasattr(self, "shared_memory"):
            self.shared_memory.close()
            if self.is_creator:
                self.shared_memory.unlink()
217
218
219
220
221

    @contextmanager
    def get_data(self, current_idx: int):
        start = self.data_offset + current_idx * self.max_chunk_bytes
        end = start + self.max_chunk_bytes
222
        with self.shared_memory.buf[start:end] as buf:
223
224
225
226
227
228
            yield buf

    @contextmanager
    def get_metadata(self, current_idx: int):
        start = self.metadata_offset + current_idx * self.metadata_size
        end = start + self.metadata_size
229
        with self.shared_memory.buf[start:end] as buf:
230
231
232
            yield buf


233
234
@dataclass
class Handle:
235
    local_reader_ranks: list[int] = field(default_factory=list)
236

237
238
239
    buffer_handle: tuple[int, int, int, str] | None = None
    local_subscribe_addr: str | None = None
    remote_subscribe_addr: str | None = None
240
    remote_addr_ipv6: bool = False
241
242
243
244
245
246
247


class MessageQueue:
    def __init__(
        self,
        n_reader,  # number of all readers
        n_local_reader,  # number of local readers through shared memory
248
        local_reader_ranks: list[int] | None = None,
249
250
251
        # Default of 24MiB chosen to be large enough to accommodate grammar
        # bitmask tensors for large batches (1024 requests).
        max_chunk_bytes: int = 1024 * 1024 * 24,
252
        max_chunks: int = 10,
253
        connect_ip: str | None = None,
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
    ):
        if local_reader_ranks is None:
            local_reader_ranks = list(range(n_local_reader))
        else:
            assert len(local_reader_ranks) == n_local_reader
        self.n_local_reader = n_local_reader
        n_remote_reader = n_reader - n_local_reader
        self.n_remote_reader = n_remote_reader

        context = Context()

        if n_local_reader > 0:
            # for local readers, we will:
            # 1. create a shared memory ring buffer to communicate small data
            # 2. create a publish-subscribe socket to communicate large data
269
            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
270

271
272
273
274
275
276
277
278
            # XPUB is very similar to PUB,
            # except that it can receive subscription messages
            # to confirm the number of subscribers
            self.local_socket = context.socket(XPUB)
            # set the verbose option so that we can receive every subscription
            # message. otherwise, we will only receive the first subscription
            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
            self.local_socket.setsockopt(XPUB_VERBOSE, True)
279
280
281
            local_subscribe_addr = get_open_zmq_ipc_path()
            logger.debug("Binding to %s", local_subscribe_addr)
            self.local_socket.bind(local_subscribe_addr)
282
283
284
285

            self.current_idx = 0
        else:
            self.buffer = None  # type: ignore
286
            local_subscribe_addr = None
287
288
289
            self.local_socket = None
            self.current_idx = -1

290
        remote_addr_ipv6 = False
291
292
293
        if n_remote_reader > 0:
            # for remote readers, we will:
            # create a publish-subscribe socket to communicate large data
294
295
            if not connect_ip:
                connect_ip = get_ip()
296
297
            self.remote_socket = context.socket(XPUB)
            self.remote_socket.setsockopt(XPUB_VERBOSE, True)
298
            remote_subscribe_port = get_open_port()
299
300
            if is_valid_ipv6_address(connect_ip):
                self.remote_socket.setsockopt(IPV6, 1)
301
                remote_addr_ipv6 = True
302
                connect_ip = f"[{connect_ip}]"
303
            socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
304
            self.remote_socket.bind(socket_addr)
305
            remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
306
        else:
307
            remote_subscribe_addr = None
308
309
310
311
312
313
314
            self.remote_socket = None

        self._is_writer = True
        self._is_local_reader = False
        self.local_reader_rank = -1
        # rank does not matter for remote readers
        self._is_remote_reader = False
315
        self._read_spin_timer = SpinTimer()
316
317
318

        self.handle = Handle(
            local_reader_ranks=local_reader_ranks,
319
            buffer_handle=self.buffer.handle() if self.buffer is not None else None,
320
321
322
            local_subscribe_addr=local_subscribe_addr,
            remote_subscribe_addr=remote_subscribe_addr,
            remote_addr_ipv6=remote_addr_ipv6,
323
324
        )

325
        logger.debug("vLLM message queue communication handle: %s", self.handle)
326

327
328
329
330
331
332
333
334
335
336
337
338
    def export_handle(self) -> Handle:
        return self.handle

    @staticmethod
    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
        self = MessageQueue.__new__(MessageQueue)
        self.handle = handle
        self._is_writer = False

        context = Context()

        if rank in handle.local_reader_ranks:
339
340
            assert handle.buffer_handle is not None
            self.buffer = ShmRingBuffer(*handle.buffer_handle)
341
342
343
344
345
346
347
            self.current_idx = 0
            self.local_reader_rank = handle.local_reader_ranks.index(rank)
            self._is_local_reader = True
            self._is_remote_reader = False

            self.local_socket = context.socket(SUB)
            self.local_socket.setsockopt_string(SUBSCRIBE, "")
348
            socket_addr = handle.local_subscribe_addr
349
350
            logger.debug("Connecting to %s", socket_addr)
            self.local_socket.connect(socket_addr)
351
352

            self.remote_socket = None
353

354
355
356
            self._read_spin_timer = (
                SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
            )
357
358
359
360
361
362
363
364
365
366
367
        else:
            self.buffer = None  # type: ignore
            self.current_idx = -1
            self.local_reader_rank = -1
            self._is_local_reader = False
            self._is_remote_reader = True

            self.local_socket = None

            self.remote_socket = context.socket(SUB)
            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
368
            if handle.remote_addr_ipv6:
369
                self.remote_socket.setsockopt(IPV6, 1)
370
            socket_addr = handle.remote_subscribe_addr
371
372
            logger.debug("Connecting to %s", socket_addr)
            self.remote_socket.connect(socket_addr)
373
374
375
376
377
378
379
380
381
382
383
384

        return self

    def wait_until_ready(self):
        """This is a collective operation. All processes (including the
        readers and the writer) should call this function.
        """
        if self._is_writer:
            # wait for all readers to connect

            # local readers
            for i in range(self.n_local_reader):
385
386
                # wait for subscription messages from all local readers
                self.local_socket.recv()
387
            if self.n_local_reader > 0:
388
389
                # send a message to all local readers
                # to make sure the publish channel is working
390
391
392
393
                self.local_socket.send(b"READY")

            # remote readers
            for i in range(self.n_remote_reader):
394
395
                # wait for subscription messages from all remote readers
                self.remote_socket.recv()
396
            if self.n_remote_reader > 0:
397
398
                # send a message to all remote readers
                # to make sure the publish channel is working
399
400
                self.remote_socket.send(b"READY")
        elif self._is_local_reader:
401
            # wait for the writer to send a message
402
403
404
            recv = self.local_socket.recv()
            assert recv == b"READY"
        elif self._is_remote_reader:
405
            # wait for the writer to send a message
406
407
            recv = self.remote_socket.recv()
            assert recv == b"READY"
408
409

    @contextmanager
410
    def acquire_write(self, timeout: float | None = None):
411
        assert self._is_writer, "Only writers can acquire write"
412
        start_time = time.monotonic()
413
414
415
416
417
418
419
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
                read_count = sum(metadata_buffer[1:])
                written_flag = metadata_buffer[0]
                if written_flag and read_count != self.buffer.n_reader:
                    # this block is written and not read by all readers
420
421
422
423
                    # for writers, `self.current_idx` is the next block to write
                    # if this block is not ready to write,
                    # we need to wait until it is read by all readers

424
                    # Release the processor to other threads
425
                    sched_yield()
426

427
428
429
430
431
                    # if we time out, raise an exception
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
                        raise TimeoutError

432
                    # if we wait for a long time, log a message
433
                    if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
434
                        logger.info(
435
                            long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
436
                        )
437
438
                        n_warning += 1

439
440
441
442
443
444
445
446
447
448
449
450
                    continue
                # found a block that is either
                # (1) not written
                # (2) read by all readers

                # mark the block as not written
                metadata_buffer[0] = 0
                # let caller write to the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has written to the buffer
451
452
453
454
                # NOTE: order is important here
                # first set the read flags to 0
                # then set the written flag to 1
                # otherwise, the readers may think they already read the block
455
456
457
                for i in range(1, self.buffer.n_reader + 1):
                    # set read flag to 0, meaning it is not read yet
                    metadata_buffer[i] = 0
458
459
                # mark the block as written
                metadata_buffer[0] = 1
460
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
461
462
463
                break

    @contextmanager
464
465
    def acquire_read(
        self,
466
467
        timeout: float | None = None,
        cancel: Event | None = None,
468
469
        indefinite: bool = False,
    ):
470
        assert self._is_local_reader, "Only readers can acquire read"
471
        start_time = time.monotonic()
472
473
474
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
475
                read_flag = metadata_buffer[self.local_reader_rank + 1]
476
477
478
479
480
                written_flag = metadata_buffer[0]
                if not written_flag or read_flag:
                    # this block is either
                    # (1) not written
                    # (2) already read by this reader
481
482
483
484
485

                    # for readers, `self.current_idx` is the next block to read
                    # if this block is not ready,
                    # we need to wait until it is written

486
                    # Release the processor to other threads
487
                    self._read_spin_timer.spin()
488

489
490
491
                    if cancel is not None and cancel.is_set():
                        raise RuntimeError("cancelled")

492
                    # if we time out, raise an exception
493
494
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
495
496
                        raise TimeoutError

497
                    # if we wait for a long time, log a message
498
499
500
                    if not indefinite and (
                        elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
                    ):
501
                        logger.info(
502
                            long_wait_time_msg(VLLM_RINGBUFFER_WARNING_INTERVAL)
503
                        )
504
505
                        n_warning += 1

506
507
508
509
510
511
512
513
                    continue
                # found a block that is not read by this reader
                # let caller read from the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has read from the buffer
                # set the read flag
514
                metadata_buffer[self.local_reader_rank + 1] = 1
515
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
516
517

                self._read_spin_timer.record_activity()
518
519
                break

520
    def enqueue(self, obj, timeout: float | None = None):
521
        """Write to message queue with optional timeout (in seconds)"""
522
        assert self._is_writer, "Only writers can enqueue"
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
        all_buffers: list[SizedBuffer] = [b""]
        total_bytes = 6  # 2 bytes for oob buffer count, 4 for main buffer size

        def oob_callback(buf: PickleBuffer) -> bool:
            raw_buf = buf.raw()
            if len(raw_buf) < 1024 * 1024:
                # In-line buffers smaller than 1MiB.
                return True
            all_buffers.append(raw_buf)
            nonlocal total_bytes
            total_bytes += len(raw_buf) + 4
            return False

        all_buffers[0] = pickle.dumps(
            obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
        )
539
        if self.n_local_reader > 0:
540
            if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
541
                with self.acquire_write(timeout) as buf:
542
                    buf[0] = 1  # overflow
543
                self.local_socket.send_multipart(all_buffers, copy=False)
544
            else:
545
546
547
548
                # Byte 0: 0
                # Bytes 1-2: Count of buffers
                # Then each buffer follows, preceded by 4 bytes containing its length:
                # [4 byte int L][L bytes of buffer content] ...
549
                with self.acquire_write(timeout) as buf:
550
                    buf[0] = 0  # not overflow
551
552
553
554
555
556
557
558
559
                    offset = 3
                    buf[1:offset] = to_bytes_big(len(all_buffers), 2)  # oob buf count
                    for buffer in all_buffers:
                        buf_len = len(buffer)
                        # prepend each buffer with 4 bytes containing its size.
                        buf_offset = offset + 4
                        buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
                        buf[buf_offset : (offset := buf_offset + buf_len)] = buffer

560
        if self.n_remote_reader > 0:
561
            self.remote_socket.send_multipart(all_buffers, copy=False)
562

563
564
    def dequeue(
        self,
565
566
        timeout: float | None = None,
        cancel: Event | None = None,
567
568
569
        indefinite: bool = False,
    ):
        """Read from message queue with optional timeout (in seconds)"""
570
        if self._is_local_reader:
571
            with self.acquire_read(timeout, cancel, indefinite) as buf:
572
573
                overflow = buf[0] == 1
                if not overflow:
574
575
576
577
578
579
580
581
582
                    offset = 3
                    buf_count = from_bytes_big(buf[1:offset])
                    all_buffers = []
                    for i in range(buf_count):
                        buf_offset = offset + 4
                        buf_len = from_bytes_big(buf[offset:buf_offset])
                        offset = buf_offset + buf_len
                        all_buffers.append(buf[buf_offset:offset])
                    obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
583
            if overflow:
584
                obj = MessageQueue.recv(self.local_socket, timeout)
585
        elif self._is_remote_reader:
586
            obj = MessageQueue.recv(self.remote_socket, timeout)
587
588
        else:
            raise RuntimeError("Only readers can dequeue")
589
590
        return obj

591
    @staticmethod
592
    def recv(socket: zmq.Socket, timeout: float | None) -> Any:
593
594
595
        timeout_ms = None if timeout is None else int(timeout * 1000)
        if not socket.poll(timeout=timeout_ms):
            raise TimeoutError
596
597
        recv, *recv_oob = socket.recv_multipart(copy=False)
        return pickle.loads(recv, buffers=recv_oob)
598

599
600
601
602
    def broadcast_object(self, obj=None):
        if self._is_writer:
            self.enqueue(obj)
            return obj
603
        return self.dequeue()
604

605
    @staticmethod
606
    def create_from_process_group(
607
        pg: ProcessGroup | StatelessProcessGroup,
608
609
610
611
        max_chunk_bytes,
        max_chunks,
        writer_rank=0,
    ) -> "MessageQueue":
612
613
614
615
616
617
618
619
        if isinstance(pg, ProcessGroup):
            group_rank = dist.get_rank(pg)
            group_world_size = dist.get_world_size(pg)
            global_ranks = dist.get_process_group_ranks(pg)
        else:
            group_rank = pg.rank
            group_world_size = pg.world_size
            global_ranks = list(range(pg.world_size))
620
621

        from vllm.distributed.parallel_state import in_the_same_node_as
622

623
624
        status = in_the_same_node_as(pg, source_rank=writer_rank)
        same_node_ranks = [i for i, s in enumerate(status) if s]
625
        n_reader = group_world_size - 1
626
627
628
        n_local_reader = len(same_node_ranks) - 1
        local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
        buffer_io: MessageQueue
629
        if group_rank == writer_rank:
630
631
632
633
634
635
636
637
            buffer_io = MessageQueue(
                n_reader=n_reader,
                n_local_reader=n_local_reader,
                local_reader_ranks=local_reader_ranks,
                max_chunk_bytes=max_chunk_bytes,
                max_chunks=max_chunks,
            )
            handle = buffer_io.export_handle()
638
            if isinstance(pg, ProcessGroup):
639
640
641
                dist.broadcast_object_list(
                    [handle], src=global_ranks[writer_rank], group=pg
                )
642
643
            else:
                pg.broadcast_obj(handle, writer_rank)
644
        else:
645
646
            if isinstance(pg, ProcessGroup):
                recv = [None]
647
648
649
                dist.broadcast_object_list(
                    recv, src=global_ranks[writer_rank], group=pg
                )
650
651
652
                handle = recv[0]  # type: ignore
            else:
                handle = pg.broadcast_obj(None, writer_rank)
653
654
655
            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
        buffer_io.wait_until_ready()
        return buffer_io