kv_events_subscriber.py 3.63 KB
Newer Older
1
# SPDX-License-Identifier: Apache-2.0
2
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
from typing import Any
4
5
6
7
8

import msgspec
import zmq
from msgspec.msgpack import Decoder

9
from vllm.v1.core.kv_cache_utils import ExternalBlockHash
10

11
12
13
14

#
# Types copied from vllm.distributed.kv_events
#
15
class EventBatch(msgspec.Struct, array_like=True, omit_defaults=True, gc=False):
16
17
18
19
    ts: float
    events: list[Any]


20
21
22
class KVCacheEvent(
    msgspec.Struct, array_like=True, omit_defaults=True, gc=False, tag=True
):
23
24
25
26
    """Base class for all KV cache-related events"""


class BlockStored(KVCacheEvent):
27
    block_hashes: list[ExternalBlockHash]
28
    parent_block_hash: ExternalBlockHash | None
29
30
    token_ids: list[int]
    block_size: int
31

32
    lora_id: int | None
33
34
35
36
    """Deprecated: use `lora_name` for KV block key hash.
    Retained for backward compatibility.
    """

37
    medium: str | None
38
    lora_name: str | None
39

40
41
42
43
44
45
    extra_keys: list[tuple[Any, ...] | None] | None = None
    """Extra keys used in block hash computation, one entry per block in
    block_hashes. Each entry contains MM identifiers, LoRA name, cache_salt,
    prompt embeddings data, etc. for that specific block.
    """

46
47

class BlockRemoved(KVCacheEvent):
48
    block_hashes: list[ExternalBlockHash]
49
    medium: str | None
50
51
52
53
54
55
56


class AllBlocksCleared(KVCacheEvent):
    pass


class KVEventBatch(EventBatch):
57
    events: list[BlockStored | BlockRemoved | AllBlocksCleared]
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93


def process_event(event_batch):
    print(f"Received event batch at {event_batch.ts}:")
    for event in event_batch.events:
        print(f"  - {event}")


def main():
    decoder = Decoder(type=KVEventBatch)
    last_seq = -1

    context = zmq.Context()

    # Set up the main subscription socket
    sub = context.socket(zmq.SUB)
    sub.connect("tcp://localhost:5557")
    topic = "kv-events"
    sub.setsockopt_string(zmq.SUBSCRIBE, topic)

    # Initialize replay socket
    replay = context.socket(zmq.REQ)
    replay.connect("tcp://localhost:5558")
    poller = zmq.Poller()
    poller.register(replay, zmq.POLLIN)

    print("Listening for KV cache events on topic:", topic)

    while True:
        try:
            if sub.poll(50):
                _, seq_bytes, payload = sub.recv_multipart()
                seq = int.from_bytes(seq_bytes, "big")

                if last_seq >= 0 and seq > last_seq + 1:
                    missed = seq - last_seq - 1
94
95
96
                    print(
                        f"Missed {missed} messages (last: {last_seq}, current: {seq})"
                    )
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129

                    replay.send((last_seq + 1).to_bytes(8, "big"))

                    while poller.poll(timeout=200):
                        seq_bytes, replay_payload = replay.recv_multipart()
                        if not replay_payload:
                            # End of replay marker is sent as an empty frame
                            # for the payload
                            break

                        replay_seq = int.from_bytes(seq_bytes, "big")

                        if replay_seq > last_seq:
                            event_batch = decoder.decode(replay_payload)
                            process_event(event_batch)
                            last_seq = replay_seq
                            if replay_seq >= seq - 1:
                                break

                event_batch = decoder.decode(payload)
                process_event(event_batch)

            # ... do other periodic work or check for shutdown ...

        except KeyboardInterrupt:
            print("Interrupted")
            break
        except Exception as e:
            print("Error decoding message:", e)


if __name__ == "__main__":
    main()