kv_events.py 16.3 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
4
5
6
7

import queue
import threading
import time
from abc import ABC, abstractmethod
8
from collections import Counter, deque
9
from collections.abc import Callable
10
from dataclasses import asdict
11
12
from itertools import count
from queue import Queue
13
from typing import Any
14
15
16
17

import msgspec
import zmq

18
from vllm.config.kv_events import KVEventsConfig
19
from vllm.logger import init_logger
20
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
21
22
23
24
25

logger = init_logger(__name__)


class EventBatch(
26
27
28
29
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,  # type: ignore[call-arg]
30
31
32
):
    ts: float
    events: list[Any]
33
    data_parallel_rank: int | None = None
34
35
36


class KVCacheEvent(
37
38
39
40
41
42
    msgspec.Struct,
    array_like=True,  # type: ignore[call-arg]
    omit_defaults=True,  # type: ignore[call-arg]
    gc=False,  # type: ignore[call-arg]
    tag=True,
):
43
44
45
    """Base class for all KV cache-related events"""


46
47
48
MEDIUM_GPU = "GPU"


49
class BlockStored(KVCacheEvent):
50
    block_hashes: list[ExternalBlockHash]
51
    parent_block_hash: ExternalBlockHash | None
52
53
    token_ids: list[int]
    block_size: int
54

55
    lora_id: int | None
56
57
58
59
    """Deprecated: use `lora_name` for KV block key hash.
    Retained for backward compatibility.
    """

60
    medium: str | None
61
    lora_name: str | None
62

63
64
65
66
67
68
69
    extra_keys: list[tuple[Any, ...] | None] | None = None
    """Extra keys used in block hash computation, one entry per block in
    block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
    prompt embedding hashes, etc. for that specific block. Exposed for external
    KV cache consumers to reconstruct block hashes.
    """

70
71
72
73
74
75
76
77
78
    def __hash__(self) -> int:
        return hash(
            (
                tuple(self.block_hashes),
                self.parent_block_hash,
                tuple(self.token_ids),
                self.block_size,
                self.lora_id,
                self.medium,
79
                tuple(self.extra_keys) if self.extra_keys else None,
80
81
82
            )
        )

83
84

class BlockRemoved(KVCacheEvent):
85
    block_hashes: list[ExternalBlockHash]
86
    medium: str | None
87

88
89
90
    def __hash__(self) -> int:
        return hash((tuple(self.block_hashes), self.medium))

91
92
93
94
95
96

class AllBlocksCleared(KVCacheEvent):
    pass


class KVEventBatch(EventBatch):
97
    events: list[BlockStored | BlockRemoved | AllBlocksCleared]
98
99


100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
class KVEventAggregator:
    """
    Aggregates KV events across multiple workers.
    Tracks how many times each event appears and returns only those
    that were emitted by all workers.
    """

    __slots__ = ("_event_counter", "_num_workers")

    def __init__(self, num_workers: int) -> None:
        if num_workers <= 0:
            raise ValueError("num_workers must be greater than zero.")
        self._event_counter: Counter[KVCacheEvent] = Counter()
        self._num_workers: int = num_workers

    def add_events(self, events: list[KVCacheEvent]) -> None:
        """
        Add events from a worker batch.

        :param events: List of KVCacheEvent objects.
        """
        if not isinstance(events, list):
            raise TypeError("events must be a list of KVCacheEvent.")
        self._event_counter.update(events)

    def get_common_events(self) -> list[KVCacheEvent]:
        """
        Return events that appeared in all workers.

        :return: List of events present in all workers.
        """
        return [
            event
            for event, count in self._event_counter.items()
            if count == self._num_workers
        ]

    def get_all_events(self) -> list[KVCacheEvent]:
        """
        Return all events for all workers.

        :return: List of events for all workers.
        """
        return list(self._event_counter.elements())

    def clear_events(self) -> None:
        """
        Clear all tracked events.
        """
        self._event_counter.clear()

    def increment_workers(self, count: int = 1) -> None:
        """
        Increment the number of workers contributing events.

        :param count: Number to increment the workers by.
        """
        if count <= 0:
            raise ValueError("count must be positive.")
        self._num_workers += count

    def reset_workers(self) -> None:
        """
        Reset the number of workers to 1.
        """
        self._num_workers = 1

    def get_number_of_workers(self) -> int:
        """
        Return the number of workers.

        :return: int number of workers.
        """
        return self._num_workers

    def __repr__(self) -> str:
        return (
            f"<KVEventAggregator workers={self._num_workers}, "
            f"events={len(self._event_counter)}>"
        )


