shm_broadcast.py 23 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
4
5
import pickle
import time
from contextlib import contextmanager
6
from dataclasses import dataclass, field
7
from multiprocessing import shared_memory
8
from threading import Event
9
from typing import Any, Optional, Union
10
11
12
13
from unittest.mock import patch

import torch
import torch.distributed as dist
14
import zmq
15
from torch.distributed import ProcessGroup
16
from zmq import IPV6  # type: ignore
17
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context  # type: ignore
18
19

import vllm.envs as envs
20
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
21
from vllm.logger import init_logger
22
23
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
                        is_valid_ipv6_address)
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60

VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

logger = init_logger(__name__)


class ShmRingBuffer:

    def __init__(self,
                 n_reader: int,
                 max_chunk_bytes: int,
                 max_chunks: int,
                 name: Optional[str] = None):
        """
        A shared memory ring buffer implementation for broadcast communication.
        Essentially, it is a queue where only one will `enqueue` and multiple
        will `dequeue`. The max size of each item, together with the max number
        of items that can be stored in the buffer are known in advance.
        In this case, we don't need to synchronize the access to
         the buffer.
        
        Buffer memory layout:
                  data                                 metadata
                    |                                      |
                    | (current_idx)                        | (current_idx)
                    v                                      v
        +-------------------------------+----------------------------------------+
        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
        +-------------------------------+----------------------------------------+
        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

        metadata memory layout: each byte is a flag, the first byte is the written
        flag, and the rest are reader flags. The flags are set to 0 by default.
        +--------------+--------------+--------------+-----+--------------+
        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
        +--------------+--------------+--------------+-----+--------------+

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
        The state of metadata is as follows:

        (case 1) 0???...???: the block is not written yet, cannot read, can write
        (case 2) 1000...000: the block is just written, can read, cannot write
        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

        State transition for readers:

        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
        Only after the caller finishes reading the block, the reader can mark the block as read.
        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

        State transition for writer:

        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
        During creation, `name` is None and the buffer is created. We can pass the
        created object to other processes by pickling it. The other processes will
        get the name of the shared memory and open it, so that they can access the
        same shared memory buffer.
        """# noqa
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
        self.total_bytes_of_buffer = (self.max_chunk_bytes +
                                      self.metadata_size) * self.max_chunks
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

        if name is None:
            # we are creating a buffer
            self.is_creator = True
            self.shared_memory = shared_memory.SharedMemory(
                create=True, size=self.total_bytes_of_buffer)
            # initialize the metadata section to 0
            with memoryview(self.shared_memory.buf[self.metadata_offset:]
                            ) as metadata_buffer:
                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
        else:
            # we are opening an existing buffer
            self.is_creator = False
            # fix to https://stackoverflow.com/q/62748654/9191338
            # Python incorrectly tracks shared memory even if it is not
            # created by the process. The following patch is a workaround.
            with patch("multiprocessing.resource_tracker.register",
                       lambda *args, **kwargs: None):
112
113
                try:
                    self.shared_memory = shared_memory.SharedMemory(name=name)
114
115
116
117
118
119
120
                    # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
                    # Some platforms allocate memory based on page size,
                    # so the shared memory block size may be larger or equal
                    # to the requested size. The size parameter is ignored
                    # when attaching to an existing block.
                    assert (self.shared_memory.size
                            >= self.total_bytes_of_buffer)
121
122
123
124
125
                except FileNotFoundError:
                    # we might deserialize the object in a different node
                    # in this case, this object is not used,
                    # and we should suppress the error
                    pass
126

127
128
129
130
    def handle(self):
        return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
                self.shared_memory.name)

131
132
133
    def __reduce__(self):
        return (
            self.__class__,
134
            self.handle(),
135
136
137
        )

    def __del__(self):
138
139
140
141
        if hasattr(self, "shared_memory"):
            self.shared_memory.close()
            if self.is_creator:
                self.shared_memory.unlink()
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157

    @contextmanager
    def get_data(self, current_idx: int):
        start = self.data_offset + current_idx * self.max_chunk_bytes
        end = start + self.max_chunk_bytes
        with memoryview(self.shared_memory.buf[start:end]) as buf:
            yield buf

    @contextmanager
    def get_metadata(self, current_idx: int):
        start = self.metadata_offset + current_idx * self.metadata_size
        end = start + self.metadata_size
        with memoryview(self.shared_memory.buf[start:end]) as buf:
            yield buf


158
159
@dataclass
class Handle:
160
    local_reader_ranks: list[int] = field(default_factory=list)
161

162
    buffer_handle: Optional[tuple[int, int, int, str]] = None
163
164
165
    local_subscribe_addr: Optional[str] = None
    remote_subscribe_addr: Optional[str] = None
    remote_addr_ipv6: bool = False
166
167
168
169
170
171
172
173


class MessageQueue:

    def __init__(
        self,
        n_reader,  # number of all readers
        n_local_reader,  # number of local readers through shared memory
174
        local_reader_ranks: Optional[list[int]] = None,
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
        max_chunk_bytes: int = 1024 * 1024 * 10,
        max_chunks: int = 10,
        connect_ip: Optional[str] = None,
    ):
        if local_reader_ranks is None:
            local_reader_ranks = list(range(n_local_reader))
        else:
            assert len(local_reader_ranks) == n_local_reader
        self.n_local_reader = n_local_reader
        n_remote_reader = n_reader - n_local_reader
        self.n_remote_reader = n_remote_reader

        context = Context()

        if n_local_reader > 0:
            # for local readers, we will:
            # 1. create a shared memory ring buffer to communicate small data
            # 2. create a publish-subscribe socket to communicate large data
            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
                                        max_chunks)

196
197
198
199
200
201
202
203
            # XPUB is very similar to PUB,
            # except that it can receive subscription messages
            # to confirm the number of subscribers
            self.local_socket = context.socket(XPUB)
            # set the verbose option so that we can receive every subscription
            # message. otherwise, we will only receive the first subscription
            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
            self.local_socket.setsockopt(XPUB_VERBOSE, True)
204
205
206
            local_subscribe_addr = get_open_zmq_ipc_path()
            logger.debug("Binding to %s", local_subscribe_addr)
            self.local_socket.bind(local_subscribe_addr)
207
208
209
210

            self.current_idx = 0
        else:
            self.buffer = None  # type: ignore
211
            local_subscribe_addr = None
212
213
214
            self.local_socket = None
            self.current_idx = -1

215
        remote_addr_ipv6 = False
216
217
218
        if n_remote_reader > 0:
            # for remote readers, we will:
            # create a publish-subscribe socket to communicate large data
219
220
            if not connect_ip:
                connect_ip = get_ip()
221
222
            self.remote_socket = context.socket(XPUB)
            self.remote_socket.setsockopt(XPUB_VERBOSE, True)
223
            remote_subscribe_port = get_open_port()
224
225
            if is_valid_ipv6_address(connect_ip):
                self.remote_socket.setsockopt(IPV6, 1)
226
                remote_addr_ipv6 = True
227
                connect_ip = f"[{connect_ip}]"
228
            socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
229
            self.remote_socket.bind(socket_addr)
230
            remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
231
        else:
232
            remote_subscribe_addr = None
233
234
235
236
237
238
239
240
241
242
            self.remote_socket = None

        self._is_writer = True
        self._is_local_reader = False
        self.local_reader_rank = -1
        # rank does not matter for remote readers
        self._is_remote_reader = False

        self.handle = Handle(
            local_reader_ranks=local_reader_ranks,
243
244
            buffer_handle=self.buffer.handle()
            if self.buffer is not None else None,
245
246
247
            local_subscribe_addr=local_subscribe_addr,
            remote_subscribe_addr=remote_subscribe_addr,
            remote_addr_ipv6=remote_addr_ipv6,
248
249
        )

250
251
        logger.info("vLLM message queue communication handle: %s", self.handle)

252
253
254
255
256
257
258
259
260
261
262
263
    def export_handle(self) -> Handle:
        return self.handle

    @staticmethod
    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
        self = MessageQueue.__new__(MessageQueue)
        self.handle = handle
        self._is_writer = False

        context = Context()

        if rank in handle.local_reader_ranks:
264
265
            assert handle.buffer_handle is not None
            self.buffer = ShmRingBuffer(*handle.buffer_handle)
266
267
268
269
270
271
272
            self.current_idx = 0
            self.local_reader_rank = handle.local_reader_ranks.index(rank)
            self._is_local_reader = True
            self._is_remote_reader = False

            self.local_socket = context.socket(SUB)
            self.local_socket.setsockopt_string(SUBSCRIBE, "")
273
            socket_addr = handle.local_subscribe_addr
274
275
            logger.debug("Connecting to %s", socket_addr)
            self.local_socket.connect(socket_addr)
276
277
278
279
280
281
282
283
284
285
286
287
288

            self.remote_socket = None
        else:
            self.buffer = None  # type: ignore
            self.current_idx = -1
            self.local_reader_rank = -1
            self._is_local_reader = False
            self._is_remote_reader = True

            self.local_socket = None

            self.remote_socket = context.socket(SUB)
            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
289
            if handle.remote_addr_ipv6:
290
                self.remote_socket.setsockopt(IPV6, 1)
291
            socket_addr = handle.remote_subscribe_addr
292
293
            logger.debug("Connecting to %s", socket_addr)
            self.remote_socket.connect(socket_addr)
294
295
296
297
298
299
300
301
302
303
304
305

        return self

    def wait_until_ready(self):
        """This is a collective operation. All processes (including the
        readers and the writer) should call this function.
        """
        if self._is_writer:
            # wait for all readers to connect

            # local readers
            for i in range(self.n_local_reader):
306
307
                # wait for subscription messages from all local readers
                self.local_socket.recv()
308
            if self.n_local_reader > 0:
309
310
                # send a message to all local readers
                # to make sure the publish channel is working
311
312
313
314
                self.local_socket.send(b"READY")

            # remote readers
            for i in range(self.n_remote_reader):
315
316
                # wait for subscription messages from all remote readers
                self.remote_socket.recv()
317
            if self.n_remote_reader > 0:
318
319
                # send a message to all remote readers
                # to make sure the publish channel is working
320
321
                self.remote_socket.send(b"READY")
        elif self._is_local_reader:
322
            # wait for the writer to send a message
323
324
325
            recv = self.local_socket.recv()
            assert recv == b"READY"
        elif self._is_remote_reader:
326
            # wait for the writer to send a message
327
328
            recv = self.remote_socket.recv()
            assert recv == b"READY"
329
330

    @contextmanager
331
    def acquire_write(self, timeout: Optional[float] = None):
332
        assert self._is_writer, "Only writers can acquire write"
333
        start_time = time.monotonic()
334
335
336
337
338
339
340
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
                read_count = sum(metadata_buffer[1:])
                written_flag = metadata_buffer[0]
                if written_flag and read_count != self.buffer.n_reader:
                    # this block is written and not read by all readers
341
342
343
344
                    # for writers, `self.current_idx` is the next block to write
                    # if this block is not ready to write,
                    # we need to wait until it is read by all readers

345
                    # Release the processor to other threads
346
                    sched_yield()
347

348
                    # if we wait for a long time, log a message
349
350
                    if (time.monotonic() - start_time
                            > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
351
352
353
354
355
                        logger.debug(
                            ("No available shared memory broadcast block found"
                             " in %s second."),
                            VLLM_RINGBUFFER_WARNING_INTERVAL,
                        )
356
357
                        n_warning += 1

358
359
360
361
362
                    # if we time out, raise an exception
                    if (timeout is not None
                            and time.monotonic() - start_time > timeout):
                        raise TimeoutError

363
364
365
366
367
368
369
370
371
372
373
374
                    continue
                # found a block that is either
                # (1) not written
                # (2) read by all readers

                # mark the block as not written
                metadata_buffer[0] = 0
                # let caller write to the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has written to the buffer
375
376
377
378
                # NOTE: order is important here
                # first set the read flags to 0
                # then set the written flag to 1
                # otherwise, the readers may think they already read the block
379
380
381
                for i in range(1, self.buffer.n_reader + 1):
                    # set read flag to 0, meaning it is not read yet
                    metadata_buffer[i] = 0
382
383
                # mark the block as written
                metadata_buffer[0] = 1
384
385
                self.current_idx = (self.current_idx +
                                    1) % self.buffer.max_chunks
386
387
388
                break

    @contextmanager
389
390
391
    def acquire_read(self,
                     timeout: Optional[float] = None,
                     cancel: Optional[Event] = None):
392
        assert self._is_local_reader, "Only readers can acquire read"
393
        start_time = time.monotonic()
394
395
396
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
397
                read_flag = metadata_buffer[self.local_reader_rank + 1]
398
399
400
401
402
                written_flag = metadata_buffer[0]
                if not written_flag or read_flag:
                    # this block is either
                    # (1) not written
                    # (2) already read by this reader
403
404
405
406
407

                    # for readers, `self.current_idx` is the next block to read
                    # if this block is not ready,
                    # we need to wait until it is written

408
                    # Release the processor to other threads
409
                    sched_yield()
410

411
                    # if we wait for a long time, log a message
412
413
                    if (time.monotonic() - start_time
                            > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
414
415
                        logger.debug(
                            ("No available shared memory broadcast block found"
416
                             " in %s second."),
417
418
                            VLLM_RINGBUFFER_WARNING_INTERVAL,
                        )
419
420
                        n_warning += 1

421
422
423
                    if cancel is not None and cancel.is_set():
                        raise RuntimeError("cancelled")

424
425
426
427
428
                    # if we time out, raise an exception
                    if (timeout is not None
                            and time.monotonic() - start_time > timeout):
                        raise TimeoutError

429
430
431
432
433
434
435
436
                    continue
                # found a block that is not read by this reader
                # let caller read from the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has read from the buffer
                # set the read flag
437
                metadata_buffer[self.local_reader_rank + 1] = 1
438
439
                self.current_idx = (self.current_idx +
                                    1) % self.buffer.max_chunks
440
441
                break

442
443
    def enqueue(self, obj, timeout: Optional[float] = None):
        """ Write to message queue with optional timeout (in seconds) """
444
445
        assert self._is_writer, "Only writers can enqueue"
        serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
446
447
        if self.n_local_reader > 0:
            if len(serialized_obj) >= self.buffer.max_chunk_bytes:
448
                with self.acquire_write(timeout) as buf:
449
450
451
                    buf[0] = 1  # overflow
                self.local_socket.send(serialized_obj)
            else:
452
                with self.acquire_write(timeout) as buf:
453
454
455
456
                    buf[0] = 0  # not overflow
                    buf[1:len(serialized_obj) + 1] = serialized_obj
        if self.n_remote_reader > 0:
            self.remote_socket.send(serialized_obj)
457

458
459
460
    def dequeue(self,
                timeout: Optional[float] = None,
                cancel: Optional[Event] = None):
461
        """ Read from message queue with optional timeout (in seconds) """
462
        if self._is_local_reader:
463
            with self.acquire_read(timeout, cancel) as buf:
464
465
466
467
468
469
470
                overflow = buf[0] == 1
                if not overflow:
                    # no need to know the size of serialized object
                    # pickle format contains the size information internally
                    # see https://docs.python.org/3/library/pickle.html
                    obj = pickle.loads(buf[1:])
            if overflow:
471
                obj = MessageQueue.recv(self.local_socket, timeout)
472
        elif self._is_remote_reader:
473
            obj = MessageQueue.recv(self.remote_socket, timeout)
474
475
        else:
            raise RuntimeError("Only readers can dequeue")
476
477
        return obj

478
479
480
481
482
483
484
485
    @staticmethod
    def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
        timeout_ms = None if timeout is None else int(timeout * 1000)
        if not socket.poll(timeout=timeout_ms):
            raise TimeoutError
        recv = socket.recv(copy=False)
        return pickle.loads(recv.buffer)

486
487
488
489
490
491
492
    def broadcast_object(self, obj=None):
        if self._is_writer:
            self.enqueue(obj)
            return obj
        else:
            return self.dequeue()

493
    @staticmethod
494
495
    def create_from_process_group(pg: Union[ProcessGroup,
                                            StatelessProcessGroup],
496
497
                                  max_chunk_bytes,
                                  max_chunks,
498
                                  writer_rank=0) -> "MessageQueue":
499
500
501
502
503
504
505
506
        if isinstance(pg, ProcessGroup):
            group_rank = dist.get_rank(pg)
            group_world_size = dist.get_world_size(pg)
            global_ranks = dist.get_process_group_ranks(pg)
        else:
            group_rank = pg.rank
            group_world_size = pg.world_size
            global_ranks = list(range(pg.world_size))
507
508
509
510

        from vllm.distributed.parallel_state import in_the_same_node_as
        status = in_the_same_node_as(pg, source_rank=writer_rank)
        same_node_ranks = [i for i, s in enumerate(status) if s]
511
        n_reader = group_world_size - 1
512
513
514
        n_local_reader = len(same_node_ranks) - 1
        local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
        buffer_io: MessageQueue
515
        if group_rank == writer_rank:
516
517
518
519
520
521
522
523
            buffer_io = MessageQueue(
                n_reader=n_reader,
                n_local_reader=n_local_reader,
                local_reader_ranks=local_reader_ranks,
                max_chunk_bytes=max_chunk_bytes,
                max_chunks=max_chunks,
            )
            handle = buffer_io.export_handle()
524
525
526
527
528
529
            if isinstance(pg, ProcessGroup):
                dist.broadcast_object_list([handle],
                                           src=global_ranks[writer_rank],
                                           group=pg)
            else:
                pg.broadcast_obj(handle, writer_rank)
530
        else:
531
532
533
534
535
536
537
538
            if isinstance(pg, ProcessGroup):
                recv = [None]
                dist.broadcast_object_list(recv,
                                           src=global_ranks[writer_rank],
                                           group=pg)
                handle = recv[0]  # type: ignore
            else:
                handle = pg.broadcast_obj(None, writer_rank)
539
540
541
            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
        buffer_io.wait_until_ready()
        return buffer_io