shm_broadcast.py 22.2 KB
Newer Older
1
import os
2
import pickle
3
import sys
4
5
import time
from contextlib import contextmanager
6
from dataclasses import dataclass, field
7
from multiprocessing import shared_memory
8
from typing import List, Optional, Tuple, Union
9
10
11
12
13
from unittest.mock import patch

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
14
from zmq import IPV6  # type: ignore
15
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context  # type: ignore
16
17

import vllm.envs as envs
18
from vllm.distributed.utils import StatelessProcessGroup
19
from vllm.logger import init_logger
20
from vllm.utils import get_ip, get_open_port, is_valid_ipv6_address
21
22
23
24
25

VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

logger = init_logger(__name__)

26
27
28
29
30
31
32
33
34
35
36
37
38
39
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
                   or (sys.version_info[:2] == (3, 10)
                       and sys.version_info[2] >= 8))


def sched_yield():
    if USE_SCHED_YIELD:
        os.sched_yield()
    else:
        time.sleep(0)

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71

class ShmRingBuffer:

    def __init__(self,
                 n_reader: int,
                 max_chunk_bytes: int,
                 max_chunks: int,
                 name: Optional[str] = None):
        """
        A shared memory ring buffer implementation for broadcast communication.
        Essentially, it is a queue where only one will `enqueue` and multiple
        will `dequeue`. The max size of each item, together with the max number
        of items that can be stored in the buffer are known in advance.
        In this case, we don't need to synchronize the access to
         the buffer.
        
        Buffer memory layout:
                  data                                 metadata
                    |                                      |
                    | (current_idx)                        | (current_idx)
                    v                                      v
        +-------------------------------+----------------------------------------+
        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
        +-------------------------------+----------------------------------------+
        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

        metadata memory layout: each byte is a flag, the first byte is the written
        flag, and the rest are reader flags. The flags are set to 0 by default.
        +--------------+--------------+--------------+-----+--------------+
        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
        +--------------+--------------+--------------+-----+--------------+

72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
        The state of metadata is as follows:

        (case 1) 0???...???: the block is not written yet, cannot read, can write
        (case 2) 1000...000: the block is just written, can read, cannot write
        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

        State transition for readers:

        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
        Only after the caller finishes reading the block, the reader can mark the block as read.
        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

        State transition for writer:

        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
        During creation, `name` is None and the buffer is created. We can pass the
        created object to other processes by pickling it. The other processes will
        get the name of the shared memory and open it, so that they can access the
        same shared memory buffer.
        """# noqa
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
        self.total_bytes_of_buffer = (self.max_chunk_bytes +
                                      self.metadata_size) * self.max_chunks
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

        if name is None:
            # we are creating a buffer
            self.is_creator = True
            self.shared_memory = shared_memory.SharedMemory(
                create=True, size=self.total_bytes_of_buffer)
            # initialize the metadata section to 0
            with memoryview(self.shared_memory.buf[self.metadata_offset:]
                            ) as metadata_buffer:
                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
        else:
            # we are opening an existing buffer
            self.is_creator = False
            # fix to https://stackoverflow.com/q/62748654/9191338
            # Python incorrectly tracks shared memory even if it is not
            # created by the process. The following patch is a workaround.
            with patch("multiprocessing.resource_tracker.register",
                       lambda *args, **kwargs: None):
123
124
                try:
                    self.shared_memory = shared_memory.SharedMemory(name=name)
125
126
                    assert (
                        self.shared_memory.size == self.total_bytes_of_buffer)
127
128
129
130
131
                except FileNotFoundError:
                    # we might deserialize the object in a different node
                    # in this case, this object is not used,
                    # and we should suppress the error
                    pass
132

133
134
135
136
    def handle(self):
        return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
                self.shared_memory.name)

137
138
139
    def __reduce__(self):
        return (
            self.__class__,
140
            self.handle(),
141
142
143
        )

    def __del__(self):
144
145
146
147
        if hasattr(self, "shared_memory"):
            self.shared_memory.close()
            if self.is_creator:
                self.shared_memory.unlink()
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163

    @contextmanager
    def get_data(self, current_idx: int):
        start = self.data_offset + current_idx * self.max_chunk_bytes
        end = start + self.max_chunk_bytes
        with memoryview(self.shared_memory.buf[start:end]) as buf:
            yield buf

    @contextmanager
    def get_metadata(self, current_idx: int):
        start = self.metadata_offset + current_idx * self.metadata_size
        end = start + self.metadata_size
        with memoryview(self.shared_memory.buf[start:end]) as buf:
            yield buf