class KVConnectorKVEvents(ABC):
    """
    Abstract base class for KV events.
    Acts as a container for KV events from the connector.
    """

    @abstractmethod
    def add_events(self, events: list[KVCacheEvent]) -> None:
        raise NotImplementedError

    @abstractmethod
    def aggregate(self) -> "KVConnectorKVEvents":
        raise NotImplementedError

    @abstractmethod
    def increment_workers(self, count: int = 1) -> None:
        raise NotImplementedError

    @abstractmethod
    def get_all_events(self) -> list[KVCacheEvent]:
        raise NotImplementedError

    @abstractmethod
    def get_number_of_workers(self) -> int:
        raise NotImplementedError

    @abstractmethod
    def clear_events(self) -> None:
        raise NotImplementedError

212
213
214
215
    def merge(self, other: "KVConnectorKVEvents") -> "KVConnectorKVEvents":
        self.add_events(other.get_all_events())
        return self

216

217
class EventPublisher(ABC):
218
219
    """Lightweight publisher for EventBatch batches with data parallelism
    support.
220

221
222
    In data parallel setups, each DP rank runs its own EventPublisher instance
    to avoid duplicate events and ensure proper event attribution:
223

224
225
226
    - Each DP rank creates a separate publisher
    - Publishers automatically annotate events with their data_parallel_rank
    - This allows consumers to distinguish events from different DP ranks
227

228
229
230
231
232
233
    The publisher is responsible for adding DP metadata since the scheduler
    operates independently of DP topology and shouldn't need DP awareness.
    """

    def __init__(self, data_parallel_rank: int = 0) -> None:
        self._data_parallel_rank = data_parallel_rank
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

    @abstractmethod
    def publish(self, events: EventBatch) -> None:
        """Emit events in order.

        Implementations should guarantee at-least-once delivery and
        monotonic ordering (e.g., via sequence numbers).
        """

    @abstractmethod
    def shutdown(self) -> None:
        """Shutdown the publisher."""


class NullEventPublisher(EventPublisher):
    """No-op implementation (default when disabled)."""

    def publish(self, events) -> None:
        return

    def shutdown(self) -> None:
        return


class ZmqEventPublisher(EventPublisher):
    """Reliable PUB/ROUTER publisher with an in-memory replay buffer.

    Spawns a separate thread to handle publishing from a queue.

    Parameters
    ----------
    endpoint:
266
        PUB address. Use `tcp://*:5557` to bind or `tcp://host:5557` to
267
268
269
270
271
272
273
274
275
276
277
278
279
280
        connect.
    replay_endpoint:
        Optional ROUTER address for replay requests. When given, subscribers can
        request missed batches by sending the starting sequence number as an
        8-byte big-endian integer.
    buffer_steps:
        Number of past batches to keep for replay.
    hwm:
        ZeroMQ high-water-mark for PUB socket.
    max_queue_size:
        Maximum number of events to buffer in memory.
    topic:
        Topic to publish events to.
    """
281

282
283
284
285
286
    SHUTDOWN_TIMEOUT: float = 1.0
    END_SEQ = (-1).to_bytes(8, "big", signed=True)

    def __init__(
        self,
287
        data_parallel_rank: int,
288
        endpoint: str = "tcp://*:5557",
289
        replay_endpoint: str | None = None,
290
291
292
293
294
295
        buffer_steps: int = 10_000,
        hwm: int = 100_000,
        max_queue_size: int = 100_000,
        topic: str = "",
    ) -> None:
        # Storage
296
        super().__init__(data_parallel_rank)
297
        self._event_queue = Queue[EventBatch | None](maxsize=max_queue_size)
298
299
300
301
        self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)

        # ZMQ sockets
        self._ctx = zmq.Context.instance()
302
303
        self._pub: zmq.Socket | None = None
        self._replay: zmq.Socket | None = None
304
305
306
307
        self._dp_rank = data_parallel_rank

        self._endpoint = self.offset_endpoint_port(endpoint, self._dp_rank)
        self._replay_endpoint = self.offset_endpoint_port(
308
309
            replay_endpoint, self._dp_rank
        )
310
        self._hwm = hwm
311
        self._socket_setup()
312
313
314

        # Payload
        self._seq_gen = count()
315
        self._topic_bytes = topic.encode("utf-8")
316
317
318
319
320

        # Thread
        self._running = True
        logger.info("Starting ZMQ publisher thread")

321
322
323
        self._thread = threading.Thread(
            target=self._publisher_thread, daemon=True, name="zmq-publisher"
        )
324
325
326
327
328
        self._thread.start()

    def publish(self, events: EventBatch) -> None:
        if not self._running:
            raise RuntimeError("Publisher is closed")
329
330
        if events.data_parallel_rank is None:
            events.data_parallel_rank = self._data_parallel_rank
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
        self._event_queue.put(events)

    def shutdown(self) -> None:
        """Stop the publisher thread and clean up resources."""
        self._running = False
        self._event_queue.put_nowait(None)

        start = time.time()
        pending_items = True
        while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
            pending_items = not self._event_queue.empty()
            if pending_items:
                time.sleep(0.1)

        if pending_items:
            logger.warning(
                "Warning: Queue still has %s items after %s seconds timeout",
                self._event_queue.qsize(),
                self.SHUTDOWN_TIMEOUT,
            )

        if self._thread.is_alive():
            self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)

        # Clean up ZMQ resources
        try:
            if self._pub is not None:
                self._pub.close(linger=0)
            if self._replay is not None:
                self._replay.close(linger=0)
        finally:
            pass  # Do not terminate context; other sockets may use it

    def _socket_setup(self) -> None:
        """Initialize sockets
        https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
        """
        if self._pub is None:
            self._pub = self._ctx.socket(zmq.PUB)
            self._pub.set_hwm(self._hwm)
            # Heuristic: bind if wildcard / * present, else connect.
            # bind stable, connect volatile convention
373
374
375
376
377
378
            if self._endpoint is not None and (
                "*" in self._endpoint
                or "::" in self._endpoint
                or self._endpoint.startswith("ipc://")
                or self._endpoint.startswith("inproc://")
            ):
379
                self._pub.bind(self._endpoint)
380
            elif self._endpoint is not None:
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
                self._pub.connect(self._endpoint)

        # Set up replay socket: use ROUTER
        # 1) handles multiple REQ clients (identities)
        # 2) lets us send back one request → many replies (streamed events)
        # 3) works in our non‑blocking poll loop alongside PUB
        if self._replay_endpoint is not None:
            self._replay = self._ctx.socket(zmq.ROUTER)
            self._replay.bind(self._replay_endpoint)

    def _publisher_thread(self) -> None:
        """Background thread that processes the event queue."""
        self._pack = msgspec.msgpack.Encoder()

        assert self._pub is not None  # narrows type for mypy

        while self._running or self._event_queue.qsize() > 0:
            # --- replay (non-critical) ---------------------------------
            if self._replay is not None and self._replay.poll(0):
                try:
                    self._service_replay()
                except Exception as e:
                    logger.exception("Error in replay: %s", e)

            # --- main queue (critical) ---------------------------------
            try:
                event = self._event_queue.get(timeout=0.1)
                if event is None:
                    break  # Sentinel received, exit thread
            except queue.Empty:
                continue

            try:
                seq = next(self._seq_gen)

                payload = self._pack.encode(event)
                seq_bytes = seq.to_bytes(8, "big")
418
                self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444

                self._buffer.append((seq, payload))
                self._event_queue.task_done()

            except Exception as e:
                # Publishing failed;  back-off a bit to avoid a tight error loop
                logger.exception("Error in publisher thread: %s", e)
                time.sleep(0.1)

    def _service_replay(self) -> None:
        """If a replay request is waiting, send buffered batches."""
        assert self._replay is not None  # narrows type for mypy

        frame = self._replay.recv_multipart()
        if len(frame) != 3:
            logger.warning("Invalid replay request: %s", frame)
            return
        client_id, _, start_seq_bytes = frame
        start_seq = int.from_bytes(start_seq_bytes, "big")

        for seq, buf in self._buffer:
            if seq >= start_seq:
                # [identity, empty_delim, seq_bytes, payload]
                # (identity, empty_delim) are stripped off by the router
                # receiving payload is (seq_bytes, payload)
                self._replay.send_multipart(
445
446
                    (client_id, b"", seq.to_bytes(8, "big"), buf)
                )
447
448
449
450
        # Send end of sequence marker
        # receiving payload is (-1, b""")
        self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))

451
    @staticmethod
452
    def offset_endpoint_port(
453
454
        endpoint: str | None, data_parallel_rank: int
    ) -> str | None:
455
        """Helper function to offset the port in an endpoint by
456
457
458
            the data parallel rank.

        Args:
459
            endpoint: The endpoint string
460
461
462
463
                (e.g., "tcp://*:5557" or "inproc://cache")
            data_parallel_rank: The data parallel rank to offset by

        Returns:
464
            The endpoint with the port offset by data_parallel_rank
465
466
467
468
469
470
471
472
473
474
475
476
477
                or suffix appended
        """
        # Do nothing if input is None or data_parallel_rank is 0
        if not endpoint or data_parallel_rank == 0:
            return endpoint

        if "inproc" in endpoint:
            return f"{endpoint}_dp{data_parallel_rank}"
        if "tcp" in endpoint:
            if endpoint and ":" in endpoint:
                # Get everything after the last colon (the port)
                last_colon_idx = endpoint.rfind(":")
                base_addr = endpoint[:last_colon_idx]
478
                base_port = int(endpoint[last_colon_idx + 1 :])
479
480
481
482
483
                new_port = base_port + data_parallel_rank
                return f"{base_addr}:{new_port}"
            return endpoint
        raise ValueError("Invalid endpoint: must contain 'inproc' or 'tcp'")

484
485
486
487
488
489
490
491

class EventPublisherFactory:
    _registry: dict[str, Callable[..., EventPublisher]] = {
        "null": NullEventPublisher,
        "zmq": ZmqEventPublisher,
    }

    @classmethod
492
    def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
493
494
495
496
497
        if name in cls._registry:
            raise KeyError(f"publisher '{name}' already registered")
        cls._registry[name] = ctor

    @classmethod
498
    def create(
499
        cls, config: KVEventsConfig | None, data_parallel_rank: int = 0
500
    ) -> EventPublisher:
501
        """Create publisher from a config mapping."""
502
503
504
505
506
        if (
            config is None
            or not config.enable_kv_cache_events
            or config.publisher == "null"
        ):
507
508
            return NullEventPublisher()

509
        config_dict = asdict(config)
510

511
        kind = config_dict.pop("publisher")
512
513
514
515
516
        config_dict.pop("enable_kv_cache_events")
        try:
            constructor = cls._registry[kind]
        except KeyError as exc:
            raise ValueError(f"Unknown event publisher '{kind}'") from exc
517
        return constructor(data_parallel_rank=data_parallel_rank, **config_dict)