Unverified Commit 56d91ee9 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

docs: update KV event docs to use event plane terminology (#6326)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 87c5c7bf
...@@ -276,268 +276,11 @@ See [Router Design](../../design-docs/router-design.md) for architecture details ...@@ -276,268 +276,11 @@ See [Router Design](../../design-docs/router-design.md) for architecture details
## KV Event Publishing for Custom Engines ## KV Event Publishing for Custom Engines
The KV Router relies on real-time events from backend workers to track which KV cache blocks are stored on each worker. When your custom engine allocates or evicts KV cache blocks, it should publish these events so the router can make optimal routing decisions. There are two main publishing pathways: direct NATS publishing (`KvEventPublisher`) which publishes events directly to NATS and is the simplest approach for custom engines, and ZMQ-based publishing for engines with ZMQ event output (like vLLM) which uses a ZMQ publisher in the engine and `ZmqKvEventPublisher` to forward events to NATS. For full documentation on implementing KV event publishing for custom inference engines, see the dedicated [KV Event Publishing for Custom Engines](../../integrations/kv-events-custom-engines.md) guide. It covers:
### Event Types - **Direct publishing**: Call `publish_stored()` / `publish_removed()` to push events over the Dynamo event plane
- **ZMQ relay**: For engines that emit raw KV events over ZMQ (like vLLM and SGLang), the same `KvEventPublisher` subscribes to the ZMQ socket and relays events automatically
The KV cache supports three event types: - API reference, event structure, ZMQ wire format, and best practices
| Event Type | Description | When to Publish |
|------------|-------------|-----------------|
| `BlockStored` | New blocks added to cache | After KV cache allocation succeeds |
| `BlockRemoved` | Blocks evicted from cache | When blocks are evicted or freed |
| `AllBlocksCleared` | All blocks removed | On cache reset or worker restart |
### Event Structure
Each event contains:
- **`event_id`**: Monotonically increasing identifier per worker
- **`dp_rank`**: Data parallel rank (0 if DP not enabled)
- **`data`**: One of `Stored`, `Removed`, or `Cleared`
For `BlockStored` events:
- **`token_ids`**: List of token IDs for the stored blocks
- **`block_hashes`**: List of **sequence block hashes** from the engine's block manager. These are cumulative hashes that incorporate all tokens from the start of the sequence up to and including the current block (not just the tokens within that block). This enables prefix matching across requests.
- **`num_block_tokens`**: Number of tokens per block (should all equal `kv_block_size`)
- **`parent_hash`**: Hash of the parent block. Required for all blocks except the first block in a sequence (which has no parent).
- **`lora_id`**: LoRA adapter ID (0 if not using LoRA)
For `BlockRemoved` events:
- **`block_hashes`**: List of sequence block hashes being evicted
### Option 1: Direct NATS Publishing (Recommended)
The `KvEventPublisher` class publishes events directly to NATS. This is the simplest approach for custom engines.
```mermaid
flowchart LR
subgraph Engine["Custom Engine"]
cache["KV Cache Manager"]
end
subgraph Worker["Dynamo Worker Process"]
pub["KvEventPublisher"]
end
subgraph NATS["NATS"]
subject["kv-events subject"]
end
subgraph Router["KV Router"]
indexer["KvIndexer"]
end
cache -->|"on_blocks_stored()<br/>on_blocks_removed()"| pub
pub -->|"publish to NATS"| subject
subject --> indexer
```
**When to use:**
- Building a custom inference engine from scratch
- Your engine doesn't have a ZMQ-based event system
- You want the simplest integration path
#### Basic Setup
```python
from dynamo.llm import KvEventPublisher
class CustomEnginePublisher:
def __init__(self, component, worker_id: int, block_size: int, dp_rank: int = 0):
self.block_size = block_size
self.event_id = 0
self.kv_publisher = KvEventPublisher(
component=component,
worker_id=worker_id,
kv_block_size=block_size,
dp_rank=dp_rank,
enable_local_indexer=False,
)
def on_blocks_stored(self, token_ids: list[int], block_hashes: list[int],
lora_id: int = 0, parent_hash: int | None = None):
"""Call after KV cache blocks are allocated."""
self.event_id += 1
num_block_tokens = [self.block_size] * len(block_hashes)
self.kv_publisher.publish_stored(
event_id=self.event_id,
token_ids=token_ids,
num_block_tokens=num_block_tokens,
block_hashes=block_hashes,
lora_id=lora_id,
parent_hash=parent_hash,
)
def on_blocks_removed(self, block_hashes: list[int]):
"""Call when KV cache blocks are evicted."""
self.event_id += 1
self.kv_publisher.publish_removed(event_id=self.event_id, block_hashes=block_hashes)
```
#### Integration with Your Engine
```python
from dynamo.llm import register_model
async def main():
# Register your engine with Dynamo
component, endpoint = await register_model(
model="my-model",
generator=my_generate_fn,
)
# Initialize publisher
publisher = CustomEnginePublisher(
component=component,
worker_id=endpoint.connection_id(),
block_size=16, # Match your engine's block size
)
# Hook into your engine's cache events
def on_prefill_complete(request_id, token_ids, blocks):
block_hashes = [block.hash for block in blocks]
publisher.on_blocks_stored(token_ids=token_ids, block_hashes=block_hashes)
def on_cache_eviction(evicted_blocks):
block_hashes = [block.hash for block in evicted_blocks]
publisher.on_blocks_removed(block_hashes=block_hashes)
```
### Option 2: ZMQ-based Publishing
For engines that publish events via ZMQ (like vLLM), this option uses two components that work together:
1. **ZMQ Publisher** (in your engine) - Publishes events to a ZMQ socket
2. **ZmqKvEventPublisher** (Dynamo binding) - Subscribes to ZMQ and forwards to NATS
```mermaid
flowchart LR
subgraph Engine["Custom Engine / vLLM"]
cache["KV Cache Manager"]
zmq_pub["ZMQ Publisher<br/>(Pure Python)"]
end
subgraph ZMQ["ZMQ Socket"]
socket["tcp://127.0.0.1:5557"]
end
subgraph Worker["Dynamo Worker Process"]
zmq_sub["ZmqKvEventPublisher<br/>(Rust bindings)"]
end
subgraph NATS["NATS"]
subject["kv-events subject"]
end
subgraph Router["KV Router"]
indexer["KvIndexer"]
end
cache --> zmq_pub
zmq_pub -->|"PUB"| socket
socket -->|"SUB"| zmq_sub
zmq_sub --> subject
subject --> indexer
```
**When to use:**
- Your engine already has a ZMQ-based event system (like vLLM)
- You're integrating with a consolidator (like KVBM)
- You want to decouple event publishing from your engine's main loop
#### Part 1: ZMQ Subscriber (Dynamo Bindings)
If your engine already publishes to ZMQ, use `KvEventPublisher` with `zmq_endpoint` to subscribe and forward to NATS:
```python
from dynamo.llm import KvEventPublisher
# Create publisher - it automatically subscribes to ZMQ and forwards to NATS
kv_publisher = KvEventPublisher(
component=component,
kv_block_size=block_size,
zmq_endpoint="tcp://127.0.0.1:5557", # Where your engine publishes
zmq_topic="", # Subscribe to all topics
enable_local_indexer=False,
)
```
#### Part 2: ZMQ Publisher (Pure Python)
If your engine needs to publish to ZMQ (e.g., for consolidator integration), implement the ZMQ protocol:
```python
import zmq
import msgpack
import time
class ZmqKvEventPublisher:
"""Pure Python ZMQ publisher for KV events (vLLM-compatible format)."""
def __init__(self, zmq_endpoint: str, kv_block_size: int, topic: str = ""):
self.kv_block_size = kv_block_size
self.topic = topic
self.ctx = zmq.Context()
self.socket = self.ctx.socket(zmq.PUB)
self.socket.bind(zmq_endpoint)
self.sequence = 0
self.data_parallel_rank = 0
def _to_signed_i64(self, value: int | None) -> int | None:
if value is None:
return None
return value - 0x10000000000000000 if value > 0x7FFFFFFFFFFFFFFF else value
def publish_stored(self, event_id: int, token_ids: list[int], num_block_tokens: list[int],
block_hashes: list[int], lora_id: int = 0, parent_hash: int | None = None):
event = {
"type": "BlockStored",
"block_hashes": [self._to_signed_i64(h) for h in block_hashes],
"parent_block_hash": self._to_signed_i64(parent_hash),
"token_ids": token_ids,
"block_size": self.kv_block_size,
"lora_id": lora_id if lora_id != 0 else None,
}
self._publish_event(event)
def publish_removed(self, event_id: int, block_hashes: list[int]):
event = {"type": "BlockRemoved", "block_hashes": [self._to_signed_i64(h) for h in block_hashes]}
self._publish_event(event)
def publish_all_cleared(self):
self._publish_event({"type": "AllBlocksCleared"})
def _publish_event(self, event: dict):
batch = [time.time(), [event], self.data_parallel_rank]
payload = msgpack.packb(batch, use_bin_type=True)
sequence_bytes = self.sequence.to_bytes(8, byteorder="big")
self.sequence += 1
self.socket.send_multipart([self.topic.encode(), sequence_bytes, payload])
def shutdown(self):
self.socket.close()
self.ctx.term()
```
### ZMQ Wire Format
The ZMQ message format (compatible with vLLM):
| Frame | Description |
|-------|-------------|
| 1 | Topic (empty string for all topics) |
| 2 | Sequence number (8 bytes, big-endian) |
| 3 | Msgpack payload: `[timestamp, [events], dp_rank]` |
Each event in the payload is a dictionary with `type` field (`BlockStored`, `BlockRemoved`, or `AllBlocksCleared`).
### Best Practices
1. **Event IDs must be monotonically increasing** per worker (use a thread-safe counter)
2. **Block size must match** your engine's actual `kv_block_size`
3. **`parent_hash` is required** for all blocks except the first in a sequence - it links blocks to enable prefix matching
## Global Router (Hierarchical Routing) ## Global Router (Hierarchical Routing)
......
...@@ -250,7 +250,7 @@ flowchart TD ...@@ -250,7 +250,7 @@ flowchart TD
A["Distributed Inference Engine"] --> B["Dynamo KV Block Manager"] A["Distributed Inference Engine"] --> B["Dynamo KV Block Manager"]
B --> C["NIXL Storage Agent<br/>- Volume registration<br/>- get()/put() abstraction"] B --> C["NIXL Storage Agent<br/>- Volume registration<br/>- get()/put() abstraction"]
B --> D["Event Plane<br/>- NATS-based Pub/Sub<br/>- StoreEvent / RemoveEvent"] B --> D["Event Plane<br/>- Pub/Sub (NATS or ZMQ)<br/>- StoreEvent / RemoveEvent"]
C --> E["G4 Storage Infrastructure<br/>(SSD, Object store, etc.)<br/>- Store KV blocks"] C --> E["G4 Storage Infrastructure<br/>(SSD, Object store, etc.)<br/>- Store KV blocks"]
D --> F["Storage Provider Subscriber<br/>- Parse Events<br/>- Build fast tree/index<br/>- Optimize G4 tiering"] D --> F["Storage Provider Subscriber<br/>- Parse Events<br/>- Build fast tree/index<br/>- Optimize G4 tiering"]
...@@ -268,7 +268,7 @@ These abstractions allow backends to be integrated without tying into the host's ...@@ -268,7 +268,7 @@ These abstractions allow backends to be integrated without tying into the host's
#### Dynamo Event Plane (Pub/Sub Coordination Layer) #### Dynamo Event Plane (Pub/Sub Coordination Layer)
To support external storage optimizations without modifying KVBM logic, we provide an **event plane** built on NATS.io that emits lifecycle events for all block operations: To support external storage optimizations without modifying KVBM logic, we provide an **event plane** (supporting NATS and ZMQ transports) that emits lifecycle events for all block operations:
- **StoreEvent**: Emitted when a KV block is registered - **StoreEvent**: Emitted when a KV block is registered
- **RemoveEvent**: Emitted when a KV block is released or evicted - **RemoveEvent**: Emitted when a KV block is released or evicted
...@@ -295,7 +295,7 @@ External storage systems are not tightly coupled with Dynamo's execution pipelin ...@@ -295,7 +295,7 @@ External storage systems are not tightly coupled with Dynamo's execution pipelin
1. Storage volumes are pre-provisioned and mounted by the storage provider 1. Storage volumes are pre-provisioned and mounted by the storage provider
2. These volumes are registered with Dynamo through the NIXL Storage Agent using `registerVolume()` APIs 2. These volumes are registered with Dynamo through the NIXL Storage Agent using `registerVolume()` APIs
3. Dynamo KV Block Manager interacts only with logical block-level APIs (`get()` and `put()`) 3. Dynamo KV Block Manager interacts only with logical block-level APIs (`get()` and `put()`)
4. The Event Plane asynchronously broadcasts KV lifecycle events using a NATS-based pub/sub channel 4. The Event Plane asynchronously broadcasts KV lifecycle events via pub/sub (NATS or ZMQ)
5. Storage vendors implement a lightweight subscriber process that listens to these events 5. Storage vendors implement a lightweight subscriber process that listens to these events
To enable fast lookup and dynamic tiering, storage vendors may build internal data structures using the received event stream: To enable fast lookup and dynamic tiering, storage vendors may build internal data structures using the received event stream:
......
...@@ -119,7 +119,7 @@ The two types of events are: ...@@ -119,7 +119,7 @@ The two types of events are:
- KV stored event - KV stored event
- KV removed event - KV removed event
The publisher can be initialized and used through C bindings or Python bindings. The publisher can be initialized and used through Python bindings.
### Deterministic Event IDs ### Deterministic Event IDs
...@@ -224,7 +224,7 @@ graph TD ...@@ -224,7 +224,7 @@ graph TD
E3[Engine 3<br/>LocalKvIndexer] E3[Engine 3<br/>LocalKvIndexer]
end end
subgraph "NATS Core" subgraph "Event Plane (NATS / ZMQ)"
NC[KV Events Pub/Sub<br/>- Block created<br/>- Block removed] NC[KV Events Pub/Sub<br/>- Block created<br/>- Block removed]
end end
......
...@@ -12,10 +12,12 @@ This document explains how to implement KV event publishing for custom inference ...@@ -12,10 +12,12 @@ This document explains how to implement KV event publishing for custom inference
The KV Router relies on real-time events from backend workers to track which KV cache blocks are stored on each worker. When your custom engine allocates or evicts KV cache blocks, it should publish these events so the router can make optimal routing decisions. The KV Router relies on real-time events from backend workers to track which KV cache blocks are stored on each worker. When your custom engine allocates or evicts KV cache blocks, it should publish these events so the router can make optimal routing decisions.
There are two main publishing pathways: Events are published over the **Dynamo event plane**, a transport-agnostic pub/sub layer that supports both NATS and ZMQ backends (see [Event Plane](../design-docs/event-plane.md) for details). The `KvEventPublisher` binding handles all transport concerns — your engine code does not interact with the event plane directly.
1. **Direct NATS publishing** (`KvEventPublisher`) - Publishes events directly to NATS. Simplest approach for custom engines. `KvEventPublisher` supports two publishing modes:
2. **ZMQ-based publishing** - For engines with ZMQ event output (like vLLM). Uses a ZMQ publisher in the engine and `ZmqKvEventPublisher` to forward events to NATS.
1. **Direct publishing** — Your engine calls `publish_stored()` / `publish_removed()` to push events directly over the event plane. Simplest approach for custom engines.
2. **ZMQ relay** — For engines that emit raw KV events over a ZMQ socket (like vLLM and SGLang). The publisher subscribes to the ZMQ endpoint and relays events to the event plane automatically.
## Event Types ## Event Types
...@@ -30,7 +32,7 @@ The KV cache supports three event types: ...@@ -30,7 +32,7 @@ The KV cache supports three event types:
### Event Structure ### Event Structure
Each event contains: Each event contains:
- **`event_id`**: Monotonically increasing identifier per worker - **`event_id`**: Monotonically increasing identifier per worker (managed internally by the publisher)
- **`dp_rank`**: Data parallel rank (0 if DP not enabled) - **`dp_rank`**: Data parallel rank (0 if DP not enabled)
- **`data`**: One of `Stored`, `Removed`, or `Cleared` - **`data`**: One of `Stored`, `Removed`, or `Cleared`
...@@ -44,9 +46,9 @@ For `BlockStored` events: ...@@ -44,9 +46,9 @@ For `BlockStored` events:
For `BlockRemoved` events: For `BlockRemoved` events:
- **`block_hashes`**: List of sequence block hashes being evicted - **`block_hashes`**: List of sequence block hashes being evicted
## Option 1: Direct NATS Publishing (Recommended) ## Direct Publishing (Recommended for Custom Engines)
The `KvEventPublisher` class publishes events directly to NATS. This is the simplest approach for custom engines. Call `publish_stored()` and `publish_removed()` directly from your engine code. The publisher handles event IDs, serialization, and transport.
```mermaid ```mermaid
flowchart LR flowchart LR
...@@ -58,17 +60,17 @@ flowchart LR ...@@ -58,17 +60,17 @@ flowchart LR
pub["KvEventPublisher"] pub["KvEventPublisher"]
end end
subgraph NATS["NATS"] subgraph EP["Dynamo Event Plane"]
subject["kv-events subject"] topic["kv-events topic"]
end end
subgraph Router["KV Router"] subgraph Router["KV Router"]
indexer["KvIndexer"] indexer["KvIndexer"]
end end
cache -->|"on_blocks_stored()<br/>on_blocks_removed()"| pub cache -->|"publish_stored()<br/>publish_removed()"| pub
pub -->|"publish to NATS"| subject pub -->|"event plane"| topic
subject --> indexer topic --> indexer
``` ```
**When to use:** **When to use:**
...@@ -82,12 +84,10 @@ flowchart LR ...@@ -82,12 +84,10 @@ flowchart LR
from dynamo.llm import KvEventPublisher from dynamo.llm import KvEventPublisher
class CustomEnginePublisher: class CustomEnginePublisher:
def __init__(self, component, worker_id: int, block_size: int, dp_rank: int = 0): def __init__(self, component, block_size: int, dp_rank: int = 0):
self.block_size = block_size self.block_size = block_size
self.event_id = 0
self.kv_publisher = KvEventPublisher( self.kv_publisher = KvEventPublisher(
component=component, component=component,
worker_id=worker_id,
kv_block_size=block_size, kv_block_size=block_size,
dp_rank=dp_rank, dp_rank=dp_rank,
) )
...@@ -95,10 +95,8 @@ class CustomEnginePublisher: ...@@ -95,10 +95,8 @@ class CustomEnginePublisher:
def on_blocks_stored(self, token_ids: list[int], block_hashes: list[int], def on_blocks_stored(self, token_ids: list[int], block_hashes: list[int],
lora_id: int = 0, parent_hash: int | None = None): lora_id: int = 0, parent_hash: int | None = None):
"""Call after KV cache blocks are allocated.""" """Call after KV cache blocks are allocated."""
self.event_id += 1
num_block_tokens = [self.block_size] * len(block_hashes) num_block_tokens = [self.block_size] * len(block_hashes)
self.kv_publisher.publish_stored( self.kv_publisher.publish_stored(
event_id=self.event_id,
token_ids=token_ids, token_ids=token_ids,
num_block_tokens=num_block_tokens, num_block_tokens=num_block_tokens,
block_hashes=block_hashes, block_hashes=block_hashes,
...@@ -108,8 +106,7 @@ class CustomEnginePublisher: ...@@ -108,8 +106,7 @@ class CustomEnginePublisher:
def on_blocks_removed(self, block_hashes: list[int]): def on_blocks_removed(self, block_hashes: list[int]):
"""Call when KV cache blocks are evicted.""" """Call when KV cache blocks are evicted."""
self.event_id += 1 self.kv_publisher.publish_removed(block_hashes=block_hashes)
self.kv_publisher.publish_removed(event_id=self.event_id, block_hashes=block_hashes)
``` ```
### Integration with Your Engine ### Integration with Your Engine
...@@ -118,20 +115,16 @@ class CustomEnginePublisher: ...@@ -118,20 +115,16 @@ class CustomEnginePublisher:
from dynamo.llm import register_model from dynamo.llm import register_model
async def main(): async def main():
# Register your engine with Dynamo
component, endpoint = await register_model( component, endpoint = await register_model(
model="my-model", model="my-model",
generator=my_generate_fn, generator=my_generate_fn,
) )
# Initialize publisher
publisher = CustomEnginePublisher( publisher = CustomEnginePublisher(
component=component, component=component,
worker_id=endpoint.connection_id(),
block_size=16, # Match your engine's block size block_size=16, # Match your engine's block size
) )
# Hook into your engine's cache events
def on_prefill_complete(request_id, token_ids, blocks): def on_prefill_complete(request_id, token_ids, blocks):
block_hashes = [block.hash for block in blocks] block_hashes = [block.hash for block in blocks]
publisher.on_blocks_stored(token_ids=token_ids, block_hashes=block_hashes) publisher.on_blocks_stored(token_ids=token_ids, block_hashes=block_hashes)
...@@ -141,18 +134,15 @@ async def main(): ...@@ -141,18 +134,15 @@ async def main():
publisher.on_blocks_removed(block_hashes=block_hashes) publisher.on_blocks_removed(block_hashes=block_hashes)
``` ```
## Option 2: ZMQ-based Publishing ## ZMQ Relay (For Engines with Raw KV Events)
For engines that publish events via ZMQ (like vLLM), this option uses two components that work together:
1. **ZMQ Publisher** (in your engine) - Publishes events to a ZMQ socket For engines that already publish raw KV events over a ZMQ socket (like vLLM and SGLang), use the same `KvEventPublisher` with a `zmq_endpoint`. The publisher subscribes to the ZMQ socket and relays events to the event plane automatically.
2. **ZmqKvEventPublisher** (Dynamo binding) - Subscribes to ZMQ and forwards to NATS
```mermaid ```mermaid
flowchart LR flowchart LR
subgraph Engine["Custom Engine / vLLM"] subgraph Engine["Custom Engine / vLLM / SGLang"]
cache["KV Cache Manager"] cache["KV Cache Manager"]
zmq_pub["ZMQ Publisher<br/>(Pure Python)"] zmq_pub["ZMQ Publisher"]
end end
subgraph ZMQ["ZMQ Socket"] subgraph ZMQ["ZMQ Socket"]
...@@ -160,11 +150,11 @@ flowchart LR ...@@ -160,11 +150,11 @@ flowchart LR
end end
subgraph Worker["Dynamo Worker Process"] subgraph Worker["Dynamo Worker Process"]
zmq_sub["ZmqKvEventPublisher<br/>(Rust bindings)"] relay["KvEventPublisher<br/>(relay mode)"]
end end
subgraph NATS["NATS"] subgraph EP["Dynamo Event Plane"]
subject["kv-events subject"] topic["kv-events topic"]
end end
subgraph Router["KV Router"] subgraph Router["KV Router"]
...@@ -173,24 +163,22 @@ flowchart LR ...@@ -173,24 +163,22 @@ flowchart LR
cache --> zmq_pub cache --> zmq_pub
zmq_pub -->|"PUB"| socket zmq_pub -->|"PUB"| socket
socket -->|"SUB"| zmq_sub socket -->|"SUB"| relay
zmq_sub --> subject relay -->|"event plane"| topic
subject --> indexer topic --> indexer
``` ```
**When to use:** **When to use:**
- Your engine already has a ZMQ-based event system (like vLLM) - Your engine already publishes KV events via ZMQ (like vLLM or SGLang)
- You're integrating with a consolidator (like KVBM)
- You want to decouple event publishing from your engine's main loop - You want to decouple event publishing from your engine's main loop
### Part 1: ZMQ Subscriber (Dynamo Bindings) ### Setup
If your engine already publishes to ZMQ, use `KvEventPublisher` with `zmq_endpoint` (and optional `zmq_topic`) to subscribe and forward to NATS: Pass `zmq_endpoint` (and optional `zmq_topic`) to the same `KvEventPublisher`:
```python ```python
from dynamo.llm import KvEventPublisher from dynamo.llm import KvEventPublisher
# Create publisher - it automatically subscribes to ZMQ and forwards to NATS
kv_publisher = KvEventPublisher( kv_publisher = KvEventPublisher(
component=component, component=component,
kv_block_size=block_size, kv_block_size=block_size,
...@@ -199,66 +187,11 @@ kv_publisher = KvEventPublisher( ...@@ -199,66 +187,11 @@ kv_publisher = KvEventPublisher(
) )
``` ```
### Part 2: ZMQ Publisher (Pure Python) No further calls to `publish_stored()` / `publish_removed()` are needed — the publisher reads events from the ZMQ socket and forwards them automatically.
If your engine needs to publish to ZMQ (e.g., for consolidator integration), implement the ZMQ protocol:
```python
import zmq
import msgpack
import time
class ZmqKvEventPublisher:
"""Pure Python ZMQ publisher for KV events (vLLM-compatible format)."""
def __init__(self, zmq_endpoint: str, kv_block_size: int, topic: str = ""):
self.kv_block_size = kv_block_size
self.topic = topic
self.ctx = zmq.Context()
self.socket = self.ctx.socket(zmq.PUB)
self.socket.bind(zmq_endpoint)
self.sequence = 0
self.data_parallel_rank = 0
def _to_signed_i64(self, value: int | None) -> int | None:
if value is None:
return None
return value - 0x10000000000000000 if value > 0x7FFFFFFFFFFFFFFF else value
def publish_stored(self, event_id: int, token_ids: list[int], num_block_tokens: list[int],
block_hashes: list[int], lora_id: int = 0, parent_hash: int | None = None):
event = {
"type": "BlockStored",
"block_hashes": [self._to_signed_i64(h) for h in block_hashes],
"parent_block_hash": self._to_signed_i64(parent_hash),
"token_ids": token_ids,
"block_size": self.kv_block_size,
"lora_id": lora_id if lora_id != 0 else None,
}
self._publish_event(event)
def publish_removed(self, event_id: int, block_hashes: list[int]):
event = {"type": "BlockRemoved", "block_hashes": [self._to_signed_i64(h) for h in block_hashes]}
self._publish_event(event)
def publish_all_cleared(self):
self._publish_event({"type": "AllBlocksCleared"})
def _publish_event(self, event: dict):
batch = [time.time(), [event], self.data_parallel_rank]
payload = msgpack.packb(batch, use_bin_type=True)
sequence_bytes = self.sequence.to_bytes(8, byteorder="big")
self.sequence += 1
self.socket.send_multipart([self.topic.encode(), sequence_bytes, payload])
def shutdown(self):
self.socket.close()
self.ctx.term()
```
### ZMQ Wire Format ### ZMQ Wire Format
The ZMQ message format (compatible with vLLM): The ZMQ message format (compatible with vLLM / SGLang):
| Frame | Description | | Frame | Description |
|-------|-------------| |-------|-------------|
...@@ -266,18 +199,99 @@ The ZMQ message format (compatible with vLLM): ...@@ -266,18 +199,99 @@ The ZMQ message format (compatible with vLLM):
| 2 | Sequence number (8 bytes, big-endian) | | 2 | Sequence number (8 bytes, big-endian) |
| 3 | Msgpack payload: `[timestamp, [events], dp_rank]` | | 3 | Msgpack payload: `[timestamp, [events], dp_rank]` |
Each event in the payload is a dictionary with `type` field (`BlockStored`, `BlockRemoved`, or `AllBlocksCleared`). Each event in the payload is a dictionary with a `type` field (`BlockStored`, `BlockRemoved`, or `AllBlocksCleared`).
For `BlockStored`:
```python
{
"type": "BlockStored",
"block_hashes": [signed_i64, ...], # Sequence block hashes
"parent_block_hash": signed_i64 | None, # Parent hash
"token_ids": [int, ...], # Token IDs
"block_size": int, # Tokens per block
"lora_id": int | None, # LoRA adapter ID
}
```
For `BlockRemoved`:
```python
{
"type": "BlockRemoved",
"block_hashes": [signed_i64, ...],
}
```
For `AllBlocksCleared`:
```python
{"type": "AllBlocksCleared"}
```
## API Reference
### `KvEventPublisher`
```python
KvEventPublisher(
component: Component,
kv_block_size: int,
dp_rank: int = 0,
enable_local_indexer: bool = False,
zmq_endpoint: str | None = None, # Set for relay mode
zmq_topic: str | None = None, # Defaults to "" when zmq_endpoint is set
)
```
| Parameter | Description |
|-----------|-------------|
| `component` | The Dynamo component this publisher belongs to |
| `kv_block_size` | Number of tokens per block (must be > 0, must match your engine) |
| `dp_rank` | Data parallel rank (defaults to 0) |
| `enable_local_indexer` | Enable a worker-local KV indexer for direct overlap queries |
| `zmq_endpoint` | ZMQ endpoint to subscribe to for relay mode (e.g. `"tcp://127.0.0.1:5557"`) |
| `zmq_topic` | ZMQ topic filter (defaults to `""` = all topics) |
#### `publish_stored()`
```python
publish_stored(
token_ids: list[int],
num_block_tokens: list[int],
block_hashes: list[int],
lora_id: int,
parent_hash: int | None = None,
)
```
Publish a block-stored event. Event IDs are managed internally.
#### `publish_removed()`
```python
publish_removed(block_hashes: list[int])
```
Publish a block-removed event. Event IDs are managed internally.
#### `shutdown()`
```python
shutdown()
```
Stop background tasks (ZMQ listener, event forwarding).
## Best Practices ## Best Practices
1. **Event IDs must be monotonically increasing** per worker (use a thread-safe counter) 1. **`kv_block_size` must match** your engine's actual block size.
2. **`parent_hash` is required** for all blocks except the first in a sequence — it links blocks to enable prefix matching.
2. **Block size must match** your engine's actual `kv_block_size` 3. **Block hashes are signed 64-bit integers** in the Python API. The publisher handles conversion internally.
3. **`parent_hash` is required** for all blocks except the first in a sequence - it links blocks to enable prefix matching 4. **Event ordering is automatic** — the publisher assigns monotonically increasing event IDs. You do not need to track event IDs yourself.
## See Also ## See Also
- **[Router README](../components/router/README.md)**: Quick start guide for the KV Router - **[Event Plane](../design-docs/event-plane.md)**: Transport options (NATS, ZMQ) and configuration
- **[Router Guide](../components/router/router-guide.md)**: Configuration, tuning, and production setup - **[Router Guide](../components/router/router-guide.md)**: Configuration, tuning, and production setup
- **[Router Design](../design-docs/router-design.md)**: Architecture details and event transport modes - **[Router Design](../design-docs/router-design.md)**: Architecture details and event transport modes
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment