"tests/entrypoints/openai/test_vision.py" did not exist on "2061f0b8a7f1a01683c4045096a092eedf6387a4"
shm_broadcast.py 20.5 KB
Newer Older
1
2
3
import pickle
import time
from contextlib import contextmanager
4
from dataclasses import dataclass, field
5
from multiprocessing import shared_memory
6
from typing import List, Optional
7
8
9
10
11
from unittest.mock import patch

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
12
from zmq import PUB, REP, REQ, SUB, SUBSCRIBE, Context  # type: ignore
13
14
15

import vllm.envs as envs
from vllm.logger import init_logger
16
from vllm.utils import get_ip, get_open_port
17
18
19

VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

20
21
22
23
24
25
# time to wait if the queue is full or empty
# if we sleep for too short, it will consume too much CPU
# if we sleep for too long, it will slow down the writer/reader
# 0.1 us is a good balance
RINGBUFFER_SLEEP_INTERVAL = 1e-7

26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
logger = init_logger(__name__)


class ShmRingBuffer:

    def __init__(self,
                 n_reader: int,
                 max_chunk_bytes: int,
                 max_chunks: int,
                 name: Optional[str] = None):
        """
        A shared memory ring buffer implementation for broadcast communication.
        Essentially, it is a queue where only one will `enqueue` and multiple
        will `dequeue`. The max size of each item, together with the max number
        of items that can be stored in the buffer are known in advance.
        In this case, we don't need to synchronize the access to
         the buffer.
        
        Buffer memory layout:
                  data                                 metadata
                    |                                      |
                    | (current_idx)                        | (current_idx)
                    v                                      v
        +-------------------------------+----------------------------------------+
        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
        +-------------------------------+----------------------------------------+
        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

        metadata memory layout: each byte is a flag, the first byte is the written
        flag, and the rest are reader flags. The flags are set to 0 by default.
        +--------------+--------------+--------------+-----+--------------+
        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
        +--------------+--------------+--------------+-----+--------------+

60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
        The state of metadata is as follows:

        (case 1) 0???...???: the block is not written yet, cannot read, can write
        (case 2) 1000...000: the block is just written, can read, cannot write
        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

        State transition for readers:

        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
        Only after the caller finishes reading the block, the reader can mark the block as read.
        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

        State transition for writer:

        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
        During creation, `name` is None and the buffer is created. We can pass the
        created object to other processes by pickling it. The other processes will
        get the name of the shared memory and open it, so that they can access the
        same shared memory buffer.
        """# noqa
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
        self.total_bytes_of_buffer = (self.max_chunk_bytes +
                                      self.metadata_size) * self.max_chunks
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

        if name is None:
            # we are creating a buffer
            self.is_creator = True
            self.shared_memory = shared_memory.SharedMemory(
                create=True, size=self.total_bytes_of_buffer)
            # initialize the metadata section to 0
            with memoryview(self.shared_memory.buf[self.metadata_offset:]
                            ) as metadata_buffer:
                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
        else:
            # we are opening an existing buffer
            self.is_creator = False
            # fix to https://stackoverflow.com/q/62748654/9191338
            # Python incorrectly tracks shared memory even if it is not
            # created by the process. The following patch is a workaround.
            with patch("multiprocessing.resource_tracker.register",
                       lambda *args, **kwargs: None):
111
112
                try:
                    self.shared_memory = shared_memory.SharedMemory(name=name)
113
114
                    assert (
                        self.shared_memory.size == self.total_bytes_of_buffer)
115
116
117
118
119
                except FileNotFoundError:
                    # we might deserialize the object in a different node
                    # in this case, this object is not used,
                    # and we should suppress the error
                    pass
120
121
122
123
124
125
126
127
128

    def __reduce__(self):
        return (
            self.__class__,
            (self.n_reader, self.max_chunk_bytes, self.max_chunks,
             self.shared_memory.name),
        )

    def __del__(self):
129
130
131
132
        if hasattr(self, "shared_memory"):
            self.shared_memory.close()
            if self.is_creator:
                self.shared_memory.unlink()
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148

    @contextmanager
    def get_data(self, current_idx: int):
        start = self.data_offset + current_idx * self.max_chunk_bytes
        end = start + self.max_chunk_bytes
        with memoryview(self.shared_memory.buf[start:end]) as buf:
            yield buf

    @contextmanager
    def get_metadata(self, current_idx: int):
        start = self.metadata_offset + current_idx * self.metadata_size
        end = start + self.metadata_size
        with memoryview(self.shared_memory.buf[start:end]) as buf:
            yield buf


149
150
151
152
@dataclass
class Handle:
    connect_ip: str
    local_reader_ranks: List[int] = field(default_factory=list)
153

154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
    buffer: Optional[ShmRingBuffer] = None
    local_subscribe_port: Optional[int] = None
    local_sync_port: Optional[int] = None
    remote_subscribe_port: Optional[int] = None
    remote_sync_port: Optional[int] = None


class MessageQueue:

    def __init__(
        self,
        n_reader,  # number of all readers
        n_local_reader,  # number of local readers through shared memory
        local_reader_ranks: Optional[List[int]] = None,
        max_chunk_bytes: int = 1024 * 1024 * 10,
        max_chunks: int = 10,
        connect_ip: Optional[str] = None,
    ):
        if local_reader_ranks is None:
            local_reader_ranks = list(range(n_local_reader))
        else:
            assert len(local_reader_ranks) == n_local_reader
        self.n_local_reader = n_local_reader
        n_remote_reader = n_reader - n_local_reader
        self.n_remote_reader = n_remote_reader

        if connect_ip is None:
181
            connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240

        context = Context()

        if n_local_reader > 0:
            # for local readers, we will:
            # 1. create a shared memory ring buffer to communicate small data
            # 2. create a publish-subscribe socket to communicate large data
            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
                                        max_chunks)

            self.local_socket = context.socket(PUB)
            local_subscribe_port = get_open_port()
            self.local_socket.bind(f"tcp://*:{local_subscribe_port}")

            self.local_sync_socket = context.socket(REP)
            local_sync_port = get_open_port()
            self.local_sync_socket.bind(f"tcp://*:{local_sync_port}")
            self.current_idx = 0

        else:
            self.buffer = None  # type: ignore
            local_subscribe_port = None
            local_sync_port = None
            self.local_socket = None
            self.local_sync_socket = None
            self.current_idx = -1

        if n_remote_reader > 0:
            # for remote readers, we will:
            # create a publish-subscribe socket to communicate large data
            self.remote_socket = context.socket(PUB)
            remote_subscribe_port = get_open_port()
            self.remote_socket.bind(f"tcp://*:{remote_subscribe_port}")

            self.remote_sync_socket = context.socket(REP)
            remote_sync_port = get_open_port()
            self.remote_sync_socket.bind(f"tcp://*:{remote_sync_port}")
        else:
            remote_subscribe_port = None
            remote_sync_port = None
            self.remote_socket = None
            self.remote_sync_socket = None

        self._is_writer = True
        self._is_local_reader = False
        self.local_reader_rank = -1
        # rank does not matter for remote readers
        self._is_remote_reader = False

        self.handle = Handle(
            connect_ip=connect_ip,
            local_reader_ranks=local_reader_ranks,
            buffer=self.buffer,
            local_subscribe_port=local_subscribe_port,
            local_sync_port=local_sync_port,
            remote_subscribe_port=remote_subscribe_port,
            remote_sync_port=remote_sync_port,
        )

241
242
        logger.info("vLLM message queue communication handle: %s", self.handle)

243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    def export_handle(self) -> Handle:
        return self.handle

    @staticmethod
    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
        self = MessageQueue.__new__(MessageQueue)
        self.handle = handle
        self._is_writer = False

        context = Context()

        if rank in handle.local_reader_ranks:
            assert handle.buffer is not None
            self.buffer = handle.buffer
            self.current_idx = 0
            self.local_reader_rank = handle.local_reader_ranks.index(rank)
            self._is_local_reader = True
            self._is_remote_reader = False

            self.local_socket = context.socket(SUB)
            self.local_socket.setsockopt_string(SUBSCRIBE, "")
            self.local_socket.connect(
                f"tcp://{handle.connect_ip}:{handle.local_subscribe_port}")

            self.local_sync_socket = context.socket(REQ)
            self.local_sync_socket.connect(
                f"tcp://{handle.connect_ip}:{handle.local_sync_port}")

            self.remote_socket = None
            self.remote_sync_socket = None
        else:
            self.buffer = None  # type: ignore
            self.current_idx = -1
            self.local_reader_rank = -1
            self._is_local_reader = False
            self._is_remote_reader = True

            self.local_socket = None
            self.local_sync_socket = None

            self.remote_socket = context.socket(SUB)
            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
            self.remote_socket.connect(
                f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}")

            self.remote_sync_socket = context.socket(REQ)
            self.remote_sync_socket.connect(
                f"tcp://{handle.connect_ip}:{handle.remote_sync_port}")

        return self

    def wait_until_ready(self):
        """This is a collective operation. All processes (including the
        readers and the writer) should call this function.
        """
        if self._is_writer:
            # wait for all readers to connect

            # local readers
            for i in range(self.n_local_reader):
                recv = self.local_sync_socket.recv()
                assert recv == b"READY"
                self.local_sync_socket.send(b"READY")
            if self.n_local_reader > 0:
                self.local_socket.send(b"READY")

            # remote readers
            for i in range(self.n_remote_reader):
                recv = self.remote_sync_socket.recv()
                assert recv == b"READY"
                self.remote_sync_socket.send(b"READY")
            if self.n_remote_reader > 0:
                self.remote_socket.send(b"READY")
        elif self._is_local_reader:
            self.local_sync_socket.send(b"READY")
            recv = self.local_sync_socket.recv()
            assert recv == b"READY"
            recv = self.local_socket.recv()
            assert recv == b"READY"
        elif self._is_remote_reader:
            self.remote_sync_socket.send(b"READY")
            recv = self.remote_sync_socket.recv()
            assert recv == b"READY"
            recv = self.remote_socket.recv()
            assert recv == b"READY"
328
329
330
331

    @contextmanager
    def acquire_write(self):
        assert self._is_writer, "Only writers can acquire write"
332
        start_time = time.monotonic()
333
334
335
336
337
338
339
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
                read_count = sum(metadata_buffer[1:])
                written_flag = metadata_buffer[0]
                if written_flag and read_count != self.buffer.n_reader:
                    # this block is written and not read by all readers
340
341
342
343
344
345
346
347
                    # for writers, `self.current_idx` is the next block to write
                    # if this block is not ready to write,
                    # we need to wait until it is read by all readers

                    # wait for a while
                    time.sleep(RINGBUFFER_SLEEP_INTERVAL)

                    # if we wait for a long time, we should warn the user
348
349
                    if (time.monotonic() - start_time >
                            VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
350
351
352
353
354
                        logger.warning(
                            "No available block found in %s second. ",
                            VLLM_RINGBUFFER_WARNING_INTERVAL)
                        n_warning += 1

355
356
357
358
359
360
361
362
363
364
365
366
                    continue
                # found a block that is either
                # (1) not written
                # (2) read by all readers

                # mark the block as not written
                metadata_buffer[0] = 0
                # let caller write to the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has written to the buffer
367
368
369
370
                # NOTE: order is important here
                # first set the read flags to 0
                # then set the written flag to 1
                # otherwise, the readers may think they already read the block
371
372
373
                for i in range(1, self.buffer.n_reader + 1):
                    # set read flag to 0, meaning it is not read yet
                    metadata_buffer[i] = 0
374
375
                # mark the block as written
                metadata_buffer[0] = 1
376
377
                self.current_idx = (self.current_idx +
                                    1) % self.buffer.max_chunks
378
379
380
381
                break

    @contextmanager
    def acquire_read(self):
382
        assert self._is_local_reader, "Only readers can acquire read"
383
        start_time = time.monotonic()
384
385
386
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
387
                read_flag = metadata_buffer[self.local_reader_rank + 1]
388
389
390
391
392
                written_flag = metadata_buffer[0]
                if not written_flag or read_flag:
                    # this block is either
                    # (1) not written
                    # (2) already read by this reader
393
394
395
396
397
398
399
400
401

                    # for readers, `self.current_idx` is the next block to read
                    # if this block is not ready,
                    # we need to wait until it is written

                    # wait for a while
                    time.sleep(RINGBUFFER_SLEEP_INTERVAL)

                    # if we wait for a long time, we should warn the user
402
403
                    if (time.monotonic() - start_time >
                            VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
404
405
406
407
408
                        logger.warning(
                            "No available block found in %s second. ",
                            VLLM_RINGBUFFER_WARNING_INTERVAL)
                        n_warning += 1

409
410
411
412
413
414
415
416
                    continue
                # found a block that is not read by this reader
                # let caller read from the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has read from the buffer
                # set the read flag
417
                metadata_buffer[self.local_reader_rank + 1] = 1
418
419
                self.current_idx = (self.current_idx +
                                    1) % self.buffer.max_chunks
420
421
422
423
424
                break

    def enqueue(self, obj):
        assert self._is_writer, "Only writers can enqueue"
        serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
425
426
427
428
429
430
431
432
433
434
435
        if self.n_local_reader > 0:
            if len(serialized_obj) >= self.buffer.max_chunk_bytes:
                with self.acquire_write() as buf:
                    buf[0] = 1  # overflow
                self.local_socket.send(serialized_obj)
            else:
                with self.acquire_write() as buf:
                    buf[0] = 0  # not overflow
                    buf[1:len(serialized_obj) + 1] = serialized_obj
        if self.n_remote_reader > 0:
            self.remote_socket.send(serialized_obj)
436
437

    def dequeue(self):
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
        if self._is_local_reader:
            with self.acquire_read() as buf:
                overflow = buf[0] == 1
                if not overflow:
                    # no need to know the size of serialized object
                    # pickle format contains the size information internally
                    # see https://docs.python.org/3/library/pickle.html
                    obj = pickle.loads(buf[1:])
            if overflow:
                recv = self.local_socket.recv()
                obj = pickle.loads(recv)
        elif self._is_remote_reader:
            recv = self.remote_socket.recv()
            obj = pickle.loads(recv)
        else:
            raise RuntimeError("Only readers can dequeue")
454
455
456
457
458
459
460
461
462
        return obj

    def broadcast_object(self, obj=None):
        if self._is_writer:
            self.enqueue(obj)
            return obj
        else:
            return self.dequeue()

463
    @staticmethod
464
465
466
    def create_from_process_group(pg: ProcessGroup,
                                  max_chunk_bytes,
                                  max_chunks,
467
                                  writer_rank=0) -> "MessageQueue":
468
469
470
        group_rank = dist.get_rank(pg)
        group_world_size = dist.get_world_size(pg)
        global_ranks = dist.get_process_group_ranks(pg)
471
472
473
474

        from vllm.distributed.parallel_state import in_the_same_node_as
        status = in_the_same_node_as(pg, source_rank=writer_rank)
        same_node_ranks = [i for i, s in enumerate(status) if s]
475
        n_reader = group_world_size - 1
476
477
478
        n_local_reader = len(same_node_ranks) - 1
        local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
        buffer_io: MessageQueue
479
        if group_rank == writer_rank:
480
481
482
483
484
485
486
487
488
            buffer_io = MessageQueue(
                n_reader=n_reader,
                n_local_reader=n_local_reader,
                local_reader_ranks=local_reader_ranks,
                max_chunk_bytes=max_chunk_bytes,
                max_chunks=max_chunks,
            )
            handle = buffer_io.export_handle()
            dist.broadcast_object_list([handle],
489
490
                                       src=global_ranks[writer_rank],
                                       group=pg)
491
492
        else:
            recv = [None]
493
494
495
            dist.broadcast_object_list(recv,
                                       src=global_ranks[writer_rank],
                                       group=pg)
496
497
498
499
            handle = recv[0]  # type: ignore
            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
        buffer_io.wait_until_ready()
        return buffer_io