shm_broadcast.py 24.4 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3

4
5
6
import pickle
import time
from contextlib import contextmanager
7
from dataclasses import dataclass, field
8
from multiprocessing import shared_memory
9
from threading import Event
10
from typing import Any, Optional, Union
11
12
13
14
from unittest.mock import patch

import torch
import torch.distributed as dist
15
import zmq
16
from torch.distributed import ProcessGroup
17
18
19
20
21
22
23
24
from zmq import (  # type: ignore
    IPV6,  # type: ignore
    SUB,
    SUBSCRIBE,
    XPUB,
    XPUB_VERBOSE,
    Context,
)
25
26

import vllm.envs as envs
27
from vllm.distributed.utils import StatelessProcessGroup, sched_yield
28
from vllm.logger import init_logger
29
30
31
32
33
34
from vllm.utils import (
    get_ip,
    get_open_port,
    get_open_zmq_ipc_path,
    is_valid_ipv6_address,
)
35
36
37
38
39
40

VLLM_RINGBUFFER_WARNING_INTERVAL = envs.VLLM_RINGBUFFER_WARNING_INTERVAL

logger = init_logger(__name__)


41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
class SpinTimer:
    def record_activity(self):
        pass

    def spin(self):
        sched_yield()


class SpinSleepTimer(SpinTimer):
    """
    In setups which have long inactivity periods it is desirable to reduce
    system power consumption when vllm does nothing. This would lead to more
    CPU thermal headroom when a request eventually comes, especially when
    multiple GPUs are connected as each GPU would otherwise pin one thread at
    100% CPU usage.

    The simplest solution is to reduce polling frequency when there is no
    activity for a certain period of time.
    """

    def __init__(self, busy_loop_s: float = 3.0, wait_sleep_s: float = 0.1):
        self.last_activity = time.monotonic()
        self.busy_loop_s = busy_loop_s
        self.wait_sleep_s = wait_sleep_s

    def record_activity(self):
        self.last_activity = time.monotonic()

    def spin(self):
        curr_time = time.monotonic()
        if curr_time >= self.last_activity + self.busy_loop_s:
            time.sleep(self.wait_sleep_s)
        else:
            sched_yield()


77
class ShmRingBuffer:
78
79
80
81
82
83
84
    def __init__(
        self,
        n_reader: int,
        max_chunk_bytes: int,
        max_chunks: int,
        name: Optional[str] = None,
    ):
85
86
87
88
89
90
91
        """
        A shared memory ring buffer implementation for broadcast communication.
        Essentially, it is a queue where only one will `enqueue` and multiple
        will `dequeue`. The max size of each item, together with the max number
        of items that can be stored in the buffer are known in advance.
        In this case, we don't need to synchronize the access to
         the buffer.
92

93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        Buffer memory layout:
                  data                                 metadata
                    |                                      |
                    | (current_idx)                        | (current_idx)
                    v                                      v
        +-------------------------------+----------------------------------------+
        | chunk0 | chunk1 | ... | chunk | metadata0 | metadata1 | ... | metadata |
        +-------------------------------+----------------------------------------+
        | max_chunks x max_chunk_bytes  | max_chunks x (1 + n_reader) bytes      |

        metadata memory layout: each byte is a flag, the first byte is the written
        flag, and the rest are reader flags. The flags are set to 0 by default.
        +--------------+--------------+--------------+-----+--------------+
        | written_flag | reader0_flag | reader1_flag | ... | readerN_flag |
        +--------------+--------------+--------------+-----+--------------+

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
        The state of metadata is as follows:

        (case 1) 0???...???: the block is not written yet, cannot read, can write
        (case 2) 1000...000: the block is just written, can read, cannot write
        (case 3) 1???...???: the block is written and read by some readers, can read if not read, cannot write
        (case 4) 1111...111: the block is written and read by all readers, cannot read, can write

        State transition for readers:

        When a reader finds a block that it can read (case 2 or 3), it can yield the block for caller to read.
        Only after the caller finishes reading the block, the reader can mark the block as read.
        Readers only mark the block as read (from 0 to 1), the writer marks the block as ready to read (from 1 to 0).

        State transition for writer:

        When the writer writes to a block (case 1 or 4), it first resets the written flag to 0, converting either case
        to case 1. Then it can yield the block for caller to write. After the caller finishes writing the block, the writer
        can reset the reader flags to 0, and mark the block as written (from 0 to 1).
        NOTE: the order is important here, first reset the reader flags (so that we are still in case 1), then mark the block as written. The state transition is atomic. If we do it in the reverse order, it will go through case 3 and then back to case 2, and readers might read the intermediate case 3, which is not correct.

129
130
131
132
        During creation, `name` is None and the buffer is created. We can pass the
        created object to other processes by pickling it. The other processes will
        get the name of the shared memory and open it, so that they can access the
        same shared memory buffer.
133
        """  # noqa
134
135
136
137
        self.n_reader = n_reader
        self.metadata_size = 1 + n_reader
        self.max_chunk_bytes = max_chunk_bytes
        self.max_chunks = max_chunks
138
139
140
        self.total_bytes_of_buffer = (
            self.max_chunk_bytes + self.metadata_size
        ) * self.max_chunks
141
142
143
144
145
146
147
        self.data_offset = 0
        self.metadata_offset = self.max_chunk_bytes * self.max_chunks

        if name is None:
            # we are creating a buffer
            self.is_creator = True
            self.shared_memory = shared_memory.SharedMemory(
148
149
                create=True, size=self.total_bytes_of_buffer
            )
150
            # initialize the metadata section to 0
151
            with self.shared_memory.buf[self.metadata_offset :] as metadata_buffer:
152
153
154
155
156
157
158
                torch.frombuffer(metadata_buffer, dtype=torch.uint8).fill_(0)
        else:
            # we are opening an existing buffer
            self.is_creator = False
            # fix to https://stackoverflow.com/q/62748654/9191338
            # Python incorrectly tracks shared memory even if it is not
            # created by the process. The following patch is a workaround.
159
160
161
162
            with patch(
                "multiprocessing.resource_tracker.register",
                lambda *args, **kwargs: None,
            ):
163
164
                try:
                    self.shared_memory = shared_memory.SharedMemory(name=name)
165
166
167
168
169
                    # See https://docs.python.org/3/library/multiprocessing.shared_memory.html # noqa
                    # Some platforms allocate memory based on page size,
                    # so the shared memory block size may be larger or equal
                    # to the requested size. The size parameter is ignored
                    # when attaching to an existing block.
170
                    assert self.shared_memory.size >= self.total_bytes_of_buffer
171
172
173
174
175
                except FileNotFoundError:
                    # we might deserialize the object in a different node
                    # in this case, this object is not used,
                    # and we should suppress the error
                    pass
176

177
    def handle(self):
178
179
180
181
182
183
        return (
            self.n_reader,
            self.max_chunk_bytes,
            self.max_chunks,
            self.shared_memory.name,
        )
184

185
186
187
    def __reduce__(self):
        return (
            self.__class__,
188
            self.handle(),
189
190
191
        )

    def __del__(self):
192
193
194
195
        if hasattr(self, "shared_memory"):
            self.shared_memory.close()
            if self.is_creator:
                self.shared_memory.unlink()
196
197
198
199
200

    @contextmanager
    def get_data(self, current_idx: int):
        start = self.data_offset + current_idx * self.max_chunk_bytes
        end = start + self.max_chunk_bytes
201
        with self.shared_memory.buf[start:end] as buf:
202
203
204
205
206
207
            yield buf

    @contextmanager
    def get_metadata(self, current_idx: int):
        start = self.metadata_offset + current_idx * self.metadata_size
        end = start + self.metadata_size
208
        with self.shared_memory.buf[start:end] as buf:
209
210
211
            yield buf


212
213
@dataclass
class Handle:
214
    local_reader_ranks: list[int] = field(default_factory=list)
215

216
    buffer_handle: Optional[tuple[int, int, int, str]] = None
217
218
219
    local_subscribe_addr: Optional[str] = None
    remote_subscribe_addr: Optional[str] = None
    remote_addr_ipv6: bool = False
220
221
222
223
224
225
226


class MessageQueue:
    def __init__(
        self,
        n_reader,  # number of all readers
        n_local_reader,  # number of local readers through shared memory
227
        local_reader_ranks: Optional[list[int]] = None,
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
        max_chunk_bytes: int = 1024 * 1024 * 10,
        max_chunks: int = 10,
        connect_ip: Optional[str] = None,
    ):
        if local_reader_ranks is None:
            local_reader_ranks = list(range(n_local_reader))
        else:
            assert len(local_reader_ranks) == n_local_reader
        self.n_local_reader = n_local_reader
        n_remote_reader = n_reader - n_local_reader
        self.n_remote_reader = n_remote_reader

        context = Context()

        if n_local_reader > 0:
            # for local readers, we will:
            # 1. create a shared memory ring buffer to communicate small data
            # 2. create a publish-subscribe socket to communicate large data
246
            self.buffer = ShmRingBuffer(n_local_reader, max_chunk_bytes, max_chunks)
247

248
249
250
251
252
253
254
255
            # XPUB is very similar to PUB,
            # except that it can receive subscription messages
            # to confirm the number of subscribers
            self.local_socket = context.socket(XPUB)
            # set the verbose option so that we can receive every subscription
            # message. otherwise, we will only receive the first subscription
            # see http://api.zeromq.org/3-3:zmq-setsockopt for more details
            self.local_socket.setsockopt(XPUB_VERBOSE, True)
256
257
258
            local_subscribe_addr = get_open_zmq_ipc_path()
            logger.debug("Binding to %s", local_subscribe_addr)
            self.local_socket.bind(local_subscribe_addr)
259
260
261
262

            self.current_idx = 0
        else:
            self.buffer = None  # type: ignore
263
            local_subscribe_addr = None
264
265
266
            self.local_socket = None
            self.current_idx = -1

267
        remote_addr_ipv6 = False
268
269
270
        if n_remote_reader > 0:
            # for remote readers, we will:
            # create a publish-subscribe socket to communicate large data
271
272
            if not connect_ip:
                connect_ip = get_ip()
273
274
            self.remote_socket = context.socket(XPUB)
            self.remote_socket.setsockopt(XPUB_VERBOSE, True)
275
            remote_subscribe_port = get_open_port()
276
277
            if is_valid_ipv6_address(connect_ip):
                self.remote_socket.setsockopt(IPV6, 1)
278
                remote_addr_ipv6 = True
279
                connect_ip = f"[{connect_ip}]"
280
            socket_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
281
            self.remote_socket.bind(socket_addr)
282
            remote_subscribe_addr = f"tcp://{connect_ip}:{remote_subscribe_port}"
283
        else:
284
            remote_subscribe_addr = None
285
286
287
288
289
290
291
            self.remote_socket = None

        self._is_writer = True
        self._is_local_reader = False
        self.local_reader_rank = -1
        # rank does not matter for remote readers
        self._is_remote_reader = False
292
        self._read_spin_timer = SpinTimer()
293
294
295

        self.handle = Handle(
            local_reader_ranks=local_reader_ranks,
296
            buffer_handle=self.buffer.handle() if self.buffer is not None else None,
297
298
299
            local_subscribe_addr=local_subscribe_addr,
            remote_subscribe_addr=remote_subscribe_addr,
            remote_addr_ipv6=remote_addr_ipv6,
300
301
        )

302
303
        logger.info("vLLM message queue communication handle: %s", self.handle)

304
305
306
307
308
309
310
311
312
313
314
315
    def export_handle(self) -> Handle:
        return self.handle

    @staticmethod
    def create_from_handle(handle: Handle, rank) -> "MessageQueue":
        self = MessageQueue.__new__(MessageQueue)
        self.handle = handle
        self._is_writer = False

        context = Context()

        if rank in handle.local_reader_ranks:
316
317
            assert handle.buffer_handle is not None
            self.buffer = ShmRingBuffer(*handle.buffer_handle)
318
319
320
321
322
323
324
            self.current_idx = 0
            self.local_reader_rank = handle.local_reader_ranks.index(rank)
            self._is_local_reader = True
            self._is_remote_reader = False

            self.local_socket = context.socket(SUB)
            self.local_socket.setsockopt_string(SUBSCRIBE, "")
325
            socket_addr = handle.local_subscribe_addr
326
327
            logger.debug("Connecting to %s", socket_addr)
            self.local_socket.connect(socket_addr)
328
329

            self.remote_socket = None
330

331
332
333
            self._read_spin_timer = (
                SpinSleepTimer() if envs.VLLM_SLEEP_WHEN_IDLE else SpinTimer()
            )
334
335
336
337
338
339
340
341
342
343
344
        else:
            self.buffer = None  # type: ignore
            self.current_idx = -1
            self.local_reader_rank = -1
            self._is_local_reader = False
            self._is_remote_reader = True

            self.local_socket = None

            self.remote_socket = context.socket(SUB)
            self.remote_socket.setsockopt_string(SUBSCRIBE, "")
345
            if handle.remote_addr_ipv6:
346
                self.remote_socket.setsockopt(IPV6, 1)
347
            socket_addr = handle.remote_subscribe_addr
348
349
            logger.debug("Connecting to %s", socket_addr)
            self.remote_socket.connect(socket_addr)
350
351
352
353
354
355
356
357
358
359
360
361

        return self

    def wait_until_ready(self):
        """This is a collective operation. All processes (including the
        readers and the writer) should call this function.
        """
        if self._is_writer:
            # wait for all readers to connect

            # local readers
            for i in range(self.n_local_reader):
362
363
                # wait for subscription messages from all local readers
                self.local_socket.recv()
364
            if self.n_local_reader > 0:
365
366
                # send a message to all local readers
                # to make sure the publish channel is working
367
368
369
370
                self.local_socket.send(b"READY")

            # remote readers
            for i in range(self.n_remote_reader):
371
372
                # wait for subscription messages from all remote readers
                self.remote_socket.recv()
373
            if self.n_remote_reader > 0:
374
375
                # send a message to all remote readers
                # to make sure the publish channel is working
376
377
                self.remote_socket.send(b"READY")
        elif self._is_local_reader:
378
            # wait for the writer to send a message
379
380
381
            recv = self.local_socket.recv()
            assert recv == b"READY"
        elif self._is_remote_reader:
382
            # wait for the writer to send a message
383
384
            recv = self.remote_socket.recv()
            assert recv == b"READY"
385
386

    @contextmanager
387
    def acquire_write(self, timeout: Optional[float] = None):
388
        assert self._is_writer, "Only writers can acquire write"
389
        start_time = time.monotonic()
390
391
392
393
394
395
396
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
                read_count = sum(metadata_buffer[1:])
                written_flag = metadata_buffer[0]
                if written_flag and read_count != self.buffer.n_reader:
                    # this block is written and not read by all readers
397
398
399
400
                    # for writers, `self.current_idx` is the next block to write
                    # if this block is not ready to write,
                    # we need to wait until it is read by all readers

401
                    # Release the processor to other threads
402
                    sched_yield()
403

404
405
406
407
408
                    # if we time out, raise an exception
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
                        raise TimeoutError

409
                    # if we wait for a long time, log a message
410
                    if elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning:
411
                        logger.info(
412
413
414
415
                            "No available shared memory broadcast block found"
                            " in %s seconds. This typically happens when some"
                            " processes are hanging or doing some"
                            " time-consuming work (e.g. compilation)",
416
417
                            VLLM_RINGBUFFER_WARNING_INTERVAL,
                        )
418
419
                        n_warning += 1

420
421
422
423
424
425
426
427
428
429
430
431
                    continue
                # found a block that is either
                # (1) not written
                # (2) read by all readers

                # mark the block as not written
                metadata_buffer[0] = 0
                # let caller write to the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has written to the buffer
432
433
434
435
                # NOTE: order is important here
                # first set the read flags to 0
                # then set the written flag to 1
                # otherwise, the readers may think they already read the block
436
437
438
                for i in range(1, self.buffer.n_reader + 1):
                    # set read flag to 0, meaning it is not read yet
                    metadata_buffer[i] = 0
439
440
                # mark the block as written
                metadata_buffer[0] = 1
441
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
442
443
444
                break

    @contextmanager
445
446
447
448
449
450
    def acquire_read(
        self,
        timeout: Optional[float] = None,
        cancel: Optional[Event] = None,
        indefinite: bool = False,
    ):
451
        assert self._is_local_reader, "Only readers can acquire read"
452
        start_time = time.monotonic()
453
454
455
        n_warning = 1
        while True:
            with self.buffer.get_metadata(self.current_idx) as metadata_buffer:
456
                read_flag = metadata_buffer[self.local_reader_rank + 1]
457
458
459
460
461
                written_flag = metadata_buffer[0]
                if not written_flag or read_flag:
                    # this block is either
                    # (1) not written
                    # (2) already read by this reader
462
463
464
465
466

                    # for readers, `self.current_idx` is the next block to read
                    # if this block is not ready,
                    # we need to wait until it is written

467
                    # Release the processor to other threads
468
                    self._read_spin_timer.spin()
469

470
471
472
                    if cancel is not None and cancel.is_set():
                        raise RuntimeError("cancelled")

473
                    # if we time out, raise an exception
474
475
                    elapsed = time.monotonic() - start_time
                    if timeout is not None and elapsed > timeout:
476
477
                        raise TimeoutError

478
                    # if we wait for a long time, log a message
479
480
481
                    if not indefinite and (
                        elapsed > VLLM_RINGBUFFER_WARNING_INTERVAL * n_warning
                    ):
482
483
484
485
486
                        logger.info(
                            "No available shared memory broadcast block found"
                            " in %s seconds. This typically happens when some"
                            " processes are hanging or doing some"
                            " time-consuming work (e.g. compilation).",
487
488
                            VLLM_RINGBUFFER_WARNING_INTERVAL,
                        )
489
490
                        n_warning += 1

491
492
493
494
495
496
497
498
                    continue
                # found a block that is not read by this reader
                # let caller read from the buffer
                with self.buffer.get_data(self.current_idx) as buf:
                    yield buf

                # caller has read from the buffer
                # set the read flag
499
                metadata_buffer[self.local_reader_rank + 1] = 1
500
                self.current_idx = (self.current_idx + 1) % self.buffer.max_chunks
501
502

                self._read_spin_timer.record_activity()
503
504
                break

505
    def enqueue(self, obj, timeout: Optional[float] = None):
506
        """Write to message queue with optional timeout (in seconds)"""
507
508
        assert self._is_writer, "Only writers can enqueue"
        serialized_obj = pickle.dumps(obj, protocol=pickle.HIGHEST_PROTOCOL)
509
510
        if self.n_local_reader > 0:
            if len(serialized_obj) >= self.buffer.max_chunk_bytes:
511
                with self.acquire_write(timeout) as buf:
512
513
514
                    buf[0] = 1  # overflow
                self.local_socket.send(serialized_obj)
            else:
515
                with self.acquire_write(timeout) as buf:
516
                    buf[0] = 0  # not overflow
517
                    buf[1 : len(serialized_obj) + 1] = serialized_obj
518
519
        if self.n_remote_reader > 0:
            self.remote_socket.send(serialized_obj)
520

521
522
523
524
525
526
527
    def dequeue(
        self,
        timeout: Optional[float] = None,
        cancel: Optional[Event] = None,
        indefinite: bool = False,
    ):
        """Read from message queue with optional timeout (in seconds)"""
528
        if self._is_local_reader:
529
            with self.acquire_read(timeout, cancel, indefinite) as buf:
530
531
532
533
534
535
536
                overflow = buf[0] == 1
                if not overflow:
                    # no need to know the size of serialized object
                    # pickle format contains the size information internally
                    # see https://docs.python.org/3/library/pickle.html
                    obj = pickle.loads(buf[1:])
            if overflow:
537
                obj = MessageQueue.recv(self.local_socket, timeout)
538
        elif self._is_remote_reader:
539
            obj = MessageQueue.recv(self.remote_socket, timeout)
540
541
        else:
            raise RuntimeError("Only readers can dequeue")
542
543
        return obj

544
545
546
547
548
549
550
551
    @staticmethod
    def recv(socket: zmq.Socket, timeout: Optional[float]) -> Any:
        timeout_ms = None if timeout is None else int(timeout * 1000)
        if not socket.poll(timeout=timeout_ms):
            raise TimeoutError
        recv = socket.recv(copy=False)
        return pickle.loads(recv.buffer)

552
553
554
555
556
557
558
    def broadcast_object(self, obj=None):
        if self._is_writer:
            self.enqueue(obj)
            return obj
        else:
            return self.dequeue()

559
    @staticmethod
560
561
562
563
564
565
    def create_from_process_group(
        pg: Union[ProcessGroup, StatelessProcessGroup],
        max_chunk_bytes,
        max_chunks,
        writer_rank=0,
    ) -> "MessageQueue":
566
567
568
569
570
571
572
573
        if isinstance(pg, ProcessGroup):
            group_rank = dist.get_rank(pg)
            group_world_size = dist.get_world_size(pg)
            global_ranks = dist.get_process_group_ranks(pg)
        else:
            group_rank = pg.rank
            group_world_size = pg.world_size
            global_ranks = list(range(pg.world_size))
574
575

        from vllm.distributed.parallel_state import in_the_same_node_as
576

577
578
        status = in_the_same_node_as(pg, source_rank=writer_rank)
        same_node_ranks = [i for i, s in enumerate(status) if s]
579
        n_reader = group_world_size - 1
580
581
582
        n_local_reader = len(same_node_ranks) - 1
        local_reader_ranks = [i for i in same_node_ranks if i != writer_rank]
        buffer_io: MessageQueue
583
        if group_rank == writer_rank:
584
585
586
587
588
589
590
591
            buffer_io = MessageQueue(
                n_reader=n_reader,
                n_local_reader=n_local_reader,
                local_reader_ranks=local_reader_ranks,
                max_chunk_bytes=max_chunk_bytes,
                max_chunks=max_chunks,
            )
            handle = buffer_io.export_handle()
592
            if isinstance(pg, ProcessGroup):
593
594
595
                dist.broadcast_object_list(
                    [handle], src=global_ranks[writer_rank], group=pg
                )
596
597
            else:
                pg.broadcast_obj(handle, writer_rank)
598
        else:
599
600
            if isinstance(pg, ProcessGroup):
                recv = [None]
601
602
603
                dist.broadcast_object_list(
                    recv, src=global_ranks[writer_rank], group=pg
                )
604
605
606
                handle = recv[0]  # type: ignore
            else:
                handle = pg.broadcast_obj(None, writer_rank)
607
608
609
            buffer_io = MessageQueue.create_from_handle(handle, group_rank)
        buffer_io.wait_until_ready()
        return buffer_io