kv-events-custom-engines.md 10.3 KB
Newer Older
1
2
3
---
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
4
title: KV Events for Custom Engines
5
6
7
8
9
10
11
12
---

This document explains how to implement KV event publishing for custom inference engines, enabling them to participate in Dynamo's KV cache-aware routing.

## Overview

The KV Router relies on real-time events from backend workers to track which KV cache blocks are stored on each worker. When your custom engine allocates or evicts KV cache blocks, it should publish these events so the router can make optimal routing decisions.

13
Events are published over the **Dynamo event plane**, a transport-agnostic pub/sub layer that supports both NATS and ZMQ backends (see [Event Plane](../design-docs/event-plane.md) for details). The `KvEventPublisher` binding handles all transport concerns — your engine code does not interact with the event plane directly.
14

15
16
17
`KvEventPublisher` supports two publishing modes:

1. **Direct publishing** — Your engine calls `publish_stored()` / `publish_removed()` to push events directly over the event plane. Simplest approach for custom engines.
18
2. **ZMQ relay** — For engines that emit raw KV events over a ZMQ socket (like SGLang and vLLM). The publisher subscribes to the ZMQ endpoint and relays events to the event plane automatically.
19
20
21
22
23
24
25
26
27
28
29
30
31
32

## Event Types

The KV cache supports three event types:

| Event Type | Description | When to Publish |
|------------|-------------|-----------------|
| `BlockStored` | New blocks added to cache | After KV cache allocation succeeds |
| `BlockRemoved` | Blocks evicted from cache | When blocks are evicted or freed |
| `AllBlocksCleared` | All blocks removed | On cache reset or worker restart |

### Event Structure

Each event contains:
33
- **`event_id`**: Monotonically increasing identifier per worker (managed internally by the publisher)
34
35
36
37
38
39
40
41
- **`dp_rank`**: Data parallel rank (0 if DP not enabled)
- **`data`**: One of `Stored`, `Removed`, or `Cleared`

For `BlockStored` events:
- **`token_ids`**: List of token IDs for the stored blocks
- **`block_hashes`**: List of **sequence block hashes** from the engine's block manager. These are cumulative hashes that incorporate all tokens from the start of the sequence up to and including the current block (not just the tokens within that block). This enables prefix matching across requests.
- **`num_block_tokens`**: Number of tokens per block (should all equal `kv_block_size`)
- **`parent_hash`**: Hash of the parent block. Required for all blocks except the first block in a sequence (which has no parent).
42
- **`lora_name`**: LoRA adapter name string (omit or `None` for base model). When set, the adapter name is incorporated into block hash computation so that blocks for different LoRA adapters (or the base model) are never conflated.
43
44
45
46

For `BlockRemoved` events:
- **`block_hashes`**: List of sequence block hashes being evicted

47
## Direct Publishing (Recommended for Custom Engines)
48

49
Call `publish_stored()` and `publish_removed()` directly from your engine code. The publisher handles event IDs, serialization, and transport.
50
51
52
53
54
55
56
57
58
59
60

```mermaid
flowchart LR
    subgraph Engine["Custom Engine"]
        cache["KV Cache Manager"]
    end

    subgraph Worker["Dynamo Worker Process"]
        pub["KvEventPublisher"]
    end

61
62
    subgraph EP["Dynamo Event Plane"]
        topic["kv-events topic"]
63
64
65
66
67
68
    end

    subgraph Router["KV Router"]
        indexer["KvIndexer"]
    end

69
70
71
    cache -->|"publish_stored()<br/>publish_removed()"| pub
    pub -->|"event plane"| topic
    topic --> indexer
72
73
74
75
76
77
78
79
80
81
82
83
84
```

**When to use:**
- Building a custom inference engine from scratch
- Your engine doesn't have a ZMQ-based event system
- You want the simplest integration path

### Basic Setup

```python
from dynamo.llm import KvEventPublisher

class CustomEnginePublisher:
85
    def __init__(self, component, block_size: int, dp_rank: int = 0):
86
87
88
89
90
91
92
93
        self.block_size = block_size
        self.kv_publisher = KvEventPublisher(
            component=component,
            kv_block_size=block_size,
            dp_rank=dp_rank,
        )

    def on_blocks_stored(self, token_ids: list[int], block_hashes: list[int],
94
95
                         parent_hash: int | None = None,
                         lora_name: str | None = None):
96
97
98
99
100
101
102
        """Call after KV cache blocks are allocated."""
        num_block_tokens = [self.block_size] * len(block_hashes)
        self.kv_publisher.publish_stored(
            token_ids=token_ids,
            num_block_tokens=num_block_tokens,
            block_hashes=block_hashes,
            parent_hash=parent_hash,
103
            lora_name=lora_name,
104
105
106
107
        )

    def on_blocks_removed(self, block_hashes: list[int]):
        """Call when KV cache blocks are evicted."""
108
        self.kv_publisher.publish_removed(block_hashes=block_hashes)
109
110
111
112
113
```

### Integration with Your Engine

```python
114
from dynamo.llm import register_model
115
116

async def main():
117
    component, endpoint = await register_model(
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
        model="my-model",
        generator=my_generate_fn,
    )

    publisher = CustomEnginePublisher(
        component=component,
        block_size=16,  # Match your engine's block size
    )

    def on_prefill_complete(request_id, token_ids, blocks):
        block_hashes = [block.hash for block in blocks]
        publisher.on_blocks_stored(token_ids=token_ids, block_hashes=block_hashes)

    def on_cache_eviction(evicted_blocks):
        block_hashes = [block.hash for block in evicted_blocks]
        publisher.on_blocks_removed(block_hashes=block_hashes)
```