164
165
166
167
@dataclass
class Handle:
    connect_ip: str
    local_reader_ranks: List[int] = field(default_factory=list)
168

169
    buffer_handle: Optional[Tuple[int, int, int, str]] = None
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
    local_subscribe_port: Optional[int] = None
    remote_subscribe_port: Optional[int] = None


class MessageQueue:

    def __init__(
        self,
        n_reader,  # number of all readers
        n_local_reader,  # number of local readers through shared memory
        local_reader_ranks: Optional[List[int]] = None,
        max_chunk_bytes: int = 1024 * 1024 * 10,
        max_chunks: int = 10,
        connect_ip: Optional[str] = None,
    ):
        if local_reader_ranks is None:
            local_reader_ranks = list(range(n_local_reader))
        else:
            assert len(local_reader_ranks) == n_local_reader
        self.n_local_reader = n_local_reader
        n_remote_reader = n_reader - n_local_reader
        self.n_remote_reader = n_remote_reader

        if connect_ip is None:
194
            connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
195
196
197
198
199
200
201
202
203
204

        context = Context()

        if n_local_reader > 0:
            # for local readers, we will:
            # 1. create a shared memory ring buffer to communicate small data
            # 2. create a publish-subscribe socket to communicate large data
            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
                                        max_chunks)

205
206
207
208
209
210
211
212
            # XPUB is very similar to PUB,
            # except that it can receive subscription messages
            # to confirm the number of subscribers
            self.local_socket = context.socket(XPUB)
            # set the verbose option so that we can receive every subscription
            # message. otherwise, we will only receive the first subscription
            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
            self.local_socket.setsockopt(XPUB_VERBOSE, True)
213
            local_subscribe_port = get_open_port()
214
215
216
            socket_addr = f"tcp://127.0.0.1:{local_subscribe_port}"
            logger.debug("Binding to %s", socket_addr)
            self.local_socket.bind(socket_addr)
217
218
219
220
221
222
223
224
225
226
227
228

            self.current_idx = 0

        else:
            self.buffer = None  # type: ignore
            local_subscribe_port = None
            self.local_socket = None
            self.current_idx = -1

        if n_remote_reader > 0:
            # for remote readers, we will:
            # create a publish-subscribe socket to communicate large data
229
230
            self.remote_socket = context.socket(XPUB)
            self.remote_socket.setsockopt(XPUB_VERBOSE, True)
231
            remote_subscribe_port = get_open_port()
232
233
            if is_valid_ipv6_address(connect_ip):
                self.remote_socket.setsockopt(IPV6, 1)
234
235
            socket_addr = f"tcp://*:{remote_subscribe_port}"
            self.remote_socket.bind(socket_addr)
236
237
238
239
240
241
242
243
244
245
246
247
248
249

        else:
            remote_subscribe_port = None
            self.remote_socket = None

        self._is_writer = True
        self._is_local_reader = False
        self.local_reader_rank = -1
        # rank does not matter for remote readers
        self._is_remote_reader = False

        self.handle = Handle(
            connect_ip=connect_ip,
            local_reader_ranks=local_reader_ranks,
250
251
            buffer_handle=self.buffer.handle()
            if self.buffer is not None else None,
252
253
254
255
            local_subscribe_port=local_subscribe_port,
            remote_subscribe_port=remote_subscribe_port,
        )

256
257
        logger.info("vLLM message queue communication handle: %s", self.handle)

258
259
260
261
262
263
264
265
266
267
268
269
    def export_handle(self) -> Handle:
        return self.handle

    @staticmethod
    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
        self = MessageQueue.__new__(MessageQueue)
        self.handle = handle
        self._is_writer = False

        context = Context()

        if rank in handle.local_reader_ranks:
270
271
            assert handle.buffer_handle is not None
            self.buffer = ShmRingBuffer(*handle.buffer_handle)
272
273
274
275
276
277
278
            self.current_idx = 0
            self.local_reader_rank = handle.local_reader_ranks.index(rank)
            self._is_local_reader = True
            self._is_remote_reader = False

            self.local_socket = context.socket(SUB)
            self.local_socket.setsockopt_string(SUBSCRIBE, "")
279
280
281
            socket_addr = f"tcp://127.0.0.1:{handle.local_subscribe_port}"
            logger.debug("Connecting to %s", socket_addr)
            self.local_socket.connect(socket_addr)
282
283
284
285
286
287
288
289
290
291
292
293
294

            self.remote_socket = None
        else:
            self.buffer = None  # type: ignore
            self.current_idx = -1
            self.local_reader_rank = -1
            self._is_local_reader = False
            self._is_remote_reader = True

            self.local_socket = None

            self.remote_socket = context.socket(SUB)
            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
295
296
            if is_valid_ipv6_address(handle.connect_ip):
                self.remote_socket.setsockopt(IPV6, 1)
297
298
299
            socket_addr = f"tcp://{handle.connect_ip}:{handle.remote_subscribe_port}"
            logger.debug("Connecting to %s", socket_addr)
            self.remote_socket.connect(socket_addr)
300
301
302
303
304
305
306
307
308
309
310
311

        return self

    def wait_until_ready(self):
        """This is a collective operation. All processes (including the
        readers and the writer) should call this function.
        """
        if self._is_writer:
            # wait for all readers to connect

            # local readers
            for i in range(self.n_local_reader):
312
313
                # wait for subscription messages from all local readers
                self.local_socket.recv()
314
            if self.n_local_reader > 0:
315
316
                # send a message to all local readers
                # to make sure the publish channel is working
317
318
319
320
                self.local_socket.send(b"READY")

            # remote readers
            for i in range(self.n_remote_reader):
321
322
                # wait for subscription messages from all remote readers
                self.remote_socket.recv()
323
            if self.n_remote_reader > 0:
324
325
                # send a message to all remote readers
                # to make sure the publish channel is working
326
327
                self.remote_socket.send(b"READY")
        elif self._is_local_reader:
328
            # wait for the writer to send a message
329
330
331
            recv = self.local_socket.recv()
            assert recv == b"READY"
        elif self._is_remote_reader:
332
            # wait for the writer to send a message
333
334
            recv = self.remote_socket.recv()
            assert recv == b"READY"
335
336

    @contextmanager
337
    def acquire_write(self, timeout: Optional[float] = None):
338
        assert self._is_writer, "Only writers can acquire write"
339
        start_time = time.monotonic()
340
341
342
343
344
345
346
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
                read_count = sum(metadata_buffer[1:])
                written_flag = metadata_buffer[0]
                if written_flag and read_count != self.buffer.n_reader:
                    # this block is written and not read by all readers
347
348
349
350
                    # for writers, `self.current_idx` is the next block to write
                    # if this block is not ready to write,
                    # we need to wait until it is read by all readers

351
                    # Release the processor to other threads
352
                    sched_yield()
353

354
                    # if we wait for a long time, log a message
