Unverified Commit 481dc636 authored by zhongdaor-nv's avatar zhongdaor-nv Committed by GitHub
Browse files

feat: add multimodal support to KV router with standalone trtllm example (#4577)


Signed-off-by: default avatarzhongdaor <zhongdaor@nvidia.com>
Signed-off-by: default avatarzhongdaor-nv <zhongdaor@nvidia.com>
parent f17fcb15
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Router Standalone - TensorRT-LLM
A standalone implementation of KvRouter that demonstrates usage with TensorRT-LLM workers, without dependency on the dynamo runtime, etcd control plane, or nats event plane.
## Overview
This example shows how to use KvRouter with TensorRT-LLM workers to intelligently route requests across multiple GPUs based on KV cache overlap and load metrics. The router maintains a view of each worker's cached blocks and routes new requests to the worker with the best combination of cache overlap and available capacity.
Key features:
- **KV cache-aware routing**: Routes requests to workers with matching cached blocks
- **Multimodal support**: Handles vision-language models (e.g., Qwen2-VL) with image inputs
- **MM hash routing**: Identical images produce identical hashes for cache reuse
## How It Works
### Core Architecture
The router uses a **RadixTree** data structure (written in Rust) to efficiently track which blocks each worker has cached. When a new request arrives, the router:
1. Tokenizes the request and computes block hashes (including MM hashes for images)
2. Uses `find_matches` to calculate overlap scores between the request and each worker's cached blocks
3. Combines this with current load metrics to select the optimal worker
4. Routes the request to the chosen worker for processing
### Multimodal Routing
For vision-language models:
1. Images are processed using `default_multimodal_input_loader` from TensorRT-LLM
2. Image placeholders are expanded to visual tokens using HuggingFace `AutoProcessor`
3. `apply_mm_hashes` computes a content hash for each image
4. The MM hash is included in block hash computation, so identical images produce cache hits
### Event-Driven Updates
The router receives two types of events from TensorRT-LLM engines:
1. **KV Events**: Emitted automatically when blocks are stored/removed from cache (includes `mm_keys` for multimodal)
2. **Load Metrics**: GPU cache usage and waiting request count
## Components
### `worker.py`
- **TrtllmWorkers**: Manages multiple TensorRT-LLM worker processes
- Each worker runs on a separate GPU with KV cache event emission enabled
- Publishes metrics and KV events over ZMQ
- Extracts `mm_hash` from TRTLLM's `mm_keys` field for multimodal routing
### `router.py`
- **KvRouter**: Core routing logic using RadixTree
- Subscribes to KV cache events and load metrics from workers
- Implements `get_best_worker()` to select optimal routing destination
### `api.py`
- **ServiceAPI**: FastAPI server providing OpenAI-compatible chat completions endpoint
- Handles multimodal inputs (images) via `default_multimodal_input_loader`
- Computes block hashes including MM hashes for routing decisions
- Streams responses in OpenAI format
### `test_router.py`
- Comprehensive test suite for router functionality
- Includes local hash computation tests and server-side multimodal tests
- Run with `--mm-only` for multimodal-specific tests
## Requirements
- **TensorRT-LLM >= 1.2.0rc6**: You need TensorRT-LLM version 1.2.0rc6 or later, which includes multimodal information (`mm_keys`) in KV cache events. This is required for MM hash-based routing. See [PR #9604](https://github.com/NVIDIA/TensorRT-LLM/pull/9604) for details.
- TensorRT-LLM with pytorch backend
- Multiple GPUs (one per worker)
- Python 3.10+
- Required packages: fastapi, uvicorn, httpx, zmq, tensorrt_llm, transformers
## Usage
### 1. Start the API Server
```bash
python api.py \
--model Qwen/Qwen2-VL-2B-Instruct \
--num-workers 2 \
--block-size 32 \
--base-kv-events-port 5557 \
--base-metrics-port 5657 \
--router-port 7000 \
--http-port 8000
```
This will:
- Initialize TensorRT-LLM engines on each GPU
- Start ZMQ publishers for metrics and KV events
- Start the router service
- Start the OpenAI-compatible API server
### 2. Test with curl
**Text-only request:**
```bash
curl -s http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2-VL-2B-Instruct",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"max_tokens": 100,
"stream": false
}' | jq
```
**Multimodal request (with images):**
```bash
curl -s -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2-VL-2B-Instruct",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": "Describe both images in detail."},
{"type": "image_url", "image_url": {"url": "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg"}},
{"type": "image_url", "image_url": {"url": "http://images.cocodataset.org/test2017/000000000001.jpg"}}
]
}],
"max_tokens": 500,
"stream": false
}' | jq
```
### 3. Run Tests
```bash
# Run all tests
python test_router.py
# Run multimodal tests only
python test_router.py --mm-only
# Verbose output
python test_router.py -v
```
### 4. Check endpoint health
```bash
./ping.sh
```
## Configuration
### Command-line Arguments
- `--model`: HuggingFace model name (default: Qwen/Qwen2-VL-2B-Instruct)
- `--num-workers`: Number of GPU workers (default: 2)
- `--block-size`: KV cache block size (default: 32, TensorRT-LLM's default)
- `--base-kv-events-port`: Base port for KV events ZMQ (default: 5557)
- `--base-metrics-port`: Base port for metrics ZMQ (default: 5657)
- `--router-port`: Router HTTP service port (default: 7000)
- `--http-port`: API server port (default: 8000)
### Environment Variables
- `DYNAMO_DEBUG=1`: Enable debug file dumps to `/tmp/debug_*.txt`
- `LOGLEVEL=DEBUG`: Set logging level (DEBUG, INFO, WARNING, ERROR)
- `TRANSFORMERS_ATTN_IMPLEMENTATION=eager`: Disable FlashAttention (set automatically)
### Port Assignment
Workers use sequential ports:
- Worker 0: KV events on 5557, metrics on 5657
- Worker 1: KV events on 5558, metrics on 5658
- Worker N: KV events on 5557+N, metrics on 5657+N
## Architecture Diagram
```
┌─────────────┐
│ Client │
└──────┬──────┘
│ HTTP
┌─────────────────┐
│ API Server │
│ (api.py) │
└────────┬────────┘
│ HTTP
┌─────────────────┐
│ Router │──┐
│ (router.py) │ │ ZMQ (KV Events)
└────────┬────────┘ │
│ │
│ Select │
│ Worker │
▼ │
┌─────────────────┐ │
│ TrtllmWorkers │ │
│ (worker.py) │◄-┘
└─────────────────┘
│ │
▼ ▼
GPU 0 GPU 1
```
## Multimodal KV Cache Routing
When processing multimodal requests:
1. **API Layer** (`api.py`):
- Parses OpenAI-format messages with `image_url` content
- Uses `default_multimodal_input_loader` to process images
- Expands image placeholders to visual tokens via `AutoProcessor`
- Computes `mm_hash` using `apply_mm_hashes`
- Includes `mm_hash` in block hash computation for routing
2. **Worker Layer** (`worker.py`):
- Receives multimodal input and passes to TRTLLM
- Extracts `mm_hash` from TRTLLM's `mm_keys` in KV events
- Publishes KV events with `mm_extra_info` to router
3. **Router Layer** (`router.py`):
- RadixTree matches blocks including MM hash
- Same image content = same hash = cache hit on same worker
## Notes
- This is a standalone implementation for pedagogical purposes
- Production dynamo uses NATS for events and etcd for service discovery
- Each worker needs its own GPU
- TensorRT-LLM models may take time to compile on first run
## See Also
- [vLLM Router Standalone](../router_standalone/) - Original vLLM version
- [TensorRT-LLM KV Event Documentation](https://nvidia.github.io/TensorRT-LLM/0.21.0/examples/llm_inference_kv_events.html)
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
This diff is collapsed.
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Simple health check - sends a basic chat request
# Model name should match what you started api.py with
curl -s -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2-VL-2B-Instruct",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": false,
"max_tokens": 50
}' | jq
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import json
import logging
import os
from contextlib import asynccontextmanager
import numpy as np
import uvicorn
import zmq
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, ValidationError
from dynamo._core import RadixTree, ZmqKvEventListener
logger = logging.getLogger(__name__)
DEBUG_ENABLED = os.environ.get("DYNAMO_DEBUG", "0") == "1"
def dump_kv_event(worker_id: int, event: dict):
"""Dump KV event to file for debugging (only when DYNAMO_DEBUG=1)."""
if not DEBUG_ENABLED:
return
import datetime
with open("/tmp/debug_kv_events.txt", "a") as f:
f.write(f"\n{'='*60}\n")
f.write(f"Timestamp: {datetime.datetime.now()}\n")
f.write(f"Worker ID: {worker_id}\n")
f.write(f"Event: {json.dumps(event, indent=2)}\n")
# -----------------------------------------------------------------------------
# Request/Response Models
# -----------------------------------------------------------------------------
class RouterRequest(BaseModel):
local_hashes: list[int]
num_tokens: int
class RouterResponse(BaseModel):
worker_id: int
overlap: float = 0.0
matched_blocks: int = 0
class InjectEventRequest(BaseModel):
"""For testing: inject a KV event directly into RadixTree."""
worker_id: int
tokens_hash: int
block_hash: int | None = None
mm_extra_info: dict | None = None
class LoadMetrics(BaseModel):
kv_cache_usage: float
num_waiting_reqs: int
# -----------------------------------------------------------------------------
# ZMQ Helpers
# -----------------------------------------------------------------------------
def create_zmq_subscriber(context: zmq.Context, endpoint: str) -> zmq.Socket[bytes]:
"""Create a ZMQ SUB socket with standard settings."""
socket = context.socket(zmq.SUB)
socket.connect(endpoint)
socket.setsockopt(zmq.SUBSCRIBE, b"")
socket.setsockopt(zmq.CONFLATE, 1)
socket.setsockopt(zmq.RCVTIMEO, 1)
return socket
# -----------------------------------------------------------------------------
# KvRouter Core
# -----------------------------------------------------------------------------
class KvRouter:
"""Router that uses RadixTree for KV cache-aware worker selection."""
def __init__(
self,
block_size: int = 64,
num_workers: int = 4,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
):
self.num_workers = num_workers
self.block_size = block_size
self.radix_tree = RadixTree()
# Per-worker metrics
self.kv_usages = [0.0] * num_workers
self.waitings = [0] * num_workers
# ZMQ setup
self.context = zmq.Context()
self.load_listeners = [
create_zmq_subscriber(
self.context, f"tcp://localhost:{base_metrics_port + i}"
)
for i in range(num_workers)
]
self.kv_listeners = [
ZmqKvEventListener(
f"tcp://localhost:{base_kv_events_port + i}", "", block_size
)
for i in range(num_workers)
]
self.background_tasks: list[asyncio.Task] = []
logger.info("Router initialized")
# -------------------------------------------------------------------------
# Background Tasks
# -------------------------------------------------------------------------
async def start_background_tasks(self):
"""Start background tasks for load and tree updates."""
logger.info("Starting router background tasks...")
for worker_id in range(self.num_workers):
self.background_tasks.append(
asyncio.create_task(self._poll_worker_load(worker_id))
)
self.background_tasks.append(
asyncio.create_task(self._poll_worker_kv_events(worker_id))
)
async def _poll_worker_load(self, worker_id: int):
"""Poll load metrics for a single worker."""
while True:
try:
data = self.load_listeners[worker_id].recv_json(zmq.NOBLOCK)
metrics = LoadMetrics.model_validate(data)
self.kv_usages[worker_id] = metrics.kv_cache_usage
self.waitings[worker_id] = metrics.num_waiting_reqs
except zmq.Again:
pass
except (zmq.ZMQError, ValidationError) as e:
logger.warning(f"Worker {worker_id} metrics error: {e}")
except Exception:
logger.exception(f"Worker {worker_id} unexpected metrics error")
await asyncio.sleep(0.1)
async def _poll_worker_kv_events(self, worker_id: int):
"""Poll KV events for a single worker and update RadixTree."""
while True:
try:
events: list[str] = await self.kv_listeners[worker_id].get_events()
for event_str in events:
event = json.loads(event_str)
dump_kv_event(worker_id, event)
self.radix_tree.apply_event(
worker_id, json.dumps(event).encode("utf-8")
)
except zmq.Again:
pass
except (zmq.ZMQError, json.JSONDecodeError) as e:
logger.warning(f"Worker {worker_id} KV events error: {e}")
except Exception:
logger.exception(f"Worker {worker_id} unexpected KV events error")
await asyncio.sleep(0.1)
# -------------------------------------------------------------------------
# Worker Selection
# -------------------------------------------------------------------------
async def get_best_worker(
self, local_hashes: list[int], num_tokens: int
) -> tuple[int, float, int]:
"""
Find best worker for request.
Returns: (worker_id, overlap_ratio, matched_blocks)
"""
if num_tokens <= 0:
raise ValueError("num_tokens must be positive")
# Get cache matches from RadixTree
matched_blocks = self._get_matched_blocks(local_hashes)
# Compute overlap scores
overlap_scores = {
wid: matched_blocks[wid] * self.block_size / num_tokens
for wid in range(self.num_workers)
}
# Compute routing logits
logits = self._compute_logits(overlap_scores)
# Select best worker (random tie-breaking)
best_id = self._select_best_worker(logits)
# Predictive update for burst handling
self.waitings[best_id] += 1
return best_id, overlap_scores[best_id], matched_blocks[best_id]
def _get_matched_blocks(self, local_hashes: list[int]) -> dict[int, int]:
"""Get matched block count per worker from RadixTree."""
result = self.radix_tree.find_matches(local_hashes)
raw_scores = result.scores
logger.info(f"Router: raw_scores={raw_scores}")
# raw_scores is keyed by (worker_id, dp_rank); assume dp_rank=0
return {wid: raw_scores.get((wid, 0), 0) for wid in range(self.num_workers)}
def _compute_logits(self, overlap_scores: dict[int, float]) -> list[float]:
"""Compute routing logits for each worker."""
max_waiting = max(self.waitings) if self.waitings else 0
logits = []
for wid in range(self.num_workers):
overlap = overlap_scores[wid]
usage = self.kv_usages[wid]
waiting_norm = self.waitings[wid] / max_waiting if max_waiting else 0.0
logit = 2 * overlap - usage - waiting_norm
logits.append(logit)
logger.info(
f"worker_id: {wid}, logit = 2 * {overlap:.3f} - {usage:.3f} - {waiting_norm:.3f} = {logit:.3f}"
)
return logits
def _select_best_worker(self, logits: list[float]) -> int:
"""Select worker with highest logit (random tie-breaking)."""
arr = np.array(logits)
return int(np.random.choice(np.flatnonzero(arr == arr.max())))
# -------------------------------------------------------------------------
# Shutdown
# -------------------------------------------------------------------------
async def shutdown(self):
"""Shutdown ZMQ listeners and background tasks."""
logger.info("Shutting down KvRouter...")
for task in self.background_tasks:
task.cancel()
if self.background_tasks:
await asyncio.gather(*self.background_tasks, return_exceptions=True)
for listener in self.load_listeners:
listener.close()
self.context.term()
logger.info("KvRouter shutdown completed")
# -----------------------------------------------------------------------------
# Router API Server
# -----------------------------------------------------------------------------
class RouterAPI:
"""FastAPI wrapper for KvRouter."""
def __init__(
self,
block_size: int = 64,
num_workers: int = 4,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
port: int = 7000,
):
self.port = port
self.router_config = {
"block_size": block_size,
"num_workers": num_workers,
"base_kv_events_port": base_kv_events_port,
"base_metrics_port": base_metrics_port,
}
self.router: KvRouter | None = None
self.app = FastAPI(
title="KV Router API", version="0.0.1", lifespan=self.lifespan
)
self._setup_routes()
def _require_router(self) -> KvRouter:
"""Get router or raise 503 if not initialized."""
if self.router is None:
raise HTTPException(status_code=503, detail="Router not initialized")
return self.router
@asynccontextmanager
async def lifespan(self, app: FastAPI):
self.router = KvRouter(**self.router_config)
await self.router.start_background_tasks()
logger.info("Router API started")
yield
if self.router:
await self.router.shutdown()
def _setup_routes(self):
@self.app.post("/find_best_worker", response_model=RouterResponse)
async def find_best_worker(request: RouterRequest):
router = self._require_router()
try:
wid, overlap, matched = await router.get_best_worker(
request.local_hashes, request.num_tokens
)
return RouterResponse(
worker_id=wid, overlap=overlap, matched_blocks=matched
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@self.app.get("/debug/tree_info")
async def get_tree_info():
router = self._require_router()
events = router.radix_tree.dump_tree_as_events()
return {"num_blocks": len(events), "events": events[:20]}
@self.app.post("/debug/inject_event")
async def inject_event(request: InjectEventRequest):
router = self._require_router()
block_hash = request.block_hash or request.tokens_hash
event = {
"event_id": 99999,
"data": {
"stored": {
"parent_hash": None,
"blocks": [
{
"block_hash": block_hash,
"tokens_hash": request.tokens_hash,
"mm_extra_info": request.mm_extra_info,
}
],
}
},
}
router.radix_tree.apply_event(
request.worker_id, json.dumps(event).encode("utf-8")
)
return {
"status": "ok",
"tokens_hash": request.tokens_hash,
"worker_id": request.worker_id,
}
async def start(self):
"""Start the router API server."""
logger.info(f"Starting Router API on port {self.port}")
config = uvicorn.Config(
self.app, host="0.0.0.0", port=self.port, log_level="info"
)
await uvicorn.Server(config).serve()
def main():
parser = argparse.ArgumentParser(description="KV Router API Server")
parser.add_argument(
"--block-size", type=int, default=32, help="Block size (default: 32)"
)
parser.add_argument("--num-workers", type=int, default=2, help="Number of workers")
parser.add_argument(
"--base-kv-events-port", type=int, default=5557, help="Base KV events port"
)
parser.add_argument(
"--base-metrics-port", type=int, default=5657, help="Base metrics port"
)
parser.add_argument("--port", type=int, default=7000, help="Router API port")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
api = RouterAPI(
block_size=args.block_size,
num_workers=args.num_workers,
base_kv_events_port=args.base_kv_events_port,
base_metrics_port=args.base_metrics_port,
port=args.port,
)
asyncio.run(api.start())
if __name__ == "__main__":
main()
This diff is collapsed.
This diff is collapsed.
...@@ -170,10 +170,12 @@ fn kv_event_create_stored_block_from_parts( ...@@ -170,10 +170,12 @@ fn kv_event_create_stored_block_from_parts(
let tokens_hash = compute_block_hash_for_seq( let tokens_hash = compute_block_hash_for_seq(
unsafe { std::slice::from_raw_parts(token_ids, num_tokens) }, unsafe { std::slice::from_raw_parts(token_ids, num_tokens) },
kv_block_size, kv_block_size,
None,
)[0]; )[0];
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(block_hash), block_hash: ExternalSequenceBlockHash(block_hash),
tokens_hash, tokens_hash,
mm_extra_info: None,
} }
} }
static WARN_COUNT: AtomicU32 = AtomicU32::new(0); static WARN_COUNT: AtomicU32 = AtomicU32::new(0);
......
...@@ -27,12 +27,35 @@ use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions}; ...@@ -27,12 +27,35 @@ use llm_rs::protocols::common::{OutputOptions, SamplingOptions, StopConditions};
use serde_json::json; use serde_json::json;
#[pyfunction] #[pyfunction]
pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> { #[pyo3(signature = (tokens, kv_block_size, block_mm_infos=None))]
pub fn compute_block_hash_for_seq_py(
_py: Python,
tokens: Vec<u32>,
kv_block_size: usize,
block_mm_infos: Option<Bound<PyAny>>,
) -> PyResult<Vec<u64>> {
if kv_block_size == 0 { if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0"))); return Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(
"kv_block_size cannot be 0",
));
} }
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size as u32); // Convert Python block_mm_infos to Rust Vec<Option<BlockExtraInfo>>
let mm_infos_rust: Option<Vec<Option<BlockExtraInfo>>> = block_mm_infos
.as_ref()
.map(|infos_py| {
depythonize::<Vec<Option<BlockExtraInfo>>>(infos_py).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to convert block_mm_infos: {}",
e
))
})
})
.transpose()?;
let hashes =
compute_block_hash_for_seq(&tokens, kv_block_size as u32, mm_infos_rust.as_deref());
Ok(hashes.into_iter().map(|h| h.0).collect()) Ok(hashes.into_iter().map(|h| h.0).collect())
} }
...@@ -280,7 +303,7 @@ impl KvEventPublisher { ...@@ -280,7 +303,7 @@ impl KvEventPublisher {
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None))] #[pyo3(signature = (event_id, token_ids, num_block_tokens, block_hashes, lora_id, parent_hash=None, block_mm_infos=None))]
fn publish_stored( fn publish_stored(
&mut self, &mut self,
py: Python, py: Python,
...@@ -290,12 +313,26 @@ impl KvEventPublisher { ...@@ -290,12 +313,26 @@ impl KvEventPublisher {
block_hashes: Vec<i64>, block_hashes: Vec<i64>,
lora_id: u64, lora_id: u64,
parent_hash: Option<i64>, parent_hash: Option<i64>,
block_mm_infos: Option<Bound<PyAny>>,
) -> PyResult<()> { ) -> PyResult<()> {
let kv_block_size = self.kv_block_size as u32; let kv_block_size = self.kv_block_size as u32;
let dp_rank = self.dp_rank; let dp_rank = self.dp_rank;
let warning_count = self.warning_count.clone(); let warning_count = self.warning_count.clone();
let inner = self.inner.clone(); let inner = self.inner.clone();
// Convert Python block_mm_infos to Rust Vec<Option<BlockExtraInfo>>
let mm_infos_rust: Option<Vec<Option<BlockExtraInfo>>> = block_mm_infos
.as_ref()
.map(|infos_py| {
depythonize::<Vec<Option<BlockExtraInfo>>>(infos_py).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to convert block_mm_infos: {}",
e
))
})
})
.transpose()?;
py.allow_threads(|| { py.allow_threads(|| {
let block_hashes_u64: Vec<u64> = block_hashes.iter().map(|&h| h as u64).collect(); let block_hashes_u64: Vec<u64> = block_hashes.iter().map(|&h| h as u64).collect();
let event = KvCacheEvent { let event = KvCacheEvent {
...@@ -309,6 +346,7 @@ impl KvEventPublisher { ...@@ -309,6 +346,7 @@ impl KvEventPublisher {
&block_hashes_u64, &block_hashes_u64,
lora_id, lora_id,
&warning_count, &warning_count,
mm_infos_rust.as_deref(),
), ),
}), }),
dp_rank, dp_rank,
......
...@@ -232,16 +232,42 @@ class Client: ...@@ -232,16 +232,42 @@ class Client:
... ...
def compute_block_hash_for_seq_py(tokens: List[int], kv_block_size: int) -> List[int]: def compute_block_hash_for_seq_py(
tokens: List[int],
kv_block_size: int,
block_mm_infos: Optional[List[Optional[Dict[str, Any]]]] = None
) -> List[int]:
""" """
Compute block hashes for a sequence of tokens Compute block hashes for a sequence of tokens, optionally including multimodal metadata.
When block_mm_infos is provided, the mm_hashes are included in the hash computation
to ensure that blocks with identical tokens but different multimodal objects produce
different hashes.
Args: Args:
tokens: List of token IDs tokens: List of token IDs
kv_block_size: Size of each KV cache block kv_block_size: Size of each block in tokens
block_mm_infos: Optional per-block multimodal metadata. Each element corresponds to a block
and should be None or a dict with structure:
{
"mm_objects": [
{
"mm_hash": int, # Hash of the MM object
}
]
}
Returns: Returns:
List of block hashes as integers List of block hashes (one per block)
Example:
>>> tokens = [1, 2, 3, 4] * 8 # 32 tokens = 1 block
>>> mm_info = {
... "mm_objects": [{
... "mm_hash": 0xDEADBEEF,
... }]
... }
>>> hashes = compute_block_hash_for_seq_py(tokens, 32, [mm_info])
""" """
... ...
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Tests for Multimodal KV Router functionality.
These tests verify that the KV router correctly handles multimodal content (images, videos)
by distinguishing between requests with identical token sequences but different MM objects.
Key Concepts:
- block_hash: External hash used to identify blocks uniquely (includes MM info)
- tokens_hash: Local hash based only on token content
- mm_hash: Hash of the multimodal object (image, video, etc.)
Test Strategy:
- Use RadixTree directly to avoid NATS/etcd infrastructure dependencies
- Simulate multiple workers caching same tokens with different MM content
- Verify that routing distinguishes between different MM objects
"""
import json
from typing import Any
import pytest
from dynamo.llm import RadixTree, compute_block_hash_for_seq_py
pytestmark = pytest.mark.pre_merge
# Constants for testing
DEFAULT_BLOCK_SIZE = 32
MM_HASH_1 = 0xDEADBEEF
MM_HASH_2 = 0xCAFEBABE
MM_HASH_3 = 0xFEEDFACE
def make_mm_info(mm_hash: int, offsets: list[list[int]] | None = None) -> dict:
"""Create a block's MM extra info structure."""
if offsets is None:
offsets = [[0, 10]]
return {"mm_objects": [{"mm_hash": mm_hash, "offsets": offsets}]}
def make_store_event(
event_id: int,
blocks: list[dict],
parent_hash: int | None = None,
) -> bytes:
"""Create a JSON-encoded store event for RadixTree."""
event = {
"event_id": event_id,
"data": {
"stored": {
"parent_hash": parent_hash,
"blocks": blocks,
}
},
}
return json.dumps(event).encode("utf-8")
def make_block(
block_hash: int,
tokens_hash: int | None = None,
mm_info: dict | None = None,
) -> dict:
"""Create a block structure for store events."""
block: dict[str, Any] = {
"block_hash": block_hash,
"tokens_hash": tokens_hash if tokens_hash is not None else block_hash,
}
if mm_info is not None:
block["mm_extra_info"] = mm_info
return block
# =============================================================================
# RadixTree MM Routing Tests
# =============================================================================
# # @pytest.mark.timeout(5)
def test_radix_tree_mm_routing_basic():
"""Test RadixTree correctly distinguishes blocks with same tokens but different MM content."""
radix_tree = RadixTree()
# Worker 0: Store block with MM Object 1
worker_0, block_hash_w0 = 0, 1000
event_w0 = make_store_event(
event_id=1,
blocks=[make_block(block_hash_w0, mm_info=make_mm_info(MM_HASH_1))],
)
radix_tree.apply_event(worker_0, event_w0)
# Worker 1: Store block with DIFFERENT MM Object (same tokens)
worker_1, block_hash_w1 = 1, 2000
event_w1 = make_store_event(
event_id=2,
blocks=[make_block(block_hash_w1, mm_info=make_mm_info(MM_HASH_2))],
)
radix_tree.apply_event(worker_1, event_w1)
# Verify both blocks are stored
all_blocks = radix_tree.dump_tree_as_events()
assert len(all_blocks) == 2
# Query for worker 0's block
scores_w0 = radix_tree.find_matches([block_hash_w0])
assert (worker_0, 0) in scores_w0.scores
assert scores_w0.scores[(worker_0, 0)] == 1
# Query for worker 1's block
scores_w1 = radix_tree.find_matches([block_hash_w1])
assert (worker_1, 0) in scores_w1.scores
assert scores_w1.scores[(worker_1, 0)] == 1
# Query with non-existent hash should return no matches
scores_none = radix_tree.find_matches([9999])
assert len(scores_none.scores) == 0
# @pytest.mark.timeout(5)
def test_radix_tree_mm_block_chaining():
"""Test block chaining with parent_hash for multi-block sequences with MM content."""
radix_tree = RadixTree()
worker_id = 0
parent_hash = 1000
child_hash = 2000
# Store parent block
parent_event = make_store_event(
event_id=1,
blocks=[make_block(parent_hash, mm_info=make_mm_info(MM_HASH_1))],
)
radix_tree.apply_event(worker_id, parent_event)
# Store child block that references parent
child_event = make_store_event(
event_id=2,
blocks=[make_block(child_hash, mm_info=make_mm_info(MM_HASH_1))],
parent_hash=parent_hash,
)
radix_tree.apply_event(worker_id, child_event)
# Verify chain exists
all_blocks = radix_tree.dump_tree_as_events()
assert len(all_blocks) == 2
# Query with both hashes should match the chain
scores = radix_tree.find_matches([parent_hash, child_hash])
assert (worker_id, 0) in scores.scores
assert scores.scores[(worker_id, 0)] == 2
# @pytest.mark.timeout(5)
def test_radix_tree_worker_removal():
"""Test worker removal clears all its blocks."""
radix_tree = RadixTree()
worker_0, worker_1 = 0, 1
# Add blocks for both workers
radix_tree.apply_event(
worker_0,
make_store_event(1, [make_block(1000, mm_info=make_mm_info(MM_HASH_1))]),
)
radix_tree.apply_event(
worker_1,
make_store_event(2, [make_block(2000, mm_info=make_mm_info(MM_HASH_2))]),
)
assert len(radix_tree.dump_tree_as_events()) == 2
# Remove worker 0
radix_tree.remove_worker(worker_0)
# Only worker 1's block should remain
remaining = radix_tree.dump_tree_as_events()
assert len(remaining) == 1
scores = radix_tree.find_matches([2000])
assert (worker_1, 0) in scores.scores
# @pytest.mark.timeout(5)
def test_radix_tree_clear_all_blocks():
"""Test clearing all blocks for a specific worker."""
radix_tree = RadixTree()
worker_id = 0
# Add multiple blocks
radix_tree.apply_event(
worker_id,
make_store_event(1, [make_block(1000), make_block(2000)]),
)
assert len(radix_tree.dump_tree_as_events()) == 2
# Clear all blocks for worker
radix_tree.clear_all_blocks(worker_id)
assert len(radix_tree.dump_tree_as_events()) == 0
# =============================================================================
# Block Hash Computation Tests
# =============================================================================
# @pytest.mark.timeout(5)
def test_mm_block_hash_computation_basic():
"""Test that same tokens with different MM content produce different hashes."""
tokens = [100] * DEFAULT_BLOCK_SIZE
# Without MM info
hashes_no_mm = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE)
assert len(hashes_no_mm) == 1
# With MM info 1
hashes_mm1 = compute_block_hash_for_seq_py(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)]
)
assert len(hashes_mm1) == 1
# With MM info 2
hashes_mm2 = compute_block_hash_for_seq_py(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_2)]
)
assert len(hashes_mm2) == 1
# All three should be different
assert hashes_no_mm != hashes_mm1
assert hashes_no_mm != hashes_mm2
assert hashes_mm1 != hashes_mm2
# @pytest.mark.timeout(5)
def test_mm_block_hash_determinism():
"""Test that hash computation is deterministic."""
tokens = [100] * DEFAULT_BLOCK_SIZE
mm_info = [make_mm_info(MM_HASH_1)]
hash1 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_info)
hash2 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_info)
assert hash1 == hash2
# @pytest.mark.timeout(5)
@pytest.mark.parametrize("block_size", [16, 32, 64])
def test_mm_block_hash_multiple_blocks(block_size: int):
"""Test hash computation for sequences spanning multiple blocks."""
num_blocks = 3
# Use different tokens per block to get unique hashes
tokens = []
for i in range(num_blocks):
tokens.extend([100 + i] * block_size)
# One MM info per block
mm_infos = [make_mm_info(MM_HASH_1) for _ in range(num_blocks)]
hashes = compute_block_hash_for_seq_py(tokens, block_size, mm_infos)
assert len(hashes) == num_blocks
# Each block should have a unique hash (due to different tokens)
assert len(set(hashes)) == num_blocks
# @pytest.mark.timeout(5)
def test_mm_block_hash_partial_block():
"""Test hash computation when tokens don't fill complete blocks."""
# 1.5 blocks worth of tokens
tokens = [100] * (DEFAULT_BLOCK_SIZE + DEFAULT_BLOCK_SIZE // 2)
# MM info for each block
mm_infos = [make_mm_info(MM_HASH_1), make_mm_info(MM_HASH_2)]
hashes = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, mm_infos)
# Only complete blocks get hashes - partial blocks are not hashed
assert len(hashes) == 1
# @pytest.mark.timeout(5)
def test_mm_block_hash_none_mm_info():
"""Test that None MM info is handled correctly."""
tokens = [100] * DEFAULT_BLOCK_SIZE
# Pass None for some blocks' MM info
mm_infos = [None]
hashes_with_none = compute_block_hash_for_seq_py(
tokens, DEFAULT_BLOCK_SIZE, mm_infos
)
hashes_without = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE)
# Both should produce the same result
assert hashes_with_none == hashes_without
# @pytest.mark.timeout(5)
def test_mm_block_hash_different_offsets():
"""Test that same mm_hash with different offsets produces same hash."""
tokens = [100] * DEFAULT_BLOCK_SIZE
# Same MM hash, different offsets
mm_info_1 = make_mm_info(MM_HASH_1, offsets=[[0, 10]])
mm_info_2 = make_mm_info(MM_HASH_1, offsets=[[5, 15]])
hash1 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info_1])
hash2 = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info_2])
# Currently offsets are not included in hash computation - just mm_hash
# This behavior may change - update test if needed
assert hash1 == hash2
# @pytest.mark.timeout(5)
def test_mm_block_hash_multiple_mm_objects():
"""Test hash with multiple MM objects in a single block."""
tokens = [100] * DEFAULT_BLOCK_SIZE
# Multiple MM objects in one block
mm_info = {
"mm_objects": [
{"mm_hash": MM_HASH_1, "offsets": [[0, 5]]},
{"mm_hash": MM_HASH_2, "offsets": [[10, 15]]},
]
}
hashes = compute_block_hash_for_seq_py(tokens, DEFAULT_BLOCK_SIZE, [mm_info])
assert len(hashes) == 1
# Compare with single MM object
single_mm_hashes = compute_block_hash_for_seq_py(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)]
)
# Should be different due to additional MM object
assert hashes != single_mm_hashes
# @pytest.mark.timeout(5)
def test_mm_block_hash_error_zero_block_size():
"""Test that zero block size raises an error."""
tokens = [100] * 32
with pytest.raises(ValueError, match="kv_block_size cannot be 0"):
compute_block_hash_for_seq_py(tokens, 0)
# =============================================================================
# Integration Tests: RadixTree + Hash Computation
# =============================================================================
# @pytest.mark.timeout(5)
def test_integration_mm_hash_to_routing():
"""Test end-to-end: compute hash -> store in tree -> query matches correctly."""
radix_tree = RadixTree()
tokens = [100] * DEFAULT_BLOCK_SIZE
# Compute hashes for two different MM contents
hash_mm1 = compute_block_hash_for_seq_py(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_1)]
)[0]
hash_mm2 = compute_block_hash_for_seq_py(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(MM_HASH_2)]
)[0]
# Store each on different workers
worker_0, worker_1 = 0, 1
radix_tree.apply_event(
worker_0,
make_store_event(1, [make_block(hash_mm1, mm_info=make_mm_info(MM_HASH_1))]),
)
radix_tree.apply_event(
worker_1,
make_store_event(2, [make_block(hash_mm2, mm_info=make_mm_info(MM_HASH_2))]),
)
# Query with MM1's hash should match worker 0
scores_mm1 = radix_tree.find_matches([hash_mm1])
assert (worker_0, 0) in scores_mm1.scores
assert (worker_1, 0) not in scores_mm1.scores
# Query with MM2's hash should match worker 1
scores_mm2 = radix_tree.find_matches([hash_mm2])
assert (worker_1, 0) in scores_mm2.scores
assert (worker_0, 0) not in scores_mm2.scores
# @pytest.mark.timeout(5)
@pytest.mark.parametrize("num_workers", [2, 3, 5])
def test_integration_multiple_workers_same_tokens(num_workers: int):
"""Test routing with multiple workers caching same tokens but different MM content."""
radix_tree = RadixTree()
tokens = [100] * DEFAULT_BLOCK_SIZE
# Each worker has unique MM content
mm_hashes = [0x1000 + i for i in range(num_workers)]
# Store blocks for each worker
for worker_id, mm_hash in enumerate(mm_hashes):
block_hash = compute_block_hash_for_seq_py(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(mm_hash)]
)[0]
radix_tree.apply_event(
worker_id,
make_store_event(
event_id=worker_id + 1,
blocks=[make_block(block_hash, mm_info=make_mm_info(mm_hash))],
),
)
# Verify all blocks stored
assert len(radix_tree.dump_tree_as_events()) == num_workers
# Query for each worker's block should match only that worker
for worker_id, mm_hash in enumerate(mm_hashes):
block_hash = compute_block_hash_for_seq_py(
tokens, DEFAULT_BLOCK_SIZE, [make_mm_info(mm_hash)]
)[0]
scores = radix_tree.find_matches([block_hash])
assert (worker_id, 0) in scores.scores
assert scores.scores[(worker_id, 0)] == 1
# No other workers should match
for other_id in range(num_workers):
if other_id != worker_id:
assert (other_id, 0) not in scores.scores
...@@ -475,7 +475,7 @@ impl KvRouter { ...@@ -475,7 +475,7 @@ impl KvRouter {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let seq_hashes = compute_seq_hash_for_block(&block_hashes); let seq_hashes = compute_seq_hash_for_block(&block_hashes);
let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?; let overlap_scores = self.indexer.find_matches(block_hashes.clone()).await?;
...@@ -530,7 +530,7 @@ impl KvRouter { ...@@ -530,7 +530,7 @@ impl KvRouter {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| { let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
compute_seq_hash_for_block(&block_hashes) compute_seq_hash_for_block(&block_hashes)
}); });
...@@ -573,11 +573,11 @@ impl KvRouter { ...@@ -573,11 +573,11 @@ impl KvRouter {
/// Get potential prefill and decode loads for all workers /// Get potential prefill and decode loads for all workers
pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> { pub async fn get_potential_loads(&self, tokens: &[u32]) -> Result<Vec<PotentialLoad>> {
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
let overlap_scores = self.indexer.find_matches(block_hashes).await?; let overlap_scores = self.indexer.find_matches(block_hashes).await?;
let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| { let maybe_seq_hashes = self.kv_router_config.router_track_active_blocks.then(|| {
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size, None);
compute_seq_hash_for_block(&block_hashes) compute_seq_hash_for_block(&block_hashes)
}); });
...@@ -661,7 +661,10 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -661,7 +661,10 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
let context_id = ctx.context().id().to_string(); let context_id = ctx.context().id().to_string();
// Handle different request types // Handle different request types
let response = match request { let response = match request {
RouterRequest::New { tokens } => { RouterRequest::New {
tokens,
request_extra_info: _,
} => {
let (best_worker, overlap_blocks) = self let (best_worker, overlap_blocks) = self
.find_best_match(Some(&context_id), &tokens, None, true) .find_best_match(Some(&context_id), &tokens, None, true)
.await?; .await?;
...@@ -761,7 +764,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -761,7 +764,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Compute actual overlap blocks by querying the indexer // Compute actual overlap blocks by querying the indexer
let block_hashes = let block_hashes =
compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size()); compute_block_hash_for_seq(&request.token_ids, self.chooser.block_size(), None);
let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?; let overlap_scores = self.chooser.indexer.find_matches(block_hashes).await?;
let worker = WorkerWithDpRank::new(id, dp_rank); let worker = WorkerWithDpRank::new(id, dp_rank);
let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0); let overlap_blocks = overlap_scores.scores.get(&worker).copied().unwrap_or(0);
......
...@@ -32,7 +32,6 @@ ...@@ -32,7 +32,6 @@
//! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance. //! This module provides a scalable and efficient way to manage and retrieve data blocks for LLM inference, leveraging a global KV cache to optimize performance.
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes;
use dynamo_runtime::{ use dynamo_runtime::{
component::Component, component::Component,
metrics::{MetricsHierarchy, prometheus_names::kvrouter}, metrics::{MetricsHierarchy, prometheus_names::kvrouter},
...@@ -55,7 +54,7 @@ use xxhash_rust::xxh3; ...@@ -55,7 +54,7 @@ use xxhash_rust::xxh3;
pub const XXH3_SEED: u64 = 1337; pub const XXH3_SEED: u64 = 1337;
use crate::kv_router::approx::{BlockEntry, PruneConfig, PruneManager}; use crate::kv_router::approx::{BlockEntry, PruneConfig, PruneManager};
use crate::kv_router::protocols::*; use crate::kv_router::protocols::{BlockExtraInfo, *};
use crate::tokens::{SequenceHash, TokenBlockSequence}; use crate::tokens::{SequenceHash, TokenBlockSequence};
/// Errors that can occur in the KV Router. /// Errors that can occur in the KV Router.
...@@ -117,25 +116,54 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash { ...@@ -117,25 +116,54 @@ pub fn compute_block_hash(data: &[u8]) -> LocalBlockHash {
// let hash = xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED); // let hash = xxh3::xxh3_64_with_seed(&bytes, XXH3_SEED);
// } // }
/// Compute the hash for a sequence of tokens. /// Compute the hash for a sequence of tokens, optionally including multimodal metadata.
///
/// When multimodal extra info is provided, the mm_hashes are included in the hash computation
/// to ensure that blocks with identical tokens but different multimodal objects produce
/// different hashes.
/// ///
/// ### Arguments /// ### Arguments
/// ///
/// * `tokens` - A vector of `u32` tokens. /// * `tokens` - A vector of `u32` tokens.
/// * `kv_block_size` - The size of each block in tokens.
/// * `block_mm_infos` - Optional per-block multimodal metadata.
/// ///
/// ### Returns /// ### Returns
/// ///
/// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens. /// A vector of `LocalBlockHash` representing the computed hashes for each chunk of tokens.
pub fn compute_block_hash_for_seq(tokens: &[u32], kv_block_size: u32) -> Vec<LocalBlockHash> { pub fn compute_block_hash_for_seq(
tokens: &[u32],
kv_block_size: u32,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
) -> Vec<LocalBlockHash> {
tokens tokens
.chunks_exact(kv_block_size as usize) // Split into chunks of kv_block_size elements .chunks_exact(kv_block_size as usize)
.map(|chunk| { .enumerate()
let bytes: Vec<u8> = chunk .map(|(block_idx, chunk)| {
let mut bytes: Vec<u8> = chunk.iter().flat_map(|&num| num.to_le_bytes()).collect();
// Include MM hashes in the block hash computation if present
if let Some(mm_infos) = block_mm_infos
&& let Some(Some(block_mm_info)) = mm_infos.get(block_idx)
{
// The order of different multimodal hashes does not matter.
// Only which multimodal infos are present in a block is important.
// The order may differ in different code paths, so the hashes are sorted
// to keep the block hash stable.
let mut mm_hashes: Vec<u64> = block_mm_info
.mm_objects
.iter() .iter()
.flat_map(|&num| num.to_le_bytes()) // Convert each i32 to its little-endian bytes .map(|obj| obj.mm_hash)
.collect(); .collect();
mm_hashes.sort_unstable();
// Append sorted mm_hashes to the byte array
for mm_hash in mm_hashes {
bytes.extend_from_slice(&mm_hash.to_le_bytes());
}
}
compute_block_hash(&Bytes::from(bytes)) // Convert the byte Vec to Bytes compute_block_hash(&bytes)
}) })
.collect() .collect()
} }
...@@ -610,6 +638,7 @@ impl RadixTree { ...@@ -610,6 +638,7 @@ impl RadixTree {
parent_hash, parent_hash,
blocks: vec![KvCacheStoredBlockData { blocks: vec![KvCacheStoredBlockData {
block_hash: *external_hash, block_hash: *external_hash,
mm_extra_info: None,
tokens_hash, tokens_hash,
}], }],
}), }),
...@@ -1076,6 +1105,7 @@ impl KvIndexer { ...@@ -1076,6 +1105,7 @@ impl KvIndexer {
blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData { blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData {
tokens_hash: *local_hash, tokens_hash: *local_hash,
block_hash: ExternalSequenceBlockHash(*sequence_hash), block_hash: ExternalSequenceBlockHash(*sequence_hash),
mm_extra_info: None,
}).collect(), }).collect(),
}); });
...@@ -1243,7 +1273,7 @@ impl KvIndexerInterface for KvIndexer { ...@@ -1243,7 +1273,7 @@ impl KvIndexerInterface for KvIndexer {
tokens, tokens,
tokens.len() tokens.len()
); );
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size); let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
tracing::debug!("Computed sequence: {:?}", sequence); tracing::debug!("Computed sequence: {:?}", sequence);
self.find_matches(sequence).await self.find_matches(sequence).await
} }
...@@ -1296,7 +1326,7 @@ impl KvIndexerInterface for KvIndexer { ...@@ -1296,7 +1326,7 @@ impl KvIndexerInterface for KvIndexer {
tokens: &[u32], tokens: &[u32],
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> { ) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size); let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None); let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None);
let sequence_hashes = sequence let sequence_hashes = sequence
.blocks() .blocks()
...@@ -1813,6 +1843,7 @@ impl KvIndexerSharded { ...@@ -1813,6 +1843,7 @@ impl KvIndexerSharded {
blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData { blocks: hashes.map(|(local_hash, sequence_hash)| KvCacheStoredBlockData {
tokens_hash: *local_hash, tokens_hash: *local_hash,
block_hash: ExternalSequenceBlockHash(*sequence_hash), block_hash: ExternalSequenceBlockHash(*sequence_hash),
mm_extra_info: None,
}).collect(), }).collect(),
}); });
...@@ -1973,7 +2004,7 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -1973,7 +2004,7 @@ impl KvIndexerInterface for KvIndexerSharded {
&self, &self,
tokens: &[u32], tokens: &[u32],
) -> Result<OverlapScores, KvRouterError> { ) -> Result<OverlapScores, KvRouterError> {
let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size); let sequence = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
self.find_matches(sequence).await self.find_matches(sequence).await
} }
...@@ -2073,7 +2104,7 @@ impl KvIndexerInterface for KvIndexerSharded { ...@@ -2073,7 +2104,7 @@ impl KvIndexerInterface for KvIndexerSharded {
tokens: &[u32], tokens: &[u32],
worker: WorkerWithDpRank, worker: WorkerWithDpRank,
) -> Result<(), KvRouterError> { ) -> Result<(), KvRouterError> {
let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size); let local_hashes = compute_block_hash_for_seq(tokens, self.kv_block_size, None);
let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None); let sequence = TokenBlockSequence::new(tokens.into(), self.kv_block_size, None);
let sequence_hashes = sequence let sequence_hashes = sequence
.blocks() .blocks()
...@@ -2111,6 +2142,7 @@ mod tests { ...@@ -2111,6 +2142,7 @@ mod tests {
.map(|i| KvCacheStoredBlockData { .map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i), tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100), block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
}) })
.collect() .collect()
} }
...@@ -2714,17 +2746,17 @@ mod tests { ...@@ -2714,17 +2746,17 @@ mod tests {
setup(); setup();
// create a sequence of 64 elements // create a sequence of 64 elements
let sequence = (0..kv_block_size).collect::<Vec<u32>>(); let sequence = (0..kv_block_size).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None);
assert_eq!(hashes.len(), 1); assert_eq!(hashes.len(), 1);
// create a sequence of 65 elements // create a sequence of 65 elements
let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>(); let sequence = (0..(kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None);
assert_eq!(hashes.len(), 1); assert_eq!(hashes.len(), 1);
// create a sequence of 129 elements // create a sequence of 129 elements
let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>(); let sequence = (0..(2 * kv_block_size + 1)).collect::<Vec<u32>>();
let hashes = compute_block_hash_for_seq(&sequence, kv_block_size); let hashes = compute_block_hash_for_seq(&sequence, kv_block_size, None);
assert_eq!(hashes.len(), 2); assert_eq!(hashes.len(), 2);
} }
...@@ -2929,6 +2961,7 @@ mod tests { ...@@ -2929,6 +2961,7 @@ mod tests {
parent_hash: None, parent_hash: None,
blocks: vec![KvCacheStoredBlockData { blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(0), block_hash: ExternalSequenceBlockHash(0),
mm_extra_info: None,
tokens_hash: LocalBlockHash(13226331709069118873), tokens_hash: LocalBlockHash(13226331709069118873),
}], }],
}), }),
...@@ -3392,6 +3425,7 @@ mod tests { ...@@ -3392,6 +3425,7 @@ mod tests {
blocks: vec![KvCacheStoredBlockData { blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(id * 100), block_hash: ExternalSequenceBlockHash(id * 100),
tokens_hash: LocalBlockHash(id * 200), tokens_hash: LocalBlockHash(id * 200),
mm_extra_info: None,
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
...@@ -3567,6 +3601,7 @@ mod tests { ...@@ -3567,6 +3601,7 @@ mod tests {
blocks: vec![KvCacheStoredBlockData { blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100), block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200), tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
......
...@@ -40,6 +40,8 @@ pub enum RouterRequest { ...@@ -40,6 +40,8 @@ pub enum RouterRequest {
#[serde(rename = "new")] #[serde(rename = "new")]
New { New {
tokens: Vec<Token>, tokens: Vec<Token>,
#[serde(default, skip_serializing_if = "Option::is_none")]
request_extra_info: Option<RequestExtraInfo>,
}, },
MarkPrefill, MarkPrefill,
MarkFree, MarkFree,
...@@ -47,7 +49,10 @@ pub enum RouterRequest { ...@@ -47,7 +49,10 @@ pub enum RouterRequest {
impl Default for RouterRequest { impl Default for RouterRequest {
fn default() -> Self { fn default() -> Self {
RouterRequest::New { tokens: vec![] } RouterRequest::New {
tokens: vec![],
request_extra_info: None,
}
} }
} }
...@@ -276,6 +281,111 @@ pub struct KvCacheStoreData { ...@@ -276,6 +281,111 @@ pub struct KvCacheStoreData {
pub blocks: Vec<KvCacheStoredBlockData>, pub blocks: Vec<KvCacheStoredBlockData>,
} }
/// Multimodal object information within a block.
/// Offsets are relative to the block (0 to block_size-1).
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct BlockMmObjectInfo {
/// Hash identifying this multimodal object
pub mm_hash: u64,
/// Token offset ranges where this MM object's placeholders appear within THIS block
/// Each tuple is (start_offset, end_offset) relative to block start
pub offsets: Vec<(usize, usize)>,
}
/// Extra metadata for a block containing multimodal objects
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct BlockExtraInfo {
/// All multimodal objects referenced in this block
pub mm_objects: Vec<BlockMmObjectInfo>,
}
/// Request-level multimodal object information.
/// Offsets are relative to the entire request token sequence.
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RequestMmObjectInfo {
/// Hash identifying this multimodal object
pub mm_hash: u64,
/// Token offset ranges where this MM object's placeholders appear in the ENTIRE request
/// Each tuple is (start_offset, end_offset) relative to request start
pub offsets: Vec<(usize, usize)>,
}
/// Request-level multimodal metadata
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RequestExtraInfo {
/// All multimodal objects in this request
pub mm_objects: Vec<RequestMmObjectInfo>,
}
impl RequestExtraInfo {
/// Convert request-level MM info to block-level MM info for a sequence of blocks.
///
/// This function splits request-level offsets (relative to the entire request token sequence)
/// into block-level offsets (relative to each block).
///
/// # Arguments
/// * `block_size` - The size of each block in tokens
/// * `total_tokens` - Total number of tokens in the request
///
/// # Returns
/// A vector of `Option<BlockExtraInfo>` where each element corresponds to a block.
/// `None` indicates a block with no multimodal objects.
pub fn to_block_level(
&self,
block_size: usize,
total_tokens: usize,
) -> Vec<Option<BlockExtraInfo>> {
let num_blocks = total_tokens.div_ceil(block_size);
let mut block_infos: Vec<Option<BlockExtraInfo>> = vec![None; num_blocks];
for req_mm_obj in &self.mm_objects {
for (req_start, req_end) in &req_mm_obj.offsets {
// Find which blocks this offset range spans
let start_block = req_start / block_size;
let end_block = (req_end.saturating_sub(1)) / block_size;
let upper_bound = end_block.min(num_blocks - 1) + 1;
for (block_idx, block_info_opt) in block_infos
.iter_mut()
.enumerate()
.take(upper_bound)
.skip(start_block)
{
let block_start_global = block_idx * block_size;
let block_end_global = ((block_idx + 1) * block_size).min(total_tokens);
// Calculate the intersection of this MM object's range with this block
let local_start = (*req_start).max(block_start_global) - block_start_global;
let local_end = (*req_end).min(block_end_global) - block_start_global;
if local_start < local_end {
let block_info = block_info_opt
.get_or_insert_with(|| BlockExtraInfo { mm_objects: vec![] });
// Check if we already have this mm_hash in this block
if let Some(existing) = block_info
.mm_objects
.iter_mut()
.find(|obj| obj.mm_hash == req_mm_obj.mm_hash)
{
// Add the offset range to existing object
existing.offsets.push((local_start, local_end));
} else {
// Create new MM object entry for this block
block_info.mm_objects.push(BlockMmObjectInfo {
mm_hash: req_mm_obj.mm_hash,
offsets: vec![(local_start, local_end)],
});
}
}
}
}
}
block_infos
}
}
/// Represents data for a stored block. /// Represents data for a stored block.
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
pub struct KvCacheStoredBlockData { pub struct KvCacheStoredBlockData {
...@@ -283,6 +393,11 @@ pub struct KvCacheStoredBlockData { ...@@ -283,6 +393,11 @@ pub struct KvCacheStoredBlockData {
pub block_hash: ExternalSequenceBlockHash, pub block_hash: ExternalSequenceBlockHash,
/// The hash of the tokens in the block. /// The hash of the tokens in the block.
pub tokens_hash: LocalBlockHash, pub tokens_hash: LocalBlockHash,
/// Extra multimodal metadata for this block
/// Note: Do NOT use skip_serializing_if with bincode - it breaks deserialization
/// because bincode is positional and expects all fields to be present.
#[serde(default)]
pub mm_extra_info: Option<BlockExtraInfo>,
} }
/// Represents the data associated with a removed cache event. /// Represents the data associated with a removed cache event.
...@@ -365,6 +480,7 @@ mod tests { ...@@ -365,6 +480,7 @@ mod tests {
blocks: vec![KvCacheStoredBlockData { blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(2), block_hash: ExternalSequenceBlockHash(2),
tokens_hash: LocalBlockHash(3), tokens_hash: LocalBlockHash(3),
mm_extra_info: None,
}], }],
}); });
......
...@@ -543,6 +543,7 @@ fn convert_event( ...@@ -543,6 +543,7 @@ fn convert_event(
token_ids, token_ids,
block_size, block_size,
lora_id, lora_id,
block_mm_infos,
.. ..
} => { } => {
let num_block_tokens = vec![block_size as u64; block_hashes.len()]; let num_block_tokens = vec![block_size as u64; block_hashes.len()];
...@@ -563,6 +564,7 @@ fn convert_event( ...@@ -563,6 +564,7 @@ fn convert_event(
&block_hashes_u64, &block_hashes_u64,
lora_id.unwrap_or(0), lora_id.unwrap_or(0),
warning_count, warning_count,
block_mm_infos.as_deref(),
), ),
}), }),
dp_rank, dp_rank,
...@@ -595,18 +597,25 @@ pub fn create_stored_block_from_parts( ...@@ -595,18 +597,25 @@ pub fn create_stored_block_from_parts(
block_hash: u64, block_hash: u64,
token_ids: &[u32], token_ids: &[u32],
_lora_id: u64, _lora_id: u64,
mm_extra_info: Option<BlockExtraInfo>,
) -> KvCacheStoredBlockData { ) -> KvCacheStoredBlockData {
let tokens_hash = compute_block_hash_for_seq(token_ids, kv_block_size)[0]; // Compute tokens_hash including MM info if present
let block_mm_infos = mm_extra_info.as_ref().map(|info| vec![Some(info.clone())]);
let tokens_hash =
compute_block_hash_for_seq(token_ids, kv_block_size, block_mm_infos.as_deref())[0];
tracing::trace!( tracing::trace!(
"Creating stored block: external_block_hash={}, tokens_hash={}, token_ids={:?}, kv_block_size={}", "Creating stored block: external_block_hash={}, tokens_hash={}, token_ids={:?}, kv_block_size={}, mm_extra_info={:?}",
block_hash, block_hash,
tokens_hash.0, tokens_hash.0,
token_ids, token_ids,
kv_block_size kv_block_size,
mm_extra_info
); );
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash::from(block_hash), block_hash: ExternalSequenceBlockHash::from(block_hash),
tokens_hash, tokens_hash,
mm_extra_info,
} }
} }
...@@ -617,11 +626,14 @@ pub fn create_stored_blocks( ...@@ -617,11 +626,14 @@ pub fn create_stored_blocks(
block_hashes: &[u64], block_hashes: &[u64],
lora_id: u64, lora_id: u64,
warning_count: &Arc<AtomicU32>, warning_count: &Arc<AtomicU32>,
block_mm_infos: Option<&[Option<BlockExtraInfo>]>,
) -> Vec<KvCacheStoredBlockData> { ) -> Vec<KvCacheStoredBlockData> {
let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new(); let mut blocks: Vec<KvCacheStoredBlockData> = Vec::new();
let mut token_offset: usize = 0; let mut token_offset: usize = 0;
for (num_tokens_it, block_hash_it) in num_block_tokens.iter().zip(block_hashes.iter()) { for (block_idx, (num_tokens_it, block_hash_it)) in
num_block_tokens.iter().zip(block_hashes.iter()).enumerate()
{
if *num_tokens_it != kv_block_size as u64 { if *num_tokens_it != kv_block_size as u64 {
if warning_count.fetch_add(1, Ordering::Relaxed) < 3 { if warning_count.fetch_add(1, Ordering::Relaxed) < 3 {
tracing::warn!( tracing::warn!(
...@@ -634,11 +646,16 @@ pub fn create_stored_blocks( ...@@ -634,11 +646,16 @@ pub fn create_stored_blocks(
} }
let tokens = &token_ids[token_offset..(token_offset + *num_tokens_it as usize)]; let tokens = &token_ids[token_offset..(token_offset + *num_tokens_it as usize)];
let mm_extra_info = block_mm_infos
.and_then(|infos| infos.get(block_idx))
.and_then(|opt| opt.clone());
blocks.push(create_stored_block_from_parts( blocks.push(create_stored_block_from_parts(
kv_block_size, kv_block_size,
*block_hash_it, *block_hash_it,
tokens, tokens,
lora_id, lora_id,
mm_extra_info,
)); ));
token_offset += *num_tokens_it as usize; token_offset += *num_tokens_it as usize;
} }
...@@ -702,6 +719,9 @@ enum RawKvEvent { ...@@ -702,6 +719,9 @@ enum RawKvEvent {
lora_id: Option<u64>, lora_id: Option<u64>,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
medium: Option<String>, medium: Option<String>,
/// Multimodal extra info for each block (length should match block_hashes)
#[serde(default, skip_serializing_if = "Option::is_none")]
block_mm_infos: Option<Vec<Option<BlockExtraInfo>>>,
}, },
BlockRemoved { BlockRemoved {
block_hashes: Vec<BlockHashValue>, block_hashes: Vec<BlockHashValue>,
...@@ -747,6 +767,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -747,6 +767,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
let mut block_size: Option<usize> = None; let mut block_size: Option<usize> = None;
let mut lora_id: Option<Option<u64>> = None; let mut lora_id: Option<Option<u64>> = None;
let mut medium: Option<Option<String>> = None; let mut medium: Option<Option<String>> = None;
let mut block_mm_infos: Option<Option<Vec<Option<BlockExtraInfo>>>> = None;
while let Some(key) = map.next_key::<String>()? { while let Some(key) = map.next_key::<String>()? {
match key.as_str() { match key.as_str() {
...@@ -771,6 +792,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -771,6 +792,9 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
"medium" => { "medium" => {
medium = Some(map.next_value()?); medium = Some(map.next_value()?);
} }
"block_mm_infos" => {
block_mm_infos = Some(map.next_value()?);
}
_ => { _ => {
map.next_value::<IgnoredAny>()?; map.next_value::<IgnoredAny>()?;
} }
...@@ -791,6 +815,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -791,6 +815,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
block_size, block_size,
lora_id: lora_id.unwrap_or(None), lora_id: lora_id.unwrap_or(None),
medium: medium.unwrap_or(None), medium: medium.unwrap_or(None),
block_mm_infos: block_mm_infos.unwrap_or(None),
}) })
} }
Some("BlockRemoved") => { Some("BlockRemoved") => {
...@@ -836,6 +861,8 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -836,6 +861,8 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
.ok_or_else(|| de::Error::invalid_length(4, &"missing block_size"))?; .ok_or_else(|| de::Error::invalid_length(4, &"missing block_size"))?;
let lora_id: Option<u64> = seq.next_element()?.unwrap_or(None); let lora_id: Option<u64> = seq.next_element()?.unwrap_or(None);
let medium: Option<String> = seq.next_element()?.unwrap_or(None); let medium: Option<String> = seq.next_element()?.unwrap_or(None);
let block_mm_infos: Option<Vec<Option<BlockExtraInfo>>> =
seq.next_element()?.unwrap_or(None);
while seq.next_element::<IgnoredAny>()?.is_some() {} while seq.next_element::<IgnoredAny>()?.is_some() {}
...@@ -846,6 +873,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor { ...@@ -846,6 +873,7 @@ impl<'de> Visitor<'de> for RawKvEventVisitor {
block_size, block_size,
lora_id, lora_id,
medium, medium,
block_mm_infos,
}) })
} }
"BlockRemoved" => { "BlockRemoved" => {
...@@ -1088,11 +1116,12 @@ mod test_event_processing { ...@@ -1088,11 +1116,12 @@ mod test_event_processing {
let token_ids = vec![10, 20, 30, 40]; let token_ids = vec![10, 20, 30, 40];
let blk_hash = 0xdead_beef; let blk_hash = 0xdead_beef;
let stored = create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, 0); let stored = create_stored_block_from_parts(kv_block_size, blk_hash, &token_ids, 0, None);
assert_eq!(stored.block_hash.0, blk_hash); assert_eq!(stored.block_hash.0, blk_hash);
let expected_hash = compute_block_hash_for_seq(&token_ids, 4)[0]; let expected_hash = compute_block_hash_for_seq(&token_ids, 4, None)[0];
assert_eq!(stored.tokens_hash, expected_hash); assert_eq!(stored.tokens_hash, expected_hash);
assert!(stored.mm_extra_info.is_none());
} }
// --------------------------------------------------------------------- // ---------------------------------------------------------------------
...@@ -1113,6 +1142,7 @@ mod test_event_processing { ...@@ -1113,6 +1142,7 @@ mod test_event_processing {
&block_hashes, &block_hashes,
/*lora_id=*/ 0, /*lora_id=*/ 0,
&Arc::new(AtomicU32::new(0)), &Arc::new(AtomicU32::new(0)),
None,
); );
assert_eq!(blocks.len(), 2); assert_eq!(blocks.len(), 2);
...@@ -1136,6 +1166,7 @@ mod test_event_processing { ...@@ -1136,6 +1166,7 @@ mod test_event_processing {
&block_hashes, &block_hashes,
/*lora_id=*/ 0, /*lora_id=*/ 0,
&warning_count, &warning_count,
None,
); );
// should early-exit as second has mismatch // should early-exit as second has mismatch
...@@ -1156,6 +1187,7 @@ mod test_event_processing { ...@@ -1156,6 +1187,7 @@ mod test_event_processing {
block_size: 4, block_size: 4,
lora_id: Some(0), lora_id: Some(0),
medium: None, medium: None,
block_mm_infos: None,
}; };
let out = convert_event(raw_evt, 42, kv_block_size, 0, &Arc::new(AtomicU32::new(0))); let out = convert_event(raw_evt, 42, kv_block_size, 0, &Arc::new(AtomicU32::new(0)));
...@@ -1303,10 +1335,12 @@ mod tests_startup_helpers { ...@@ -1303,10 +1335,12 @@ mod tests_startup_helpers {
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100), block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200), tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}, },
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(101), block_hash: ExternalSequenceBlockHash(101),
tokens_hash: LocalBlockHash(201), tokens_hash: LocalBlockHash(201),
mm_extra_info: None,
}, },
], ],
}), }),
...@@ -1391,6 +1425,7 @@ mod tests_startup_helpers { ...@@ -1391,6 +1425,7 @@ mod tests_startup_helpers {
blocks: vec![KvCacheStoredBlockData { blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100), block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200), tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
...@@ -1471,6 +1506,7 @@ mod tests_startup_helpers { ...@@ -1471,6 +1506,7 @@ mod tests_startup_helpers {
blocks: vec![KvCacheStoredBlockData { blocks: vec![KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100), block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200), tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}], }],
}), }),
dp_rank: 0, dp_rank: 0,
...@@ -1615,6 +1651,7 @@ mod tests_startup_helpers { ...@@ -1615,6 +1651,7 @@ mod tests_startup_helpers {
block_size: 4, block_size: 4,
lora_id: None, lora_id: None,
medium: None, medium: None,
block_mm_infos: None,
}]; }];
let batch = KvEventBatch { let batch = KvEventBatch {
...@@ -1705,10 +1742,12 @@ mod tests_startup_helpers { ...@@ -1705,10 +1742,12 @@ mod tests_startup_helpers {
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100), block_hash: ExternalSequenceBlockHash(100),
tokens_hash: LocalBlockHash(200), tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}, },
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(101), block_hash: ExternalSequenceBlockHash(101),
tokens_hash: LocalBlockHash(201), tokens_hash: LocalBlockHash(201),
mm_extra_info: None,
}, },
], ],
}), }),
...@@ -1769,10 +1808,12 @@ mod tests_startup_helpers { ...@@ -1769,10 +1808,12 @@ mod tests_startup_helpers {
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(100), // Shared prefix block_hash: ExternalSequenceBlockHash(100), // Shared prefix
tokens_hash: LocalBlockHash(200), tokens_hash: LocalBlockHash(200),
mm_extra_info: None,
}, },
KvCacheStoredBlockData { KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(102), // New block block_hash: ExternalSequenceBlockHash(102), // New block
tokens_hash: LocalBlockHash(202), tokens_hash: LocalBlockHash(202),
mm_extra_info: None,
}, },
], ],
}), }),
......
...@@ -24,6 +24,7 @@ mod tests { ...@@ -24,6 +24,7 @@ mod tests {
.map(|i| KvCacheStoredBlockData { .map(|i| KvCacheStoredBlockData {
tokens_hash: LocalBlockHash(*i), tokens_hash: LocalBlockHash(*i),
block_hash: ExternalSequenceBlockHash(*i * 100), block_hash: ExternalSequenceBlockHash(*i * 100),
mm_extra_info: None,
}) })
.collect() .collect()
} }
......
...@@ -139,6 +139,7 @@ impl KvManager { ...@@ -139,6 +139,7 @@ impl KvManager {
.map(|(global_hash, local_hash)| KvCacheStoredBlockData { .map(|(global_hash, local_hash)| KvCacheStoredBlockData {
block_hash: ExternalSequenceBlockHash(global_hash), block_hash: ExternalSequenceBlockHash(global_hash),
tokens_hash: LocalBlockHash(*local_hash), tokens_hash: LocalBlockHash(*local_hash),
mm_extra_info: None,
}) })
.collect(), .collect(),
}) })
......
...@@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize}; ...@@ -8,7 +8,7 @@ use serde::{Deserialize, Serialize};
use super::timing::RequestTracker; use super::timing::RequestTracker;
use super::{OutputOptions, SamplingOptions, StopConditions}; use super::{OutputOptions, SamplingOptions, StopConditions};
use crate::kv_router::RouterConfigOverride; use crate::kv_router::{RouterConfigOverride, protocols::RequestExtraInfo};
#[cfg(feature = "media-nixl")] #[cfg(feature = "media-nixl")]
use crate::preprocessor::media::RdmaMediaDataDescriptor; use crate::preprocessor::media::RdmaMediaDataDescriptor;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
...@@ -118,6 +118,10 @@ pub struct PreprocessedRequest { ...@@ -118,6 +118,10 @@ pub struct PreprocessedRequest {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub extra_fields: Option<Vec<String>>, pub extra_fields: Option<Vec<String>>,
/// Multimodal request-level metadata (mm_hash and token offsets)
#[builder(default)]
#[serde(default, skip_serializing_if = "Option::is_none")]
pub request_extra_info: Option<RequestExtraInfo>,
/// Optional request tracker for per-request metrics (shared with DeltaGenerator) /// Optional request tracker for per-request metrics (shared with DeltaGenerator)
#[builder(default)] #[builder(default)]
#[serde(skip)] #[serde(skip)]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment