shm_broadcast.py 22.3 KB
Newer Older
1
2
# SPDX-License-Identifier: Apache-2.0

3
import os
4
import pickle
5
import sys
6
7
import time
from contextlib import contextmanager
8
from dataclasses import dataclass, field
9
from multiprocessing import shared_memory
10
from typing import List, Optional, Tuple, Union
11
12
13
14
15
from unittest.mock import patch

import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
16
from zmq import IPV6  # type: ignore
17
from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context  # type: ignore
18
19

import vllm.envs as envs
20
from vllm.distributed.utils import StatelessProcessGroup
21
from vllm.logger import init_logger
22
23
from vllm.utils import (get_ip, get_open_port, get_open_zmq_ipc_path,
                        is_valid_ipv6_address)
24
25
26
27
28

VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

logger = init_logger(__name__)

29
30
31
32
33
34
35
36
37
38
39
40
41
42
# We prefer to use os.sched_yield as it results in tighter polling loops,
# measured to be around 3e-7 seconds. However on earlier versions of Python
# os.sched_yield() does not release the GIL, so we fall back to time.sleep(0)
USE_SCHED_YIELD = ((sys.version_info[:3] >= (3, 11, 1))
                   or (sys.version_info[:2] == (3, 10)
                       and sys.version_info[2] >= 8))


def sched_yield():
    if USE_SCHED_YIELD:
        os.sched_yield()
    else:
        time.sleep(0)

43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74

class ShmRingBuffer:

    def __init__(self,
                 n_reader: int,
                 max_chunk_bytes: int,
                 max_chunks: int,
                 name: Optional[str] = None):
        """
        A shared memory ring buffer implementation for broadcast communication.
        Essentially, it is a queue where only one will `enqueue` and multiple
        will `dequeue`. The max size of each item, together with the max number
        of items that can be stored in the buffer are known in advance.
        In this case, we don't need to synchronize the access to
         the buffer.
        
        Buffer memory layout:
                  data                                 metadata
                    |                                      |
                    | (current_idx)                        | (current_idx)
                    v                                      v
        +-------------------------------+----------------------------------------+
        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
        +-------------------------------+----------------------------------------+
        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

        metadata memory layout: each byte is a flag, the first byte is the written
        flag, and the rest are reader flags. The flags are set to 0 by default.
        +--------------+--------------+--------------+-----+--------------+
        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
        +--------------+--------------+--------------+-----+--------------+

75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
        The state of metadata is as follows:

        (case 1) 0???...???: the block is not written yet, cannot read, can write
        (case 2) 1000...000: the block is just written, can read, cannot write
        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

        State transition for readers:

        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
        Only after the caller finishes reading the block, the reader can mark the block as read.
        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

        State transition for writer:

        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
        During creation, `name` is None and the buffer is created. We can pass the
        created object to other processes by pickling it. The other processes will
        get the name of the shared memory and open it, so that they can access the
        same shared memory buffer.
        """# noqa
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
        self.total_bytes_of_buffer = (self.max_chunk_bytes +
                                      self.metadata_size) * self.max_chunks
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

        if name is None:
            # we are creating a buffer
            self.is_creator = True
            self.shared_memory = shared_memory.SharedMemory(
                create=True, size=self.total_bytes_of_buffer)
            # initialize the metadata section to 0
            with memoryview(self.shared_memory.buf[self.metadata_offset:]
                            ) as metadata_buffer:
                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
        else:
            # we are opening an existing buffer
            self.is_creator = False
            # fix to https://stackoverflow.com/q/62748654/9191338
            # Python incorrectly tracks shared memory even if it is not
            # created by the process. The following patch is a workaround.
            with patch("multiprocessing.resource_tracker.register",
                       lambda *args, **kwargs: None):
126
127
                try:
                    self.shared_memory = shared_memory.SharedMemory(name=name)
128
129
                    assert (
                        self.shared_memory.size == self.total_bytes_of_buffer)
130
131
132
133
134
                except FileNotFoundError:
                    # we might deserialize the object in a different node
                    # in this case, this object is not used,
                    # and we should suppress the error
                    pass
135

136
137
138
139
    def handle(self):
        return (self.n_reader, self.max_chunk_bytes, self.max_chunks,
                self.shared_memory.name)

140
141
142
    def __reduce__(self):
        return (
            self.__class__,
143
            self.handle(),
144
145
146
        )

    def __del__(self):
147
148
149
150
        if hasattr(self, "shared_memory"):
            self.shared_memory.close()
            if self.is_creator:
                self.shared_memory.unlink()
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166

    @contextmanager
    def get_data(self, current_idx: int):
        start = self.data_offset + current_idx * self.max_chunk_bytes
        end = start + self.max_chunk_bytes
        with memoryview(self.shared_memory.buf[start:end]) as buf:
            yield buf

    @contextmanager
    def get_metadata(self, current_idx: int):
        start = self.metadata_offset + current_idx * self.metadata_size
        end = start + self.metadata_size
        with memoryview(self.shared_memory.buf[start:end]) as buf:
            yield buf


167
168
169
@dataclass
class Handle:
    local_reader_ranks: List[int] = field(default_factory=list)
170

171
    buffer_handle: Optional[Tuple[int, int, int, str]] = None
172
173
174
    local_subscribe_addr: Optional[str] = None
    remote_subscribe_addr: Optional[str] = None
    remote_addr_ipv6: bool = False
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204


class MessageQueue:

    def __init__(
        self,
        n_reader,  # number of all readers
        n_local_reader,  # number of local readers through shared memory
        local_reader_ranks: Optional[List[int]] = None,
        max_chunk_bytes: int = 1024 * 1024 * 10,
        max_chunks: int = 10,
        connect_ip: Optional[str] = None,
    ):
        if local_reader_ranks is None:
            local_reader_ranks = list(range(n_local_reader))
        else:
            assert len(local_reader_ranks) == n_local_reader
        self.n_local_reader = n_local_reader
        n_remote_reader = n_reader - n_local_reader
        self.n_remote_reader = n_remote_reader

        context = Context()

        if n_local_reader > 0:
            # for local readers, we will:
            # 1. create a shared memory ring buffer to communicate small data
            # 2. create a publish-subscribe socket to communicate large data
            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes,
                                        max_chunks)

205
206
207
208
209
210
211
212
            # XPUB is very similar to PUB,
            # except that it can receive subscription messages
            # to confirm the number of subscribers
            self.local_socket = context.socket(XPUB)
            # set the verbose option so that we can receive every subscription
            # message. otherwise, we will only receive the first subscription
            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
            self.local_socket.setsockopt(XPUB_VERBOSE, True)
213
214
215
            local_subscribe_addr = get_open_zmq_ipc_path()
            logger.debug("Binding to %s", local_subscribe_addr)
            self.local_socket.bind(local_subscribe_addr)
216
217
218
219

            self.current_idx = 0
        else:
            self.buffer = None  # type: ignore
220
            local_subscribe_addr = None
221
222
223
            self.local_socket = None
            self.current_idx = -1

224
        remote_addr_ipv6 = False
225
226
227
        if n_remote_reader > 0:
            # for remote readers, we will:
            # create a publish-subscribe socket to communicate large data
228
229
            if not connect_ip:
                connect_ip = get_ip()
230
231
            self.remote_socket = context.socket(XPUB)
            self.remote_socket.setsockopt(XPUB_VERBOSE, True)
232
            remote_subscribe_port = get_open_port()
233
234
            if is_valid_ipv6_address(connect_ip):
                self.remote_socket.setsockopt(IPV6, 1)
235
                remote_addr_ipv6 = True
236
237
            socket_addr = f"tcp://*:{remote_subscribe_port}"
            self.remote_socket.bind(socket_addr)
238
            remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
239
        else:
240
            remote_subscribe_addr = None
241
242
243
244
245
246
247
248
249
250
            self.remote_socket = None

        self._is_writer = True
        self._is_local_reader = False
        self.local_reader_rank = -1
        # rank does not matter for remote readers
        self._is_remote_reader = False

        self.handle = Handle(
            local_reader_ranks=local_reader_ranks,
251
252
            buffer_handle=self.buffer.handle()
            if self.buffer is not None else None,
253
254
255
            local_subscribe_addr=local_subscribe_addr,
            remote_subscribe_addr=remote_subscribe_addr,
            remote_addr_ipv6=remote_addr_ipv6,
256
257
        )

258
259
        logger.info("vLLM message queue communication handle: %s", self.handle)

260
261
262
263
264
265
266
267
268
269
270
271
    def export_handle(self) -> Handle:
        return self.handle

    @staticmethod
    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
        self = MessageQueue.__new__(MessageQueue)
        self.handle = handle
        self._is_writer = False

        context = Context()

        if rank in handle.local_reader_ranks:
272
273
            assert handle.buffer_handle is not None
            self.buffer = ShmRingBuffer(*handle.buffer_handle)
274
275
276
277
278
279
280
            self.current_idx = 0
            self.local_reader_rank = handle.local_reader_ranks.index(rank)
            self._is_local_reader = True
            self._is_remote_reader = False

            self.local_socket = context.socket(SUB)
            self.local_socket.setsockopt_string(SUBSCRIBE, "")
281
            socket_addr = handle.local_subscribe_addr
282
283
            logger.debug("Connecting to %s", socket_addr)
            self.local_socket.connect(socket_addr)
284
285
286
287
288
289
290
291
292
293
294
295
296

            self.remote_socket = None
        else:
            self.buffer = None  # type: ignore
            self.current_idx = -1
            self.local_reader_rank = -1
            self._is_local_reader = False
            self._is_remote_reader = True

            self.local_socket = None

            self.remote_socket = context.socket(SUB)
            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
297
            if handle.remote_addr_ipv6:
298
                self.remote_socket.setsockopt(IPV6, 1)
299
            socket_addr = handle.remote_subscribe_addr
300
301
            logger.debug("Connecting to %s", socket_addr)
            self.remote_socket.connect(socket_addr)
302
303
304
305
306
307
308
309
310
311
312
313

        return self

    def wait_until_ready(self):
        """This is a collective operation. All processes (including the
        readers and the writer) should call this function.
        """
        if self._is_writer:
            # wait for all readers to connect

            # local readers
            for i in range(self.n_local_reader):
314
315
                # wait for subscription messages from all local readers
                self.local_socket.recv()
316
            if self.n_local_reader > 0:
317
318
                # send a message to all local readers
                # to make sure the publish channel is working
319
320
321
322
                self.local_socket.send(b"READY")

            # remote readers
            for i in range(self.n_remote_reader):
323
324
                # wait for subscription messages from all remote readers
                self.remote_socket.recv()
325
            if self.n_remote_reader > 0:
326
327
                # send a message to all remote readers
                # to make sure the publish channel is working
328
329
                self.remote_socket.send(b"READY")
        elif self._is_local_reader:
330
            # wait for the writer to send a message
331
332
333
            recv = self.local_socket.recv()
            assert recv == b"READY"
        elif self._is_remote_reader:
334
            # wait for the writer to send a message
335
336
            recv = self.remote_socket.recv()
            assert recv == b"READY"
337
338

    @contextmanager
339
    def acquire_write(self, timeout: Optional[float] = None):
340
        assert self._is_writer, "Only writers can acquire write"
341
        start_time = time.monotonic()
342
343
344
345
346
347
348
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
                read_count = sum(metadata_buffer[1:])
                written_flag = metadata_buffer[0]
                if written_flag and read_count != self.buffer.n_reader:
                    # this block is written and not read by all readers
349
350
351
352
                    # for writers, `self.current_idx` is the next block to write
                    # if this block is not ready to write,
                    # we need to wait until it is read by all readers

353
                    # Release the processor to other threads
354
                    sched_yield()
355

356
                    # if we wait for a long time, log a message
357
358
                    if (time.monotonic() - start_time
                            > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
359
360
                        logger.debug("No available block found in %s second. ",
                                     VLLM_RINGBUFFER_WARNING_INTERVAL)
361
362
                        n_warning += 1

363
364
365
366
367
                    # if we time out, raise an exception
                    if (timeout is not None
                            and time.monotonic() - start_time > timeout):
                        raise TimeoutError

368
369
370
371
372
373
374
375
376
377
378
379
                    continue
                # found a block that is either
                # (1) not written
                # (2) read by all readers

                # mark the block as not written
                metadata_buffer[0] = 0
                # let caller write to the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has written to the buffer
380
381
382
383
                # NOTE: order is important here
                # first set the read flags to 0
                # then set the written flag to 1
                # otherwise, the readers may think they already read the block
384
385
386
                for i in range(1, self.buffer.n_reader + 1):
                    # set read flag to 0, meaning it is not read yet
                    metadata_buffer[i] = 0
387
388
                # mark the block as written
                metadata_buffer[0] = 1
389
390
                self.current_idx = (self.current_idx +
                                    1) % self.buffer.max_chunks
391
392
393
                break

    @contextmanager
394
    def acquire_read(self, timeout: Optional[float] = None):
395
        assert self._is_local_reader, "Only readers can acquire read"
396
        start_time = time.monotonic()
397
398
399
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
400
                read_flag = metadata_buffer[self.local_reader_rank + 1]
401
402
403
404
405
                written_flag = metadata_buffer[0]
                if not written_flag or read_flag:
                    # this block is either
                    # (1) not written
                    # (2) already read by this reader
406
407
408
409
410

                    # for readers, `self.current_idx` is the next block to read
                    # if this block is not ready,
                    # we need to wait until it is written

411
                    # Release the processor to other threads
412
                    sched_yield()
413

414
                    # if we wait for a long time, log a message
415
416
                    if (time.monotonic() - start_time
                            > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning):
417
418
                        logger.debug("No available block found in %s second. ",
                                     VLLM_RINGBUFFER_WARNING_INTERVAL)
419
420
                        n_warning += 1

421
422
423
424
425
                    # if we time out, raise an exception
                    if (timeout is not None
                            and time.monotonic() - start_time > timeout):
                        raise TimeoutError

426
427
428
429
430
431
432
433
                    continue
                # found a block that is not read by this reader
                # let caller read from the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has read from the buffer
                # set the read flag
434
                metadata_buffer[self.local_reader_rank + 1] = 1
435
436
                self.current_idx = (self.current_idx +
                                    1) % self.buffer.max_chunks
437
438
                break

439
440
    def enqueue(self, obj, timeout: Optional[float] = None):
        """ Write to message queue with optional timeout (in seconds) """
441
442
        assert self._is_writer, "Only writers can enqueue"
        serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
443
444
        if self.n_local_reader > 0:
            if len(serialized_obj) >= self.buffer.max_chunk_bytes:
445
                with self.acquire_write(timeout) as buf:
446
447
448
                    buf[0] = 1  # overflow
                self.local_socket.send(serialized_obj)
            else:
449
                with self.acquire_write(timeout) as buf:
450
451
452
453
                    buf[0] = 0  # not overflow
                    buf[1:len(serialized_obj) + 1] = serialized_obj
        if self.n_remote_reader > 0:
            self.remote_socket.send(serialized_obj)
454

455
456
    def dequeue(self, timeout: Optional[float] = None):
        """ Read from message queue with optional timeout (in seconds) """
457
        if self._is_local_reader:
458
            with self.acquire_read(timeout) as buf:
459
460
461
462
463
464
465
466
467
468
469
470
471
472
                overflow = buf[0] == 1
                if not overflow:
                    # no need to know the size of serialized object
                    # pickle format contains the size information internally
                    # see https://docs.python.org/3/library/pickle.html
                    obj = pickle.loads(buf[1:])
            if overflow:
                recv = self.local_socket.recv()
                obj = pickle.loads(recv)
        elif self._is_remote_reader:
            recv = self.remote_socket.recv()
            obj = pickle.loads(recv)
        else:
            raise RuntimeError("Only readers can dequeue")
473
474
475
476
477
478
479
480
481
        return obj

    def broadcast_object(self, obj=None):
        if self._is_writer:
            self.enqueue(obj)
            return obj
        else:
            return self.dequeue()

482
    @staticmethod
483
484
    def create_from_process_group(pg: Union[ProcessGroup,
                                            StatelessProcessGroup],
485
486
                                  max_chunk_bytes,
                                  max_chunks,
487
                                  writer_rank=0) -> "MessageQueue":
488
489
490
491
492
493
494
495
        if isinstance(pg, ProcessGroup):
            group_rank = dist.get_rank(pg)
            group_world_size = dist.get_world_size(pg)
            global_ranks = dist.get_process_group_ranks(pg)
        else:
            group_rank = pg.rank
            group_world_size = pg.world_size
            global_ranks = list(range(pg.world_size))
496
497
498
499

        from vllm.distributed.parallel_state import in_the_same_node_as
        status = in_the_same_node_as(pg, source_rank=writer_rank)
        same_node_ranks = [i for i, s in enumerate(status) if s]
500
        n_reader = group_world_size - 1
501
502
503
        n_local_reader = len(same_node_ranks) - 1
        local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
        buffer_io: MessageQueue
504
        if group_rank == writer_rank:
505
506
507
508
509
510
511
512
            buffer_io = MessageQueue(
                n_reader=n_reader,
                n_local_reader=n_local_reader,
                local_reader_ranks=local_reader_ranks,
                max_chunk_bytes=max_chunk_bytes,
                max_chunks=max_chunks,
            )
            handle = buffer_io.export_handle()
513
514
515
516
517
518
            if isinstance(pg, ProcessGroup):
                dist.broadcast_object_list([handle],
                                           src=global_ranks[writer_rank],
                                           group=pg)
            else:
                pg.broadcast_obj(handle, writer_rank)
519
        else:
520
521
522
523
524
525
526
527
            if isinstance(pg, ProcessGroup):
                recv = [None]
                dist.broadcast_object_list(recv,
                                           src=global_ranks[writer_rank],
                                           group=pg)
                handle = recv[0]  # type: ignore
            else:
                handle = pg.broadcast_obj(None, writer_rank)
528
529
530
            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
        buffer_io.wait_until_ready()
        return buffer_io