shm_broadcast.py 26.2 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
import functools
4
5
6
import pickle
import time
from contextlib import contextmanager
7
from dataclasses import dataclass, field
8
from multiprocessing import shared_memory
9
from pickle import PickleBuffer
10
from threading import Event
11
from typing import TYPE_CHECKING, Any
12
13
14
15
from unittest.mock import patch

import torch
import torch.distributed as dist
16
import zmq
17
from torch.distributed import ProcessGroup
18
19
20
21
22
23
24
25
from zmq import (  # type: ignore
    IPV6,  # type: ignore
    SUB,
    SUBSCRIBE,
    XPUB,
    XPUB_VERBOSE,
    Context,
)
26
27

import vllm.envs as envs
28
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
29
from vllm.logger import init_logger
30
from vllm.utils.network_utils import (
31
32
33
34
35
    get_ip,
    get_open_port,
    get_open_zmq_ipc_path,
    is_valid_ipv6_address,
)
36

37
38
39
if TYPE_CHECKING:
    from _typeshed import SizedBuffer

40
41
VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

42
43
44
45
46
47
48
from_bytes_big = functools.partial(int.from_bytes, byteorder="big")


def to_bytes_big(value: int, size: int) -> bytes:
    return value.to_bytes(size, byteorder="big")


49
50
51
logger = init_logger(__name__)


52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class SpinTimer:
    def record_activity(self):
        pass

    def spin(self):
        sched_yield()


class SpinSleepTimer(SpinTimer):
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when vllm does nothing. This would lead to more
    CPU thermal headroom when a request eventually comes, especially when
    multiple GPUs are connected as each GPU would otherwise pin one thread at
    100% CPU usage.

    The simplest solution is to reduce polling frequency when there is no
    activity for a certain period of time.
    """

    def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
        self.last_activity = time.monotonic()
        self.busy_loop_s = busy_loop_s
        self.wait_sleep_s = wait_sleep_s

    def record_activity(self):
        self.last_activity = time.monotonic()

    def spin(self):
        curr_time = time.monotonic()
        if curr_time >= self.last_activity + self.busy_loop_s:
            time.sleep(self.wait_sleep_s)
        else:
            sched_yield()


88
class ShmRingBuffer:
89
90
91
92
93
    def __init__(
        self,
        n_reader: int,
        max_chunk_bytes: int,
        max_chunks: int,
94
        name: str | None = None,
95
    ):
96
97
98
99
100
101
102
        """
        A shared memory ring buffer implementation for broadcast communication.
        Essentially, it is a queue where only one will `enqueue` and multiple
        will `dequeue`. The max size of each item, together with the max number
        of items that can be stored in the buffer are known in advance.
        In this case, we don't need to synchronize the access to
         the buffer.
103

104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
        Buffer memory layout:
                  data                                 metadata
                    |                                      |
                    | (current_idx)                        | (current_idx)
                    v                                      v
        +-------------------------------+----------------------------------------+
        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
        +-------------------------------+----------------------------------------+
        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

        metadata memory layout: each byte is a flag, the first byte is the written
        flag, and the rest are reader flags. The flags are set to 0 by default.
        +--------------+--------------+--------------+-----+--------------+
        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
        +--------------+--------------+--------------+-----+--------------+

120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
        The state of metadata is as follows:

        (case 1) 0???...???: the block is not written yet, cannot read, can write
        (case 2) 1000...000: the block is just written, can read, cannot write
        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

        State transition for readers:

        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
        Only after the caller finishes reading the block, the reader can mark the block as read.
        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

        State transition for writer:

        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

140
141
142
143
        During creation, `name` is None and the buffer is created. We can pass the
        created object to other processes by pickling it. The other processes will
        get the name of the shared memory and open it, so that they can access the
        same shared memory buffer.
144
        """  # noqa
145
146
147
148
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
149
150
151
        self.total_bytes_of_buffer = (
            self.max_chunk_bytes + self.metadata_size
        ) * self.max_chunks
152
153
154
155
156
157
158
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

        if name is None:
            # we are creating a buffer
            self.is_creator = True
            self.shared_memory = shared_memory.SharedMemory(
159
160
                create=True, size=self.total_bytes_of_buffer
            )
161
            # initialize the metadata section to 0
162
            with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer:
163
164
165
166
167
168
169
                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
        else:
            # we are opening an existing buffer
            self.is_creator = False
            # fix to https://stackoverflow.com/q/62748654/9191338
            # Python incorrectly tracks shared memory even if it is not
            # created by the process. The following patch is a workaround.
170
171
172
173
            with patch(
                "multiprocessing.resource_tracker.register",
                lambda *args, **kwargs: None,
            ):
174
175
                try:
                    self.shared_memory = shared_memory.SharedMemory(name=name)
176
177
178
179
180
                    # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
                    # Some platforms allocate memory based on page size,
                    # so the shared memory block size may be larger or equal
                    # to the requested size. The size parameter is ignored
                    # when attaching to an existing block.
181
                    assert self.shared_memory.size >= self.total_bytes_of_buffer
182
183
184
185
186
                except FileNotFoundError:
                    # we might deserialize the object in a different node
                    # in this case, this object is not used,
                    # and we should suppress the error
                    pass
187

188
    def handle(self):
189
190
191
192
193
194
        return (
            self.n_reader,
            self.max_chunk_bytes,
            self.max_chunks,
            self.shared_memory.name,
        )
195

196
197
198
    def __reduce__(self):
        return (
            self.__class__,
199
            self.handle(),
200
201
202
        )

    def __del__(self):
203
204
205
206
        if hasattr(self, "shared_memory"):
            self.shared_memory.close()
            if self.is_creator:
                self.shared_memory.unlink()
207
208
209
210
211

    @contextmanager
    def get_data(self, current_idx: int):
        start = self.data_offset + current_idx * self.max_chunk_bytes
        end = start + self.max_chunk_bytes
212
        with self.shared_memory.buf[start:end] as buf:
213
214
215
216
217
218
            yield buf

    @contextmanager
    def get_metadata(self, current_idx: int):
        start = self.metadata_offset + current_idx * self.metadata_size
        end = start + self.metadata_size
219
        with self.shared_memory.buf[start:end] as buf:
220
221
222
            yield buf


223
224
@dataclass
class Handle:
225
    local_reader_ranks: list[int] = field(default_factory=list)
226

227
228
229
    buffer_handle: tuple[int, int, int, str] | None = None
    local_subscribe_addr: str | None = None
    remote_subscribe_addr: str | None = None
230
    remote_addr_ipv6: bool = False
231
232
233
234
235
236
237


class MessageQueue:
    def __init__(
        self,
        n_reader,  # number of all readers
        n_local_reader,  # number of local readers through shared memory
238
        local_reader_ranks: list[int] | None = None,
239
240
241
        # Default of 24MiB chosen to be large enough to accommodate grammar
        # bitmask tensors for large batches (1024 requests).
        max_chunk_bytes: int = 1024 * 1024 * 24,
242
        max_chunks: int = 10,
243
        connect_ip: str | None = None,
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    ):
        if local_reader_ranks is None:
            local_reader_ranks = list(range(n_local_reader))
        else:
            assert len(local_reader_ranks) == n_local_reader
        self.n_local_reader = n_local_reader
        n_remote_reader = n_reader - n_local_reader
        self.n_remote_reader = n_remote_reader

        context = Context()

        if n_local_reader > 0:
            # for local readers, we will:
            # 1. create a shared memory ring buffer to communicate small data
            # 2. create a publish-subscribe socket to communicate large data
259
            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
260

261
262
263
264
265
266
267
268
            # XPUB is very similar to PUB,
            # except that it can receive subscription messages
            # to confirm the number of subscribers
            self.local_socket = context.socket(XPUB)
            # set the verbose option so that we can receive every subscription
            # message. otherwise, we will only receive the first subscription
            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
            self.local_socket.setsockopt(XPUB_VERBOSE, True)
269
270
271
            local_subscribe_addr = get_open_zmq_ipc_path()
            logger.debug("Binding to %s", local_subscribe_addr)
            self.local_socket.bind(local_subscribe_addr)
272
273
274
275

            self.current_idx = 0
        else:
            self.buffer = None  # type: ignore
276
            local_subscribe_addr = None
277
278
279
            self.local_socket = None
            self.current_idx = -1

280
        remote_addr_ipv6 = False
281
282
283
        if n_remote_reader > 0:
            # for remote readers, we will:
            # create a publish-subscribe socket to communicate large data
284
285
            if not connect_ip:
                connect_ip = get_ip()
286
287
            self.remote_socket = context.socket(XPUB)
            self.remote_socket.setsockopt(XPUB_VERBOSE, True)
288
            remote_subscribe_port = get_open_port()
289
290
            if is_valid_ipv6_address(connect_ip):
                self.remote_socket.setsockopt(IPV6, 1)
291
                remote_addr_ipv6 = True
292
                connect_ip = f"[{connect_ip}]"
293
            socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
294
            self.remote_socket.bind(socket_addr)
295
            remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
296
        else:
297
            remote_subscribe_addr = None
298
299
300
301
302
303
304
            self.remote_socket = None

        self._is_writer = True
        self._is_local_reader = False
        self.local_reader_rank = -1
        # rank does not matter for remote readers
        self._is_remote_reader = False
305
        self._read_spin_timer = SpinTimer()
306
307
308

        self.handle = Handle(
            local_reader_ranks=local_reader_ranks,
309
            buffer_handle=self.buffer.handle() if self.buffer is not None else None,
310
311
312
            local_subscribe_addr=local_subscribe_addr,
            remote_subscribe_addr=remote_subscribe_addr,
            remote_addr_ipv6=remote_addr_ipv6,
313
314
        )

315
316
        logger.info("vLLM message queue communication handle: %s", self.handle)

317
318
319
320
321
322
323
324
325
326
327
328
    def export_handle(self) -> Handle:
        return self.handle

    @staticmethod
    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
        self = MessageQueue.__new__(MessageQueue)
        self.handle = handle
        self._is_writer = False

        context = Context()

        if rank in handle.local_reader_ranks:
329
330
            assert handle.buffer_handle is not None
            self.buffer = ShmRingBuffer(*handle.buffer_handle)
331
332
333
334
335
336
337
            self.current_idx = 0
            self.local_reader_rank = handle.local_reader_ranks.index(rank)
            self._is_local_reader = True
            self._is_remote_reader = False

            self.local_socket = context.socket(SUB)
            self.local_socket.setsockopt_string(SUBSCRIBE, "")
338
            socket_addr = handle.local_subscribe_addr
339
340
            logger.debug("Connecting to %s", socket_addr)
            self.local_socket.connect(socket_addr)
341
342

            self.remote_socket = None
343

344
345
346
            self._read_spin_timer = (
                SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
            )
347
348
349
350
351
352
353
354
355
356
357
        else:
            self.buffer = None  # type: ignore
            self.current_idx = -1
            self.local_reader_rank = -1
            self._is_local_reader = False
            self._is_remote_reader = True

            self.local_socket = None

            self.remote_socket = context.socket(SUB)
            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
358
            if handle.remote_addr_ipv6:
359
                self.remote_socket.setsockopt(IPV6, 1)
360
            socket_addr = handle.remote_subscribe_addr
361
362
            logger.debug("Connecting to %s", socket_addr)
            self.remote_socket.connect(socket_addr)
363
364
365
366
367
368
369
370
371
372
373
374

        return self

    def wait_until_ready(self):
        """This is a collective operation. All processes (including the
        readers and the writer) should call this function.
        """
        if self._is_writer:
            # wait for all readers to connect

            # local readers
            for i in range(self.n_local_reader):
375
376
                # wait for subscription messages from all local readers
                self.local_socket.recv()
377
            if self.n_local_reader > 0:
378
379
                # send a message to all local readers
                # to make sure the publish channel is working
380
381
382
383
                self.local_socket.send(b"READY")

            # remote readers
            for i in range(self.n_remote_reader):
384
385
                # wait for subscription messages from all remote readers
                self.remote_socket.recv()
386
            if self.n_remote_reader > 0:
387
388
                # send a message to all remote readers
                # to make sure the publish channel is working
389
390
                self.remote_socket.send(b"READY")
        elif self._is_local_reader:
391
            # wait for the writer to send a message
392
393
394
            recv = self.local_socket.recv()
            assert recv == b"READY"
        elif self._is_remote_reader:
395
            # wait for the writer to send a message
396
397
            recv = self.remote_socket.recv()
            assert recv == b"READY"
398
399

    @contextmanager
400
    def acquire_write(self, timeout: float | None = None):
401
        assert self._is_writer, "Only writers can acquire write"
402
        start_time = time.monotonic()
403
404
405
406
407
408
409
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
                read_count = sum(metadata_buffer[1:])
                written_flag = metadata_buffer[0]
                if written_flag and read_count != self.buffer.n_reader:
                    # this block is written and not read by all readers
410
411
412
413
                    # for writers, `self.current_idx` is the next block to write
                    # if this block is not ready to write,
                    # we need to wait until it is read by all readers

414
                    # Release the processor to other threads
415
                    sched_yield()
416

417
418
419
420
421
                    # if we time out, raise an exception
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
                        raise TimeoutError

422
                    # if we wait for a long time, log a message
423
                    if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
424
                        logger.info(
425
426
427
428
                            "No available shared memory broadcast block found"
                            " in %s seconds. This typically happens when some"
                            " processes are hanging or doing some"
                            " time-consuming work (e.g. compilation)",
429
430
                            VLLM_RINGBUFFER_WARNING_INTERVAL,
                        )
431
432
                        n_warning += 1

433
434
435
436
437
438
439
440
441
442
443
444
                    continue
                # found a block that is either
                # (1) not written
                # (2) read by all readers

                # mark the block as not written
                metadata_buffer[0] = 0
                # let caller write to the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has written to the buffer
445
446
447
448
                # NOTE: order is important here
                # first set the read flags to 0
                # then set the written flag to 1
                # otherwise, the readers may think they already read the block
449
450
451
                for i in range(1, self.buffer.n_reader + 1):
                    # set read flag to 0, meaning it is not read yet
                    metadata_buffer[i] = 0
452
453
                # mark the block as written
                metadata_buffer[0] = 1
454
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
455
456
457
                break

    @contextmanager
458
459
    def acquire_read(
        self,
460
461
        timeout: float | None = None,
        cancel: Event | None = None,
462
463
        indefinite: bool = False,
    ):
464
        assert self._is_local_reader, "Only readers can acquire read"
465
        start_time = time.monotonic()
466
467
468
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
469
                read_flag = metadata_buffer[self.local_reader_rank + 1]
470
471
472
473
474
                written_flag = metadata_buffer[0]
                if not written_flag or read_flag:
                    # this block is either
                    # (1) not written
                    # (2) already read by this reader
475
476
477
478
479

                    # for readers, `self.current_idx` is the next block to read
                    # if this block is not ready,
                    # we need to wait until it is written

480
                    # Release the processor to other threads
481
                    self._read_spin_timer.spin()
482

483
484
485
                    if cancel is not None and cancel.is_set():
                        raise RuntimeError("cancelled")

486
                    # if we time out, raise an exception
487
488
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
489
490
                        raise TimeoutError

491
                    # if we wait for a long time, log a message
492
493
494
                    if not indefinite and (
                        elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
                    ):
495
496
497
498
499
                        logger.info(
                            "No available shared memory broadcast block found"
                            " in %s seconds. This typically happens when some"
                            " processes are hanging or doing some"
                            " time-consuming work (e.g. compilation).",
500
501
                            VLLM_RINGBUFFER_WARNING_INTERVAL,
                        )
502
503
                        n_warning += 1

504
505
506
507
508
509
510
511
                    continue
                # found a block that is not read by this reader
                # let caller read from the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has read from the buffer
                # set the read flag
512
                metadata_buffer[self.local_reader_rank + 1] = 1
513
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
514
515

                self._read_spin_timer.record_activity()
516
517
                break

518
    def enqueue(self, obj, timeout: float | None = None):
519
        """Write to message queue with optional timeout (in seconds)"""
520
        assert self._is_writer, "Only writers can enqueue"
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
        all_buffers: list[SizedBuffer] = [b""]
        total_bytes = 6  # 2 bytes for oob buffer count, 4 for main buffer size

        def oob_callback(buf: PickleBuffer) -> bool:
            raw_buf = buf.raw()
            if len(raw_buf) < 1024 * 1024:
                # In-line buffers smaller than 1MiB.
                return True
            all_buffers.append(raw_buf)
            nonlocal total_bytes
            total_bytes += len(raw_buf) + 4
            return False

        all_buffers[0] = pickle.dumps(
            obj, protocol=pickle.HIGHEST_PROTOCOL, buffer_callback=oob_callback
        )
537
        if self.n_local_reader > 0:
538
            if total_bytes + len(all_buffers[0]) >= self.buffer.max_chunk_bytes:
539
                with self.acquire_write(timeout) as buf:
540
                    buf[0] = 1  # overflow
541
                self.local_socket.send_multipart(all_buffers, copy=False)
542
            else:
543
544
545
546
                # Byte 0: 0
                # Bytes 1-2: Count of buffers
                # Then each buffer follows, preceded by 4 bytes containing its length:
                # [4 byte int L][L bytes of buffer content] ...
547
                with self.acquire_write(timeout) as buf:
548
                    buf[0] = 0  # not overflow
549
550
551
552
553
554
555
556
557
                    offset = 3
                    buf[1:offset] = to_bytes_big(len(all_buffers), 2)  # oob buf count
                    for buffer in all_buffers:
                        buf_len = len(buffer)
                        # prepend each buffer with 4 bytes containing its size.
                        buf_offset = offset + 4
                        buf[offset:buf_offset] = to_bytes_big(buf_len, 4)
                        buf[buf_offset : (offset := buf_offset + buf_len)] = buffer

558
        if self.n_remote_reader > 0:
559
            self.remote_socket.send_multipart(all_buffers, copy=False)
560

561
562
    def dequeue(
        self,
563
564
        timeout: float | None = None,
        cancel: Event | None = None,
565
566
567
        indefinite: bool = False,
    ):
        """Read from message queue with optional timeout (in seconds)"""
568
        if self._is_local_reader:
569
            with self.acquire_read(timeout, cancel, indefinite) as buf:
570
571
                overflow = buf[0] == 1
                if not overflow:
572
573
574
575
576
577
578
579
580
                    offset = 3
                    buf_count = from_bytes_big(buf[1:offset])
                    all_buffers = []
                    for i in range(buf_count):
                        buf_offset = offset + 4
                        buf_len = from_bytes_big(buf[offset:buf_offset])
                        offset = buf_offset + buf_len
                        all_buffers.append(buf[buf_offset:offset])
                    obj = pickle.loads(all_buffers[0], buffers=all_buffers[1:])
581
            if overflow:
582
                obj = MessageQueue.recv(self.local_socket, timeout)
583
        elif self._is_remote_reader:
584
            obj = MessageQueue.recv(self.remote_socket, timeout)
585
586
        else:
            raise RuntimeError("Only readers can dequeue")
587
588
        return obj

589
    @staticmethod
590
    def recv(socket: zmq.Socket, timeout: float | None) -> Any:
591
592
593
        timeout_ms = None if timeout is None else int(timeout * 1000)
        if not socket.poll(timeout=timeout_ms):
            raise TimeoutError
594
595
        recv, *recv_oob = socket.recv_multipart(copy=False)
        return pickle.loads(recv, buffers=recv_oob)
596

597
598
599
600
    def broadcast_object(self, obj=None):
        if self._is_writer:
            self.enqueue(obj)
            return obj
601
        return self.dequeue()
602

603
    @staticmethod
604
    def create_from_process_group(
605
        pg: ProcessGroup | StatelessProcessGroup,
606
607
608
609
        max_chunk_bytes,
        max_chunks,
        writer_rank=0,
    ) -> "MessageQueue":
610
611
612
613
614
615
616
617
        if isinstance(pg, ProcessGroup):
            group_rank = dist.get_rank(pg)
            group_world_size = dist.get_world_size(pg)
            global_ranks = dist.get_process_group_ranks(pg)
        else:
            group_rank = pg.rank
            group_world_size = pg.world_size
            global_ranks = list(range(pg.world_size))
618
619

        from vllm.distributed.parallel_state import in_the_same_node_as
620

621
622
        status = in_the_same_node_as(pg, source_rank=writer_rank)
        same_node_ranks = [i for i, s in enumerate(status) if s]
623
        n_reader = group_world_size - 1
624
625
626
        n_local_reader = len(same_node_ranks) - 1
        local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
        buffer_io: MessageQueue
627
        if group_rank == writer_rank:
628
629
630
631
632
633
634
635
            buffer_io = MessageQueue(
                n_reader=n_reader,
                n_local_reader=n_local_reader,
                local_reader_ranks=local_reader_ranks,
                max_chunk_bytes=max_chunk_bytes,
                max_chunks=max_chunks,
            )
            handle = buffer_io.export_handle()
636
            if isinstance(pg, ProcessGroup):
637
638
639
                dist.broadcast_object_list(
                    [handle], src=global_ranks[writer_rank], group=pg
                )
640
641
            else:
                pg.broadcast_obj(handle, writer_rank)
642
        else:
643
644
            if isinstance(pg, ProcessGroup):
                recv = [None]
645
646
647
                dist.broadcast_object_list(
                    recv, src=global_ranks[writer_rank], group=pg
                )
648
649
650
                handle = recv[0]  # type: ignore
            else:
                handle = pg.broadcast_obj(None, writer_rank)
651
652
653
            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
        buffer_io.wait_until_ready()
        return buffer_io