355
356
                    if (time.monotonic() - start_time
                            > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
357
358
                        logger.debug("No available block found in %s second. ",
                                     VLLM_RINGBUFFER_WARNING_INTERVAL)
359
360
                        n_warning += 1

361
362
363
364
365
                    # if we time out, raise an exception
                    if (timeout is not None
                            and time.monotonic() - start_time > timeout):
                        raise TimeoutError

366
367
368
369
370
371
372
373
374
375
376
377
                    continue
                # found a block that is either
                # (1) not written
                # (2) read by all readers

                # mark the block as not written
                metadata_buffer[0] = 0
                # let caller write to the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has written to the buffer
378
379
380
381
                # NOTE: order is important here
                # first set the read flags to 0
                # then set the written flag to 1
                # otherwise, the readers may think they already read the block
382
383
384
                for i in range(1, self.buffer.n_reader + 1):
                    # set read flag to 0, meaning it is not read yet
                    metadata_buffer[i] = 0
385
386
                # mark the block as written
                metadata_buffer[0] = 1
387
388
                self.current_idx = (self.current_idx +
                                    1) % self.buffer.max_chunks
389
390
391
                break

    @contextmanager
392
    def acquire_read(self, timeout: Optional[float] = None):
393
        assert self._is_local_reader, "Only readers can acquire read"
394
        start_time = time.monotonic()
395
396
397
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
398
                read_flag = metadata_buffer[self.local_reader_rank + 1]
399
400
401
402
403
                written_flag = metadata_buffer[0]
                if not written_flag or read_flag:
                    # this block is either
                    # (1) not written
                    # (2) already read by this reader
404
405
406
407
408

                    # for readers, `self.current_idx` is the next block to read
                    # if this block is not ready,
                    # we need to wait until it is written

409
                    # Release the processor to other threads
410
                    sched_yield()
411

412
                    # if we wait for a long time, log a message
413
414
                    if (time.monotonic() - start_time
                            > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
415
416
                        logger.debug("No available block found in %s second. ",
                                     VLLM_RINGBUFFER_WARNING_INTERVAL)
417
418
                        n_warning += 1

419
420
421
422
423
                    # if we time out, raise an exception
                    if (timeout is not None
                            and time.monotonic() - start_time > timeout):
                        raise TimeoutError

424
425
426
427
428
429
430
431
                    continue
                # found a block that is not read by this reader
                # let caller read from the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has read from the buffer
                # set the read flag
432
                metadata_buffer[self.local_reader_rank + 1] = 1
433
434
                self.current_idx = (self.current_idx +
                                    1) % self.buffer.max_chunks
435
436
                break

437
438
    def enqueue(self, obj, timeout: Optional[float] = None):
        """ Write to message queue with optional timeout (in seconds) """
439
440
        assert self._is_writer, "Only writers can enqueue"
        serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
441
442
        if self.n_local_reader > 0:
            if len(serialized_obj) >= self.buffer.max_chunk_bytes:
443
                with self.acquire_write(timeout) as buf:
444
445
446
                    buf[0] = 1  # overflow
                self.local_socket.send(serialized_obj)
            else:
447
                with self.acquire_write(timeout) as buf:
448
449
450
451
                    buf[0] = 0  # not overflow
                    buf[1:len(serialized_obj) + 1] = serialized_obj
        if self.n_remote_reader > 0:
            self.remote_socket.send(serialized_obj)
452

453
454
    def dequeue(self, timeout: Optional[float] = None):
        """ Read from message queue with optional timeout (in seconds) """
455
        if self._is_local_reader:
456
            with self.acquire_read(timeout) as buf:
457
458
459
460
461
462
463
464
465
466
467
468
469
470
                overflow = buf[0] == 1
                if not overflow:
                    # no need to know the size of serialized object
                    # pickle format contains the size information internally
                    # see https://docs.python.org/3/library/pickle.html
                    obj = pickle.loads(buf[1:])
            if overflow:
                recv = self.local_socket.recv()
                obj = pickle.loads(recv)
        elif self._is_remote_reader:
            recv = self.remote_socket.recv()
            obj = pickle.loads(recv)
        else:
            raise RuntimeError("Only readers can dequeue")
471
472
473
474
475
476
477
478
479
        return obj

    def broadcast_object(self, obj=None):
        if self._is_writer:
            self.enqueue(obj)
            return obj
        else:
            return self.dequeue()

480
    @staticmethod
481
482
    def create_from_process_group(pg: Union[ProcessGroup,
                                            StatelessProcessGroup],
483
484
                                  max_chunk_bytes,
                                  max_chunks,
485
                                  writer_rank=0) -> "MessageQueue":
486
487
488
489
490
491
492
493
        if isinstance(pg, ProcessGroup):
            group_rank = dist.get_rank(pg)
            group_world_size = dist.get_world_size(pg)
            global_ranks = dist.get_process_group_ranks(pg)
        else:
            group_rank = pg.rank
            group_world_size = pg.world_size
            global_ranks = list(range(pg.world_size))
494
495
496
497

        from vllm.distributed.parallel_state import in_the_same_node_as
        status = in_the_same_node_as(pg, source_rank=writer_rank)
        same_node_ranks = [i for i, s in enumerate(status) if s]
498
        n_reader = group_world_size - 1
499
500
501
        n_local_reader = len(same_node_ranks) - 1
        local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
        buffer_io: MessageQueue
502
        if group_rank == writer_rank:
503
504
505
506
507
508
509
510
            buffer_io = MessageQueue(
                n_reader=n_reader,
                n_local_reader=n_local_reader,
                local_reader_ranks=local_reader_ranks,
                max_chunk_bytes=max_chunk_bytes,
                max_chunks=max_chunks,
            )
            handle = buffer_io.export_handle()
511
512
513
514
515
516
            if isinstance(pg, ProcessGroup):
                dist.broadcast_object_list([handle],
                                           src=global_ranks[writer_rank],
                                           group=pg)
            else:
                pg.broadcast_obj(handle, writer_rank)
517
        else:
518
519
520
521
522
523
524
525
            if isinstance(pg, ProcessGroup):
                recv = [None]
                dist.broadcast_object_list(recv,
                                           src=global_ranks[writer_rank],
                                           group=pg)
                handle = recv[0]  # type: ignore
            else:
                handle = pg.broadcast_obj(None, writer_rank)
526
527
528
            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
        buffer_io.wait_until_ready()
        return buffer_io