136
## ZMQ Relay (For Engines with Raw KV Events)
137

138
For engines that already publish raw KV events over a ZMQ socket (like SGLang and vLLM), use the same `KvEventPublisher` with a `zmq_endpoint`. The publisher subscribes to the ZMQ socket and relays events to the event plane automatically.
139
140
141

```mermaid
flowchart LR
142
    subgraph Engine["Custom Engine / SGLang / vLLM"]
143
        cache["KV Cache Manager"]
144
        zmq_pub["ZMQ Publisher"]
145
146
147
148
149
150
151
    end

    subgraph ZMQ["ZMQ Socket"]
        socket["tcp://127.0.0.1:5557"]
    end

    subgraph Worker["Dynamo Worker Process"]
152
        relay["KvEventPublisher<br/>(relay mode)"]
153
154
    end

155
156
    subgraph EP["Dynamo Event Plane"]
        topic["kv-events topic"]
157
158
159
160
161
162
163
164
    end

    subgraph Router["KV Router"]
        indexer["KvIndexer"]
    end

    cache --> zmq_pub
    zmq_pub -->|"PUB"| socket
165
166
167
    socket -->|"SUB"| relay
    relay -->|"event plane"| topic
    topic --> indexer
168
169
170
```

**When to use:**
171
- Your engine already publishes KV events via ZMQ (like SGLang or vLLM)
172
173
- You want to decouple event publishing from your engine's main loop

174
### Setup
175

176
Pass `zmq_endpoint` (and optional `zmq_topic`) to the same `KvEventPublisher`:
177
178

```python
179
from dynamo.llm import KvEventPublisher
180
181
182

kv_publisher = KvEventPublisher(
    component=component,
183
184
185
    kv_block_size=block_size,
    zmq_endpoint="tcp://127.0.0.1:5557",  # Where your engine publishes
    zmq_topic="",                          # Subscribe to all topics
186
187
188
)
```

189
No further calls to `publish_stored()` / `publish_removed()` are needed — the publisher reads events from the ZMQ socket and forwards them automatically.
190
191
192

### ZMQ Wire Format

193
The ZMQ message format (compatible with SGLang / vLLM):
194
195
196
197
198
199
200

| Frame | Description |
|-------|-------------|
| 1 | Topic (empty string for all topics) |
| 2 | Sequence number (8 bytes, big-endian) |
| 3 | Msgpack payload: `[timestamp, [events], dp_rank]` |

201
202
203
204
205
206
207
208
209
210
Each event in the payload is a dictionary with a `type` field (`BlockStored`, `BlockRemoved`, or `AllBlocksCleared`).

For `BlockStored`:
```python
{
    "type": "BlockStored",
    "block_hashes": [signed_i64, ...],      # Sequence block hashes
    "parent_block_hash": signed_i64 | None,  # Parent hash
    "token_ids": [int, ...],                 # Token IDs
    "block_size": int,                       # Tokens per block
211
    "lora_name": str | None,                 # LoRA adapter name
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
}
```

For `BlockRemoved`:
```python
{
    "type": "BlockRemoved",
    "block_hashes": [signed_i64, ...],
}
```

For `AllBlocksCleared`:
```python
{"type": "AllBlocksCleared"}
```

## API Reference

### `KvEventPublisher`

```python
KvEventPublisher(
    component: Component,
    kv_block_size: int,
    dp_rank: int = 0,
    enable_local_indexer: bool = False,
    zmq_endpoint: str | None = None,   # Set for relay mode
    zmq_topic: str | None = None,      # Defaults to "" when zmq_endpoint is set
)
```

| Parameter | Description |
|-----------|-------------|
| `component` | The Dynamo component this publisher belongs to |
| `kv_block_size` | Number of tokens per block (must be > 0, must match your engine) |
| `dp_rank` | Data parallel rank (defaults to 0) |
| `enable_local_indexer` | Enable a worker-local KV indexer for direct overlap queries |
| `zmq_endpoint` | ZMQ endpoint to subscribe to for relay mode (e.g. `"tcp://127.0.0.1:5557"`) |
| `zmq_topic` | ZMQ topic filter (defaults to `""` = all topics) |

#### `publish_stored()`

```python
publish_stored(
    token_ids: list[int],
    num_block_tokens: list[int],
    block_hashes: list[int],
    parent_hash: int | None = None,
260
261
    block_mm_infos: list[dict | None] | None = None,
    lora_name: str | None = None,
262
263
264
)
```

265
Publish a block-stored event. Event IDs are managed internally. When `lora_name` is provided, the adapter name is mixed into block hash computation so blocks cached under different adapters produce distinct hashes.
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281

#### `publish_removed()`

```python
publish_removed(block_hashes: list[int])
```

Publish a block-removed event. Event IDs are managed internally.

#### `shutdown()`

```python
shutdown()
```

Stop background tasks (ZMQ listener, event forwarding).
282
283
284

## Best Practices

285
286
287
1. **`kv_block_size` must match** your engine's actual block size.

2. **`parent_hash` is required** for all blocks except the first in a sequence — it links blocks to enable prefix matching.
288

289
3. **Block hashes are signed 64-bit integers** in the Python API. The publisher handles conversion internally.
290

291
4. **Event ordering is automatic** — the publisher assigns monotonically increasing event IDs. You do not need to track event IDs yourself.
292
293
294

## See Also

295
- **[Event Plane](../design-docs/event-plane.md)**: Transport options (NATS, ZMQ) and configuration
296
297
- **[Router Guide](../components/router/router-guide.md)**: Configuration, tuning, and production setup
- **[Router Design](../design-docs/router-design.md)**: Architecture details and event transport modes