Unverified Commit bb43fada authored by dagil-nvidia's avatar dagil-nvidia Committed by GitHub
Browse files

chore: remove outdated router_standalone_trtllm example and add standalone router docs (#7278)


Signed-off-by: default avatarakshatha-k <akshutk@gmail.com>
Signed-off-by: default avatarDan Gil <dagil@nvidia.com>
Co-authored-by: default avatarakshatha-k <akshutk@gmail.com>
parent bb07b2f4
...@@ -23,6 +23,10 @@ For Kubernetes, set `DYN_ROUTER_MODE=kv` on the Frontend service. Workers automa ...@@ -23,6 +23,10 @@ For Kubernetes, set `DYN_ROUTER_MODE=kv` on the Frontend service. Workers automa
| `--no-router-kv-events` | enabled | Fall back to approximate routing (no event consumption from workers) | | `--no-router-kv-events` | enabled | Fall back to approximate routing (no event consumption from workers) |
| `--router-queue-threshold` | disabled | Enable backpressure queue under high concurrency; also enables priority scheduling via `nvext.agent_hints.latency_sensitivity` | | `--router-queue-threshold` | disabled | Enable backpressure queue under high concurrency; also enables priority scheduling via `nvext.agent_hints.latency_sensitivity` |
### Standalone Router
You can also run the KV router as a standalone service (without the Dynamo frontend). See the [Standalone Router component](../../../components/src/dynamo/router/) for more details.
For all CLI arguments, environment variables, K8s deployment examples, and tuning guidelines, see the [Router Guide](router-guide.md). For A/B benchmarking, see the [KV Router A/B Benchmarking Guide](../../benchmarks/kv-router-ab-testing.md). For all CLI arguments, environment variables, K8s deployment examples, and tuning guidelines, see the [Router Guide](router-guide.md). For A/B benchmarking, see the [KV Router A/B Benchmarking Guide](../../benchmarks/kv-router-ab-testing.md).
## Prerequisites and Limitations ## Prerequisites and Limitations
......
...@@ -82,6 +82,10 @@ All CLI arguments can be configured via environment variables using the `DYN_` p ...@@ -82,6 +82,10 @@ All CLI arguments can be configured via environment variables using the `DYN_` p
For complete K8s examples and advanced configuration, see [K8s Examples](router-examples.md#k8s-examples). For complete K8s examples and advanced configuration, see [K8s Examples](router-examples.md#k8s-examples).
For A/B testing and advanced K8s setup, see the [KV Router A/B Benchmarking Guide](../../benchmarks/kv-router-ab-testing.md). For A/B testing and advanced K8s setup, see the [KV Router A/B Benchmarking Guide](../../benchmarks/kv-router-ab-testing.md).
### Standalone Router
You can also run the KV router as a standalone service (without the Dynamo frontend) for disaggregated serving (e.g., routing to prefill workers), multi-tier architectures, or any scenario requiring intelligent KV cache-aware routing decisions. See the [Standalone Router component](../../../components/src/dynamo/router/) for more details.
## KV Cache Routing ## KV Cache Routing
KV cache routing optimizes large language model inference by intelligently directing requests to workers with the most relevant cached data. By maximizing cache reuse, it reduces redundant computation and improves both throughput and latency. KV cache routing optimizes large language model inference by intelligently directing requests to workers with the most relevant cached data. By maximizing cache reuse, it reduces redundant computation and improves both throughput and latency.
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Router Standalone - TensorRT-LLM
A standalone implementation of KvRouter that demonstrates usage with TensorRT-LLM workers, without dependency on the dynamo runtime, etcd control plane, or nats event plane.
## Overview
This example shows how to use KvRouter with TensorRT-LLM workers to intelligently route requests across multiple GPUs based on KV cache overlap and load metrics. The router maintains a view of each worker's cached blocks and routes new requests to the worker with the best combination of cache overlap and available capacity.
Key features:
- **KV cache-aware routing**: Routes requests to workers with matching cached blocks
- **Multimodal support**: Handles vision-language models (e.g., Qwen2-VL) with image inputs
- **MM hash routing**: Identical images produce identical hashes for cache reuse
## How It Works
### Core Architecture
The router uses a **RadixTree** data structure (written in Rust) to efficiently track which blocks each worker has cached. When a new request arrives, the router:
1. Tokenizes the request and computes block hashes (including MM hashes for images)
2. Uses `find_matches` to calculate overlap scores between the request and each worker's cached blocks
3. Combines this with current load metrics to select the optimal worker
4. Routes the request to the chosen worker for processing
### Multimodal Routing
For vision-language models:
1. Images are processed using `default_multimodal_input_loader` from TensorRT-LLM
2. Image placeholders are expanded to visual tokens using HuggingFace `AutoProcessor`
3. `apply_mm_hashes` computes a content hash for each image
4. The MM hash is included in block hash computation, so identical images produce cache hits
### Event-Driven Updates
The router receives two types of events from TensorRT-LLM engines:
1. **KV Events**: Emitted automatically when blocks are stored/removed from cache (includes `mm_keys` for multimodal)
2. **Load Metrics**: GPU cache usage and waiting request count
## Components
### `worker.py`
- **TrtllmWorkers**: Manages multiple TensorRT-LLM worker processes
- Each worker runs on a separate GPU with KV cache event emission enabled
- Publishes metrics and KV events over ZMQ
- Extracts `mm_hash` from TRTLLM's `mm_keys` field for multimodal routing
### `router.py`
- **KvRouter**: Core routing logic using RadixTree
- Subscribes to KV cache events and load metrics from workers
- Implements `get_best_worker()` to select optimal routing destination
### `api.py`
- **ServiceAPI**: FastAPI server providing OpenAI-compatible chat completions endpoint
- Handles multimodal inputs (images) via `default_multimodal_input_loader`
- Computes block hashes including MM hashes for routing decisions
- Streams responses in OpenAI format
### `test_router.py`
- Comprehensive test suite for router functionality
- Includes local hash computation tests and server-side multimodal tests
- Run with `--mm-only` for multimodal-specific tests
## Requirements
- **TensorRT-LLM >= 1.2.0rc6**: You need TensorRT-LLM version 1.2.0rc6 or later, which includes multimodal information (`mm_keys`) in KV cache events. This is required for MM hash-based routing. See [PR #9604](https://github.com/NVIDIA/TensorRT-LLM/pull/9604) for details.
- TensorRT-LLM with pytorch backend
- Multiple GPUs (one per worker)
- Python 3.10+
- Required packages: fastapi, uvicorn, httpx, zmq, tensorrt_llm, transformers
## Usage
### 1. Start the API Server
```bash
python api.py \
--model Qwen/Qwen2-VL-2B-Instruct \
--num-workers 2 \
--block-size 32 \
--base-kv-events-port 5557 \
--base-metrics-port 5657 \
--router-port 7000 \
--http-port 8000
```
This will:
- Initialize TensorRT-LLM engines on each GPU
- Start ZMQ publishers for metrics and KV events
- Start the router service
- Start the OpenAI-compatible API server
### 2. Test with curl
**Text-only request:**
```bash
curl -s http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2-VL-2B-Instruct",
"messages": [{"role": "user", "content": "Hello, how are you?"}],
"max_tokens": 100,
"stream": false
}' | jq
```
**Multimodal request (with images):**
```bash
curl -s -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2-VL-2B-Instruct",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": "Describe both images in detail."},
{"type": "image_url", "image_url": {"url": "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg"}},
{"type": "image_url", "image_url": {"url": "http://images.cocodataset.org/test2017/000000000001.jpg"}}
]
}],
"max_tokens": 500,
"stream": false
}' | jq
```
### 3. Run Tests
```bash
# Run all tests
python test_router.py
# Run multimodal tests only
python test_router.py --mm-only
# Verbose output
python test_router.py -v
```
### 4. Check endpoint health
```bash
./ping.sh
```
## Configuration
### Command-line Arguments
- `--model`: HuggingFace model name (default: Qwen/Qwen2-VL-2B-Instruct)
- `--num-workers`: Number of GPU workers (default: 2)
- `--block-size`: KV cache block size (default: 32, TensorRT-LLM's default)
- `--base-kv-events-port`: Base port for KV events ZMQ (default: 5557)
- `--base-metrics-port`: Base port for metrics ZMQ (default: 5657)
- `--router-port`: Router HTTP service port (default: 7000)
- `--http-port`: API server port (default: 8000)
### Environment Variables
- `DYNAMO_DEBUG=1`: Enable debug file dumps to `/tmp/debug_*.txt`
- `LOGLEVEL=DEBUG`: Set logging level (DEBUG, INFO, WARNING, ERROR)
- `TRANSFORMERS_ATTN_IMPLEMENTATION=eager`: Disable FlashAttention (set automatically)
- `TRTLLM_MAX_NUM_TOKENS`: Set max token length
### Port Assignment
Workers use sequential ports:
- Worker 0: KV events on 5557, metrics on 5657
- Worker 1: KV events on 5558, metrics on 5658
- Worker N: KV events on 5557+N, metrics on 5657+N
## Architecture Diagram
```
┌─────────────┐
│ Client │
└──────┬──────┘
│ HTTP
┌─────────────────┐
│ API Server │
│ (api.py) │
└────────┬────────┘
│ HTTP
┌─────────────────┐
│ Router │──┐
│ (router.py) │ │ ZMQ (KV Events)
└────────┬────────┘ │
│ │
│ Select │
│ Worker │
▼ │
┌─────────────────┐ │
│ TrtllmWorkers │ │
│ (worker.py) │◄-┘
└─────────────────┘
│ │
▼ ▼
GPU 0 GPU 1
```
## Multimodal KV Cache Routing
When processing multimodal requests:
1. **API Layer** (`api.py`):
- Parses OpenAI-format messages with `image_url` content
- Uses `default_multimodal_input_loader` to process images
- Expands image placeholders to visual tokens via `AutoProcessor`
- Computes `mm_hash` using `apply_mm_hashes`
- Includes `mm_hash` in block hash computation for routing
2. **Worker Layer** (`worker.py`):
- Receives multimodal input and passes to TRTLLM
- Extracts `mm_hash` from TRTLLM's `mm_keys` in KV events
- Publishes KV events with `mm_extra_info` to router
3. **Router Layer** (`router.py`):
- RadixTree matches blocks including MM hash
- Same image content = same hash = cache hit on same worker
## Notes
- This is a standalone implementation for pedagogical purposes
- Production dynamo uses NATS for events and etcd for service discovery
- Each worker needs its own GPU
- TensorRT-LLM models may take time to compile on first run
## See Also
- [TensorRT-LLM KV Event Documentation](https://nvidia.github.io/TensorRT-LLM/0.21.0/examples/llm_inference_kv_events.html)
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
# Fix protobuf version conflict with etcd3
os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
import argparse
import asyncio
import json
import logging
import time
import uuid
from dataclasses import dataclass
from typing import Optional
import httpx
import uvicorn
from fastapi import FastAPI
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from router import RouterAPI, RouterRequest, RouterResponse
from tensorrt_llm.inputs.multimodal import apply_mm_hashes
from tensorrt_llm.inputs.utils import default_multimodal_input_loader, load_image
from tensorrt_llm.llmapi.tokenizer import tokenizer_factory
from transformers import AutoProcessor
from worker import TrtllmWorkers
from dynamo._core import compute_block_hash_for_seq
logger = logging.getLogger(__name__)
# Debug flag: set DYNAMO_DEBUG=1 to enable debug file dumps
DEBUG_ENABLED = os.environ.get("DYNAMO_DEBUG", "0") == "1"
DEBUG_API_FILE = "/tmp/debug_api_hashes.txt"
# Qwen2-VL specific token IDs
QWEN2_VL_IMAGE_TOKEN_ID = 151655
QWEN2_VL_REPLACEMENT_ID = 151937
def dump_api_debug(
tokens: list[int],
block_size: int,
local_hashes: list[int],
mm_hashes: list[int] | None,
block_mm_infos: list | None,
image_urls: list[str] | None,
):
"""Dump API-side hash computation to file for debugging."""
if not DEBUG_ENABLED:
return
import datetime
with open(DEBUG_API_FILE, "a") as f:
f.write(f"\n{'='*60}\n")
f.write(f"Timestamp: {datetime.datetime.now()}\n")
f.write(f"Image URLs: {image_urls}\n")
f.write(f"mm_hashes: {mm_hashes}\n")
f.write(f"block_size: {block_size}\n")
f.write(f"num_tokens: {len(tokens)}\n")
f.write(f"tokens (first 50): {tokens[:50]}\n")
f.write(f"tokens (last 50): {tokens[-50:]}\n")
f.write(f"block_mm_infos: {block_mm_infos}\n")
f.write(f"local_hashes ({len(local_hashes)}): {local_hashes}\n")
f.write(f"{'='*60}\n")
def make_error(message: str, error_type: str, code: int) -> dict:
"""Create a standardized error response dict."""
return {"message": message, "type": error_type, "code": code}
# Pydantic models for OpenAI-compatible API
class ImageUrl(BaseModel):
url: str
class ContentPart(BaseModel):
type: str # "text" | "image_url"
text: Optional[str] = None
image_url: Optional[ImageUrl] = None
class Message(BaseModel):
role: str
content: str | list[ContentPart]
class ChatCompletionRequest(BaseModel):
model: str
messages: list[Message]
max_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
temperature: Optional[float] = 1.0
top_p: Optional[float] = 1.0
stream: bool = True
class ErrorResponse(BaseModel):
error: dict
@dataclass(frozen=True)
class ServingParams:
"""Configuration parameters for the serving API."""
model: str
model_type: str # e.g., "qwen2_vl", "llava"
block_size: int
num_workers: int
base_kv_events_port: int
base_metrics_port: int
router_port: int
http_port: int
@dataclass
class ParsedRequest:
"""Parsed and preprocessed request data."""
messages_dict: list[dict]
image_urls: list[str]
max_tokens: int
temperature: float
top_p: float
model: str
@dataclass
class ProcessedInput:
"""Processed input ready for routing and generation."""
tokens: list[int]
mm_input: dict | None # For multimodal requests
mm_hashes: list[int] | None # List of mm_hash for each image
image_offsets_list: list[list[int]] | None # List of [start, end] for each image
class ServiceAPI:
"""Main API service handling chat completion requests with KV cache routing."""
def __init__(self, init_params: ServingParams):
self.init_params = init_params
self.app = FastAPI(title="TensorRT-LLM Router API", version="0.0.1")
self.workers: Optional[TrtllmWorkers] = None
self.tokenizer = None
self.processor = None
self.http_client: Optional[httpx.AsyncClient] = None
self._setup_routes()
# -------------------------------------------------------------------------
# Request Parsing Helpers
# -------------------------------------------------------------------------
def _parse_request(
self, request: ChatCompletionRequest
) -> ParsedRequest | ErrorResponse:
"""Parse and validate the incoming request."""
max_tokens = request.max_completion_tokens or request.max_tokens
if max_tokens is None:
return ErrorResponse(
error=make_error(
"Either max_tokens or max_completion_tokens must be specified",
"invalid_request_error",
400,
)
)
messages_dict, image_urls = self._extract_messages(request.messages)
return ParsedRequest(
messages_dict=messages_dict,
image_urls=image_urls,
max_tokens=max_tokens,
temperature=request.temperature,
top_p=request.top_p,
model=request.model,
)
def _extract_messages(
self, messages: list[Message]
) -> tuple[list[dict], list[str]]:
"""Extract text messages and image URLs from request messages."""
messages_dict = []
image_urls = []
for msg in messages:
if isinstance(msg.content, str):
messages_dict.append({"role": msg.role, "content": msg.content})
else:
text_parts = []
for part in msg.content:
if part.type == "text" and part.text:
text_parts.append(part.text)
elif part.type == "image_url" and part.image_url:
image_urls.append(part.image_url.url)
messages_dict.append(
{"role": msg.role, "content": " ".join(text_parts)}
)
return messages_dict, image_urls
def _build_prompt(self, messages_dict: list[dict]) -> str:
"""Build prompt text from messages using chat template."""
try:
return self.tokenizer.apply_chat_template(
messages_dict, tokenize=False, add_generation_prompt=True
)
except Exception as e:
logger.warning(f"Chat template failed: {e}, using simple format")
return self._format_messages_simple(messages_dict)
def _format_messages_simple(self, messages: list[dict]) -> str:
"""Simple fallback formatting when chat template is unavailable."""
parts = []
role_map = {"system": "System", "user": "User", "assistant": "Assistant"}
for msg in messages:
prefix = role_map.get(msg["role"], msg["role"].capitalize())
parts.append(f"{prefix}: {msg['content']}\n")
parts.append("Assistant: ")
return "\n".join(parts)
# -------------------------------------------------------------------------
# Multimodal Processing Helpers
# -------------------------------------------------------------------------
def _process_multimodal(self, prompt: str, image_urls: list[str]) -> ProcessedInput:
"""Process multimodal request: load images, compute tokens and mm_hashes."""
try:
# Use "multiple_image" modality when there are multiple images
modality = "multiple_image" if len(image_urls) > 1 else "image"
inputs = default_multimodal_input_loader(
tokenizer=self.tokenizer,
model_dir=self.init_params.model,
model_type=self.init_params.model_type,
modality=modality,
prompts=[prompt],
media=[image_urls],
# Align hash input type with backend multimodal processor path.
image_data_format="pil",
device="cuda",
)
mm_input = inputs[0]
processed_prompt = mm_input.get("prompt", prompt)
multi_modal_data = mm_input.get("multi_modal_data")
tokens, image_offsets_list = self._get_mm_tokens(
processed_prompt, image_urls
)
mm_hashes = self._compute_mm_hashes(multi_modal_data)
return ProcessedInput(
tokens=tokens,
mm_input=mm_input,
mm_hashes=mm_hashes,
image_offsets_list=image_offsets_list,
)
except Exception as e:
logger.warning(f"MM processing failed: {e}, falling back to text-only")
return ProcessedInput(
tokens=self.tokenizer.encode(prompt),
mm_input=None,
mm_hashes=None,
image_offsets_list=None,
)
def _get_mm_tokens(
self, prompt: str, image_urls: list[str]
) -> tuple[list[int], list[list[int]] | None]:
"""Get tokens with visual expansion and find image token positions."""
if self.processor is None:
return self.tokenizer.encode(prompt), None
pil_images = [load_image(url, format="pil") for url in image_urls]
processor_output = self.processor(
text=[prompt], images=pil_images, return_tensors="pt", padding=True
)
tokens = processor_output["input_ids"][0].tolist()
image_token_id = getattr(
self.processor, "image_token_id", QWEN2_VL_IMAGE_TOKEN_ID
)
return self._replace_image_tokens(
tokens, image_token_id, QWEN2_VL_REPLACEMENT_ID
)
def _replace_image_tokens(
self, tokens: list[int], image_token_id: int, replacement_id: int
) -> tuple[list[int], list[list[int]] | None]:
"""Replace image tokens and return their positions as list of [start, end] per image.
Finds contiguous regions of image tokens. Each contiguous region is assumed
to be one image.
"""
image_offsets_list: list[list[int]] = []
current_start: int | None = None
for i, t in enumerate(tokens):
if t == image_token_id:
if current_start is None:
current_start = i
tokens[i] = replacement_id
else:
# End of a contiguous image token region
if current_start is not None:
image_offsets_list.append([current_start, i])
current_start = None
# Handle case where image tokens go to the end
if current_start is not None:
image_offsets_list.append([current_start, len(tokens)])
if image_offsets_list:
logger.debug(f"Image token regions: {image_offsets_list}")
return tokens, image_offsets_list
return tokens, None
def _compute_mm_hashes(self, multi_modal_data: dict | None) -> list[int] | None:
"""Compute mm_hash for each image in multimodal data.
Returns:
List of mm_hash (one per image), or None if no images.
"""
if not multi_modal_data:
return None
# TRT-LLM 1.3 returns Tuple[Dict[str, List[str]], Optional[List[Optional[str]]]].
mm_hashes_dict = apply_mm_hashes(multi_modal_data)[0]
if not isinstance(mm_hashes_dict, dict) or not mm_hashes_dict:
return None
# Prefer image modality for stable behavior, but fall back to flattening
# all modality hashes to stay forward-compatible.
hash_hexes = mm_hashes_dict.get("image")
if not hash_hexes:
hash_hexes = [
h for hashes in mm_hashes_dict.values() for h in (hashes or [])
]
if hash_hexes:
# Convert each 256-bit hex digest to 64-bit int
mm_hashes = [int(hex_digest[:16], 16) for hex_digest in hash_hexes]
logger.debug(f"Computed mm_hashes for {len(mm_hashes)} images: {mm_hashes}")
return mm_hashes
return None
# -------------------------------------------------------------------------
# Routing Helpers
# -------------------------------------------------------------------------
def _build_block_mm_infos(
self,
num_tokens: int,
mm_hashes: list[int] | None,
image_offsets_list: list[list[int]] | None,
) -> list[dict | None] | None:
"""Build block_mm_infos for routing hash computation.
For each block, includes mm_objects for all images that overlap with that block.
Args:
num_tokens: Total number of tokens
mm_hashes: List of mm_hash, one per image
image_offsets_list: List of [start, end] offsets, one per image
Returns:
List of mm_info dicts (one per block), with None for blocks without images.
"""
if mm_hashes is None or image_offsets_list is None:
return None
if len(mm_hashes) != len(image_offsets_list):
logger.warning(
f"mm_hashes ({len(mm_hashes)}) and image_offsets_list "
f"({len(image_offsets_list)}) length mismatch"
)
return None
block_size = self.init_params.block_size
num_blocks = (num_tokens + block_size - 1) // block_size
result: list[dict | None] = []
for block_idx in range(num_blocks):
block_start = block_idx * block_size
block_end = block_start + block_size
# Find all images that overlap with this block
mm_objects = []
for mm_hash, offsets in zip(mm_hashes, image_offsets_list):
img_start, img_end = offsets
if block_end > img_start and block_start < img_end:
mm_objects.append({"mm_hash": mm_hash, "offsets": [offsets]})
if mm_objects:
result.append({"mm_objects": mm_objects})
else:
result.append(None)
return result
async def _route_request(
self, local_hashes: list[int], num_tokens: int
) -> int | ErrorResponse:
"""Query router for best worker ID."""
try:
router_request = RouterRequest(
local_hashes=local_hashes, num_tokens=num_tokens
)
response = await self.http_client.post(
f"http://localhost:{self.init_params.router_port}/find_best_worker",
json=router_request.model_dump(),
timeout=1,
)
response.raise_for_status()
return RouterResponse.model_validate(response.json()).worker_id
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.error(f"Router request failed: {e}")
return ErrorResponse(
error=make_error(
"Router service unavailable", "service_unavailable", 503
)
)
# -------------------------------------------------------------------------
# Response Streaming
# -------------------------------------------------------------------------
async def _stream_response(
self, request: ChatCompletionRequest, result_generator, request_id: str
):
"""Generate SSE formatted streaming responses."""
created = int(time.time())
first_chunk = True
try:
async for output in result_generator:
# Handle both dict (from worker) and object responses
if isinstance(output, dict):
text = output.get("text_diff") or output.get("text", "")
else:
text = getattr(output, "text_diff", None) or getattr(
output, "text", ""
)
if not text and not first_chunk:
continue
delta = (
{"role": "assistant", "content": text}
if first_chunk
else {"content": text}
)
yield self._format_chunk(
request_id, created, request.model, delta, None
)
first_chunk = False
# Final chunk
yield self._format_chunk(request_id, created, request.model, {}, "stop")
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {e}")
yield f"data: {json.dumps({'error': make_error(str(e), 'internal_error', 500)})}\n\n"
def _format_chunk(
self,
request_id: str,
created: int,
model: str,
delta: dict,
finish_reason: str | None,
) -> str:
"""Format a single SSE chunk."""
chunk = {
"id": request_id,
"object": "chat.completion.chunk",
"created": created,
"model": model,
"choices": [{"index": 0, "delta": delta, "finish_reason": finish_reason}],
}
return f"data: {json.dumps(chunk)}\n\n"
async def _generate_full_response(
self, request: ChatCompletionRequest, result_generator, request_id: str
) -> dict:
"""Collect all outputs and generate a complete (non-streaming) response."""
created = int(time.time())
full_text = ""
try:
async for output in result_generator:
if isinstance(output, dict):
text = output.get("text_diff") or output.get("text", "")
else:
text = getattr(output, "text_diff", None) or getattr(
output, "text", ""
)
full_text += text
return {
"id": request_id,
"object": "chat.completion",
"created": created,
"model": request.model,
"choices": [
{
"index": 0,
"message": {"role": "assistant", "content": full_text},
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 0, # Not tracked in this implementation
"completion_tokens": 0,
"total_tokens": 0,
},
}
except Exception as e:
logger.error(f"Generation error: {e}")
return {"error": make_error(str(e), "internal_error", 500)}
# -------------------------------------------------------------------------
# Main Request Handler
# -------------------------------------------------------------------------
def _setup_routes(self):
@self.app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
# Check service readiness
if (
self.workers is None
or self.tokenizer is None
or self.http_client is None
):
return ErrorResponse(
error=make_error("Service not ready", "service_unavailable", 503)
)
try:
# Parse request
parsed = self._parse_request(request)
if isinstance(parsed, ErrorResponse):
return parsed
# Process input (multimodal or text-only)
if parsed.image_urls:
# For multimodal: pass raw text, let default_multimodal_input_loader apply chat template
raw_text = " ".join(
msg["content"]
for msg in parsed.messages_dict
if msg.get("content")
)
processed = self._process_multimodal(raw_text, parsed.image_urls)
else:
# For text-only: apply chat template ourselves
prompt = self._build_prompt(parsed.messages_dict)
processed = ProcessedInput(
tokens=self.tokenizer.encode(prompt),
mm_input=None,
mm_hashes=None,
image_offsets_list=None,
)
# Validate tokens
if not processed.tokens:
return ErrorResponse(
error=make_error(
"Input prompt is empty", "invalid_request_error", 400
)
)
# Compute block hashes for routing
block_mm_infos = self._build_block_mm_infos(
len(processed.tokens),
processed.mm_hashes,
processed.image_offsets_list,
)
logger.debug(f"block_mm_infos: {block_mm_infos}")
local_hashes = compute_block_hash_for_seq(
processed.tokens, self.init_params.block_size, block_mm_infos
)
# Debug dump
dump_api_debug(
tokens=processed.tokens,
block_size=self.init_params.block_size,
local_hashes=local_hashes,
mm_hashes=processed.mm_hashes,
block_mm_infos=block_mm_infos,
image_urls=parsed.image_urls,
)
# Route to best worker
worker_id = await self._route_request(
local_hashes, len(processed.tokens)
)
if isinstance(worker_id, ErrorResponse):
return worker_id
# Generate response
request_id = f"chatcmpl-{uuid.uuid4()}"
sampling_params = {
"max_tokens": parsed.max_tokens,
"temperature": parsed.temperature,
"top_p": parsed.top_p,
}
prompt_input = processed.mm_input or processed.tokens
logger.debug(f"Sending to worker {worker_id}")
result_generator = self.workers.direct(
prompt_input, worker_id, sampling_params
)
if request.stream:
return StreamingResponse(
self._stream_response(request, result_generator, request_id),
media_type="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
else:
# Non-streaming: collect all outputs and return complete response
response_data = await self._generate_full_response(
request, result_generator, request_id
)
return JSONResponse(content=response_data)
except Exception as e:
logger.error(f"Request processing error: {e}")
return ErrorResponse(error=make_error(str(e), "internal_error", 500))
# -------------------------------------------------------------------------
# Lifecycle Management
# -------------------------------------------------------------------------
async def initialize_services(self):
"""Initialize workers, HTTP client, and tokenizer."""
logger.info(
f"Initializing services: model={self.init_params.model}, "
f"workers={self.init_params.num_workers}, block_size={self.init_params.block_size}"
)
self.workers = TrtllmWorkers(
model=self.init_params.model,
block_size=self.init_params.block_size,
base_kv_events_port=self.init_params.base_kv_events_port,
base_metrics_port=self.init_params.base_metrics_port,
num_workers=self.init_params.num_workers,
)
await self.workers.start_all()
self.http_client = httpx.AsyncClient()
self.tokenizer = tokenizer_factory(self.init_params.model)
try:
self.processor = AutoProcessor.from_pretrained(
self.init_params.model, trust_remote_code=True
)
except Exception as e:
logger.warning(f"Failed to initialize HF processor: {e}")
self.processor = None
await asyncio.sleep(2)
logger.info("All services initialized")
async def start(self):
"""Start the API server."""
await self.initialize_services()
logger.info(f"Starting API server on port {self.init_params.http_port}")
config = uvicorn.Config(
self.app, host="0.0.0.0", port=self.init_params.http_port, log_level="info"
)
server = uvicorn.Server(config)
await server.serve()
async def shutdown(self):
"""Proper shutdown handler."""
logger.info("Shutting down API...")
if self.http_client:
await self.http_client.aclose()
if self.workers:
self.workers.shutdown_all()
logger.info("API shutdown completed")
def main():
parser = argparse.ArgumentParser(description="TensorRT-LLM Router API Server")
parser.add_argument(
"--model",
type=str,
default="Qwen/Qwen2-VL-2B-Instruct",
help="Model name to use (VLM for multimodal support)",
)
parser.add_argument(
"--model-type",
type=str,
default="qwen2_vl",
help="Model type for TRTLLM (e.g., qwen2_vl, llava, phi3_v)",
)
parser.add_argument(
"--block-size",
type=int,
default=32,
help="Block size for caching (TensorRT-LLM uses 32)",
)
parser.add_argument(
"--num-workers", type=int, default=2, help="Number of worker processes"
)
parser.add_argument(
"--base-kv-events-port", type=int, default=5557, help="Base port for KV events"
)
parser.add_argument(
"--base-metrics-port", type=int, default=5657, help="Base port for metrics"
)
parser.add_argument(
"--router-port", type=int, default=7000, help="Port for router service"
)
parser.add_argument(
"--http-port", type=int, default=8000, help="Port to serve the API on"
)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
init_params = ServingParams(
model=args.model,
model_type=args.model_type,
block_size=args.block_size,
num_workers=args.num_workers,
base_kv_events_port=args.base_kv_events_port,
base_metrics_port=args.base_metrics_port,
router_port=args.router_port,
http_port=args.http_port,
)
api = ServiceAPI(init_params=init_params)
router_api = RouterAPI(
block_size=args.block_size,
num_workers=args.num_workers,
base_kv_events_port=args.base_kv_events_port,
base_metrics_port=args.base_metrics_port,
port=args.router_port,
)
async def run_with_shutdown():
try:
router_task = asyncio.create_task(router_api.start())
await asyncio.sleep(0.5)
api_task = asyncio.create_task(api.start())
await asyncio.gather(router_task, api_task)
except KeyboardInterrupt:
logger.info("Shutting down services...")
except Exception as e:
logger.exception(f"Unhandled exception: {e}")
finally:
await api.shutdown()
try:
asyncio.run(run_with_shutdown())
except KeyboardInterrupt:
logger.info("Force shutdown via KeyboardInterrupt.")
if __name__ == "__main__":
main()
#!/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Simple health check - sends a basic chat request
# Model name should match what you started api.py with
curl -s -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "Qwen/Qwen2-VL-2B-Instruct",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": false,
"max_tokens": 50
}' | jq
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import argparse
import asyncio
import json
import logging
import os
from contextlib import asynccontextmanager
import numpy as np
import uvicorn
import zmq
import zmq.asyncio
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, ValidationError
from dynamo._core import RadixTree
logger = logging.getLogger(__name__)
DEBUG_ENABLED = os.environ.get("DYNAMO_DEBUG", "0") == "1"
def dump_kv_event(worker_id: int, event: dict):
"""Dump KV event to file for debugging (only when DYNAMO_DEBUG=1)."""
if not DEBUG_ENABLED:
return
import datetime
with open("/tmp/debug_kv_events.txt", "a") as f:
f.write(f"\n{'='*60}\n")
f.write(f"Timestamp: {datetime.datetime.now()}\n")
f.write(f"Worker ID: {worker_id}\n")
f.write(f"Event: {json.dumps(event, indent=2)}\n")
# -----------------------------------------------------------------------------
# Request/Response Models
# -----------------------------------------------------------------------------
class RouterRequest(BaseModel):
local_hashes: list[int]
num_tokens: int
class RouterResponse(BaseModel):
worker_id: int
overlap: float = 0.0
matched_blocks: int = 0
class InjectEventRequest(BaseModel):
"""For testing: inject a KV event directly into RadixTree."""
worker_id: int
tokens_hash: int
block_hash: int | None = None
mm_extra_info: dict | None = None
class LoadMetrics(BaseModel):
kv_cache_usage: float
num_waiting_reqs: int
# -----------------------------------------------------------------------------
# ZMQ Helpers
# -----------------------------------------------------------------------------
def create_zmq_subscriber(context: zmq.Context, endpoint: str) -> zmq.Socket[bytes]:
"""Create a ZMQ SUB socket with standard settings."""
socket = context.socket(zmq.SUB)
socket.connect(endpoint)
socket.setsockopt(zmq.SUBSCRIBE, b"")
socket.setsockopt(zmq.CONFLATE, 1)
socket.setsockopt(zmq.RCVTIMEO, 1)
return socket
# -----------------------------------------------------------------------------
# KvRouter Core
# -----------------------------------------------------------------------------
class KvRouter:
"""Router that uses RadixTree for KV cache-aware worker selection."""
def __init__(
self,
block_size: int = 64,
num_workers: int = 4,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
):
self.num_workers = num_workers
self.block_size = block_size
self.radix_tree = RadixTree()
# Per-worker metrics
self.kv_usages = [0.0] * num_workers
self.waitings = [0] * num_workers
# ZMQ setup
self.context = zmq.Context()
self.load_listeners = [
create_zmq_subscriber(
self.context, f"tcp://localhost:{base_metrics_port + i}"
)
for i in range(num_workers)
]
self.async_context = zmq.asyncio.Context()
self.kv_listeners = [
self._create_kv_listener(base_kv_events_port + i)
for i in range(num_workers)
]
self.background_tasks: list[asyncio.Task] = []
logger.info("Router initialized")
def _create_kv_listener(self, port: int) -> zmq.asyncio.Socket:
"""Create an async ZMQ SUB socket for receiving KV cache events."""
sock = self.async_context.socket(zmq.SUB)
sock.connect(f"tcp://localhost:{port}")
sock.setsockopt(zmq.SUBSCRIBE, b"")
sock.setsockopt(zmq.RCVTIMEO, 1)
return sock
# -------------------------------------------------------------------------
# Background Tasks
# -------------------------------------------------------------------------
async def start_background_tasks(self):
"""Start background tasks for load and tree updates."""
logger.info("Starting router background tasks...")
for worker_id in range(self.num_workers):
self.background_tasks.append(
asyncio.create_task(self._poll_worker_load(worker_id))
)
self.background_tasks.append(
asyncio.create_task(self._poll_worker_kv_events(worker_id))
)
async def _poll_worker_load(self, worker_id: int):
"""Poll load metrics for a single worker."""
while True:
try:
data = self.load_listeners[worker_id].recv_json(zmq.NOBLOCK)
metrics = LoadMetrics.model_validate(data)
self.kv_usages[worker_id] = metrics.kv_cache_usage
self.waitings[worker_id] = metrics.num_waiting_reqs
except zmq.Again:
pass
except (zmq.ZMQError, ValidationError) as e:
logger.warning(f"Worker {worker_id} metrics error: {e}")
except Exception:
logger.exception(f"Worker {worker_id} unexpected metrics error")
await asyncio.sleep(0.1)
async def _poll_worker_kv_events(self, worker_id: int):
"""Poll KV events for a single worker and update RadixTree."""
sock = self.kv_listeners[worker_id]
while True:
try:
event_bytes = await sock.recv(zmq.NOBLOCK)
event = json.loads(event_bytes)
dump_kv_event(worker_id, event)
self.radix_tree.apply_event(
worker_id, json.dumps(event).encode("utf-8")
)
except zmq.Again:
pass
except (zmq.ZMQError, json.JSONDecodeError) as e:
logger.warning(f"Worker {worker_id} KV events error: {e}")
except Exception:
logger.exception(f"Worker {worker_id} unexpected KV events error")
await asyncio.sleep(0.1)
# -------------------------------------------------------------------------
# Worker Selection
# -------------------------------------------------------------------------
async def get_best_worker(
self, local_hashes: list[int], num_tokens: int
) -> tuple[int, float, int]:
"""
Find best worker for request.
Returns: (worker_id, overlap_ratio, matched_blocks)
"""
if num_tokens <= 0:
raise ValueError("num_tokens must be positive")
# Get cache matches from RadixTree
matched_blocks = self._get_matched_blocks(local_hashes)
# Compute overlap scores
overlap_scores = {
wid: matched_blocks[wid] * self.block_size / num_tokens
for wid in range(self.num_workers)
}
# Compute routing logits
logits = self._compute_logits(overlap_scores)
# Select best worker (random tie-breaking)
best_id = self._select_best_worker(logits)
# Predictive update for burst handling
self.waitings[best_id] += 1
return best_id, overlap_scores[best_id], matched_blocks[best_id]
def _get_matched_blocks(self, local_hashes: list[int]) -> dict[int, int]:
"""Get matched block count per worker from RadixTree."""
result = self.radix_tree.find_matches(local_hashes)
raw_scores = result.scores
logger.info(f"Router: raw_scores={raw_scores}")
# raw_scores is keyed by (worker_id, dp_rank); assume dp_rank=0
return {wid: raw_scores.get((wid, 0), 0) for wid in range(self.num_workers)}
def _compute_logits(self, overlap_scores: dict[int, float]) -> list[float]:
"""Compute routing logits for each worker."""
max_waiting = max(self.waitings) if self.waitings else 0
logits = []
for wid in range(self.num_workers):
overlap = overlap_scores[wid]
usage = self.kv_usages[wid]
waiting_norm = self.waitings[wid] / max_waiting if max_waiting else 0.0
logit = 2 * overlap - usage - waiting_norm
logits.append(logit)
logger.info(
f"worker_id: {wid}, logit = 2 * {overlap:.3f} - {usage:.3f} - {waiting_norm:.3f} = {logit:.3f}"
)
return logits
def _select_best_worker(self, logits: list[float]) -> int:
"""Select worker with highest logit (random tie-breaking)."""
arr = np.array(logits)
return int(np.random.choice(np.flatnonzero(arr == arr.max())))
# -------------------------------------------------------------------------
# Shutdown
# -------------------------------------------------------------------------
async def shutdown(self):
"""Shutdown ZMQ listeners and background tasks."""
logger.info("Shutting down KvRouter...")
for task in self.background_tasks:
task.cancel()
if self.background_tasks:
await asyncio.gather(*self.background_tasks, return_exceptions=True)
for listener in self.load_listeners:
listener.close()
for listener in self.kv_listeners:
listener.close()
self.context.term()
self.async_context.term()
logger.info("KvRouter shutdown completed")
# -----------------------------------------------------------------------------
# Router API Server
# -----------------------------------------------------------------------------
class RouterAPI:
"""FastAPI wrapper for KvRouter."""
def __init__(
self,
block_size: int = 64,
num_workers: int = 4,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
port: int = 7000,
):
self.port = port
self.router_config = {
"block_size": block_size,
"num_workers": num_workers,
"base_kv_events_port": base_kv_events_port,
"base_metrics_port": base_metrics_port,
}
self.router: KvRouter | None = None
self.app = FastAPI(
title="KV Router API", version="0.0.1", lifespan=self.lifespan
)
self._setup_routes()
def _require_router(self) -> KvRouter:
"""Get router or raise 503 if not initialized."""
if self.router is None:
raise HTTPException(status_code=503, detail="Router not initialized")
return self.router
@asynccontextmanager
async def lifespan(self, app: FastAPI):
self.router = KvRouter(**self.router_config)
await self.router.start_background_tasks()
logger.info("Router API started")
yield
if self.router:
await self.router.shutdown()
def _setup_routes(self):
@self.app.post("/find_best_worker", response_model=RouterResponse)
async def find_best_worker(request: RouterRequest):
router = self._require_router()
try:
wid, overlap, matched = await router.get_best_worker(
request.local_hashes, request.num_tokens
)
return RouterResponse(
worker_id=wid, overlap=overlap, matched_blocks=matched
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@self.app.get("/debug/tree_info")
async def get_tree_info():
router = self._require_router()
events = router.radix_tree.dump_tree_as_events()
return {"num_blocks": len(events), "events": events[:20]}
@self.app.post("/debug/inject_event")
async def inject_event(request: InjectEventRequest):
router = self._require_router()
block_hash = request.block_hash or request.tokens_hash
event = {
"event_id": 99999,
"data": {
"stored": {
"parent_hash": None,
"blocks": [
{
"block_hash": block_hash,
"tokens_hash": request.tokens_hash,
"mm_extra_info": request.mm_extra_info,
}
],
}
},
}
router.radix_tree.apply_event(
request.worker_id, json.dumps(event).encode("utf-8")
)
return {
"status": "ok",
"tokens_hash": request.tokens_hash,
"worker_id": request.worker_id,
}
async def start(self):
"""Start the router API server."""
logger.info(f"Starting Router API on port {self.port}")
config = uvicorn.Config(
self.app, host="0.0.0.0", port=self.port, log_level="info"
)
await uvicorn.Server(config).serve()
def main():
parser = argparse.ArgumentParser(description="KV Router API Server")
parser.add_argument(
"--block-size", type=int, default=32, help="Block size (default: 32)"
)
parser.add_argument("--num-workers", type=int, default=2, help="Number of workers")
parser.add_argument(
"--base-kv-events-port", type=int, default=5557, help="Base KV events port"
)
parser.add_argument(
"--base-metrics-port", type=int, default=5657, help="Base metrics port"
)
parser.add_argument("--port", type=int, default=7000, help="Router API port")
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
api = RouterAPI(
block_size=args.block_size,
num_workers=args.num_workers,
base_kv_events_port=args.base_kv_events_port,
base_metrics_port=args.base_metrics_port,
port=args.port,
)
asyncio.run(api.start())
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Test suite for TensorRT-LLM KV Router.
Usage:
python test_router.py # Run text-only tests (requires server)
python test_router.py --verbose # Show detailed logs
python test_router.py --mm-only # Run multimodal hash tests (no server needed)
python test_router.py --mm-server # Run multimodal server tests (requires VLM)
python test_router.py --all # Run all tests
"""
import argparse
import sys
import time
from dataclasses import dataclass
import httpx
from dynamo.llm import compute_block_hash_for_seq
# Sample test images from COCO dataset
TEST_IMAGE_1 = "http://images.cocodataset.org/test2017/000000155781.jpg"
TEST_IMAGE_2 = "http://images.cocodataset.org/test2017/000000000001.jpg"
TEST_IMAGE_3 = "http://images.cocodataset.org/test2017/000000155721.jpg"
TEST_IMAGE_4 = "https://huggingface.co/datasets/Sayali9141/traffic_signal_images/resolve/main/61.jpg"
@dataclass
class RouterTestConfig:
api_url: str = "http://localhost:8000"
router_url: str = "http://localhost:7000"
timeout: int = 30
kv_settle_time: float = 3.0 # Time to wait for KV events to propagate
@dataclass
class RouterTestResult:
name: str
passed: bool
message: str
overlap: float = 0.0
def make_request(content: str, max_tokens: int = 10) -> dict:
"""Create a text-only chat completion request."""
return {
"model": "test",
"messages": [{"role": "user", "content": content}],
"stream": True,
"max_tokens": max_tokens,
}
def make_mm_request(text: str, image_url: str, max_tokens: int = 10) -> dict:
"""Create a multimodal chat completion request with image."""
return {
"model": "test",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": text},
{"type": "image_url", "image_url": {"url": image_url}},
],
}
],
"stream": True,
"max_tokens": max_tokens,
}
def make_multi_image_request(
text: str, image_urls: list[str], max_tokens: int = 10
) -> dict:
"""Create a multimodal chat completion request with multiple images."""
content: list[dict] = [{"type": "text", "text": text}]
for url in image_urls:
content.append({"type": "image_url", "image_url": {"url": url}})
return {
"model": "test",
"messages": [{"role": "user", "content": content}],
"stream": True,
"max_tokens": max_tokens,
}
def send_request(client: httpx.Client, url: str, payload: dict) -> bool:
"""Send a chat completion request and consume the stream."""
try:
resp = client.post(f"{url}/v1/chat/completions", json=payload)
if resp.status_code != 200:
return False
for _ in resp.iter_lines():
pass
return True
except Exception:
return False
def get_tree_info(client: httpx.Client, url: str) -> dict:
"""Get radix tree debug info."""
try:
resp = client.get(f"{url}/debug/tree_info")
return resp.json()
except Exception:
return {"num_blocks": -1, "events": []}
class KvRouterTests:
"""Test cases for KV cache routing."""
def __init__(self, config: RouterTestConfig, verbose: bool = False):
self.config = config
self.verbose = verbose
self.client = httpx.Client(timeout=config.timeout)
self.results: list[RouterTestResult] = []
# Test messages designed for block_size=32
# "Are you ok? Hello! Thank you! Thank you very much! " is ~12 tokens
# Chat template adds ~4 tokens
self.base_phrase = "Are you ok? Hello! Thank you! Thank you very much! "
def log(self, msg: str):
if self.verbose:
print(f" {msg}")
def run_all(self) -> bool:
"""Run all test cases."""
print("\nKV Router Test Suite")
print("=" * 50)
# Check server connectivity first
if not self._check_servers():
print("\nFATAL: Cannot connect to servers")
return False
# Run test cases
self._test_full_match()
self._test_partial_match()
self._test_no_match()
# Print summary
return self._print_summary()
def run_mm_tests(self) -> bool:
"""Run multimodal tests (local hash computation, no server needed)."""
print("\nMultimodal KV Router Tests (Local)")
print("=" * 50)
print("(These tests verify hash computation without server)")
self._test_mm_hash_computation()
self._test_mm_routing_distinction()
self._test_mm_hash_consistency()
self._test_mm_offset_affects_hash()
self._test_mm_block_boundary()
self._test_mm_multi_image_partial_match()
return self._print_summary()
def run_mm_server_tests(self) -> bool:
"""Run multimodal tests that require server."""
print("\nMultimodal KV Router Tests (Server)")
print("=" * 50)
if not self._check_servers():
print("\nFATAL: Cannot connect to servers")
return False
self._test_mm_same_image_cache_hit()
self._test_mm_different_images_no_cache_hit()
self._test_text_cache_hit_with_overlap()
self._test_mm_multi_image_partial_match()
return self._print_summary()
def _check_servers(self) -> bool:
"""Verify both API and Router servers are reachable."""
print("\nChecking server connectivity...")
try:
# Check router
resp = self.client.get(f"{self.config.router_url}/debug/tree_info")
if resp.status_code != 200:
print(f" Router not responding: {resp.status_code}")
return False
print(f" Router OK (blocks in tree: {resp.json().get('num_blocks', '?')})")
# Check API - just verify it's up
# A simple request to verify the endpoint exists
return True
except Exception as e:
print(f" Connection error: {e}")
return False
def _test_full_match(self):
"""
Test: Send identical request twice.
Expected: Second request should have overlap > 0.
"""
print("\n[1] Full Match Test")
print(" Sending same request twice, expecting cache hit on second...")
# Create a request with enough tokens for multiple full blocks
# 5 repetitions ≈ 64 tokens ≈ 2 full blocks
content = (self.base_phrase * 5).strip()
payload = make_request(content)
# Get initial state
initial = get_tree_info(self.client, self.config.router_url)
initial_blocks = initial["num_blocks"]
self.log(f"Initial blocks: {initial_blocks}")
# First request - should populate cache (or hit existing cache)
self.log("Sending first request...")
if not send_request(self.client, self.config.api_url, payload):
self.results.append(
RouterTestResult("full_match", False, "First request failed")
)
return
# Wait for KV events
self.log(f"Waiting {self.config.kv_settle_time}s for KV events...")
time.sleep(self.config.kv_settle_time)
# Check blocks after first request
after_first = get_tree_info(self.client, self.config.router_url)
blocks_added = after_first["num_blocks"] - initial_blocks
self.log(
f"Blocks after first: {after_first['num_blocks']} (added {blocks_added})"
)
# Second request - should hit cache
self.log("Sending second request (should hit cache)...")
if not send_request(self.client, self.config.api_url, payload):
self.results.append(
RouterTestResult("full_match", False, "Second request failed")
)
return
# Success: either new blocks were added, or blocks already existed (from previous runs)
# Either way, the second request should show overlap > 0 in server logs
total_blocks = after_first["num_blocks"]
self.results.append(
RouterTestResult(
"full_match",
True,
f"OK - Tree has {total_blocks} blocks. Check server logs for 'overlap > 0'.",
)
)
def _test_partial_match(self):
"""
Test: Send request A, then request B that shares same prefix but is longer.
Expected: Request B should have partial overlap (matching the shared prefix blocks).
"""
print("\n[2] Partial Match Test")
print(" Request B shares prefix with cached request A...")
# Request A: 5 repetitions (~64 tokens, ~2 full blocks)
content_a = (self.base_phrase * 5).strip()
# Request B: 8 repetitions (~100 tokens, ~3 full blocks)
# First 2 blocks should match A, third block is new
content_b = (self.base_phrase * 8).strip()
payload_a = make_request(content_a)
payload_b = make_request(content_b)
# Ensure A is cached (might already be from previous test)
self.log("Ensuring request A is cached...")
send_request(self.client, self.config.api_url, payload_a)
time.sleep(self.config.kv_settle_time)
before = get_tree_info(self.client, self.config.router_url)
self.log(f"Blocks before B: {before['num_blocks']}")
# Send request B
self.log("Sending request B (longer, shares prefix)...")
if not send_request(self.client, self.config.api_url, payload_b):
self.results.append(
RouterTestResult("partial_match", False, "Request B failed")
)
return
time.sleep(self.config.kv_settle_time)
after = get_tree_info(self.client, self.config.router_url)
new_blocks = after["num_blocks"] - before["num_blocks"]
self.log(f"New blocks from B: {new_blocks}")
# B should add new blocks (the non-matching suffix)
# The matching prefix blocks already exist
self.results.append(
RouterTestResult(
"partial_match",
True,
f"OK - Request B added {new_blocks} new blocks. "
f"Check server logs for partial overlap (0 < overlap < 1).",
)
)
def _test_no_match(self):
"""
Test: Send completely different content.
Expected: No cache hit (overlap = 0).
"""
print("\n[3] No Match Test")
print(" Sending completely different content...")
# Content that's very different from previous tests
# ~80 tokens, completely different from "Hello are you ok leijun"
content = (
"The quick brown fox jumps over the lazy dog. "
"Pack my box with five dozen liquor jugs. "
"How vexingly quick daft zebras jump. "
"The five boxing wizards jump quickly. "
"Sphinx of black quartz, judge my vow."
)
payload = make_request(content)
before = get_tree_info(self.client, self.config.router_url)
self.log(f"Blocks before: {before['num_blocks']}")
# Send the different request
self.log("Sending unrelated request...")
if not send_request(self.client, self.config.api_url, payload):
self.results.append(RouterTestResult("no_match", False, "Request failed"))
return
# No need to wait - we're checking overlap on this request, not the next
self.results.append(
RouterTestResult(
"no_match",
True,
"OK - Check server logs for 'overlap = 0.000' (no cache hit expected).",
)
)
def _test_mm_hash_computation(self):
"""
Test: Verify that compute_block_hash_for_seq produces different hashes
for same tokens with different mm_hash values.
"""
print("\n[MM-1] MM Hash Computation Test")
print(" Verifying same tokens + different mm_hash = different block_hash...")
# Simulated tokens (32 tokens = 1 block)
tokens = [100] * 32
block_size = 32
# Hash without MM info
hash_no_mm = compute_block_hash_for_seq(tokens, block_size)
# Hash with MM info (simulated mm_hash)
mm_info_1 = {"mm_objects": [{"mm_hash": 0xDEADBEEF, "offsets": [[0, 32]]}]}
hash_with_mm1 = compute_block_hash_for_seq(tokens, block_size, [mm_info_1])
# Hash with different MM info
mm_info_2 = {"mm_objects": [{"mm_hash": 0xCAFEBABE, "offsets": [[0, 32]]}]}
hash_with_mm2 = compute_block_hash_for_seq(tokens, block_size, [mm_info_2])
self.log(f"Hash without MM: {hash_no_mm}")
self.log(f"Hash with MM 1: {hash_with_mm1}")
self.log(f"Hash with MM 2: {hash_with_mm2}")
# Verify all hashes are different
if hash_no_mm == hash_with_mm1:
self.results.append(
RouterTestResult(
"mm_hash_computation",
False,
"FAIL - Hash without MM equals hash with MM",
)
)
return
if hash_with_mm1 == hash_with_mm2:
self.results.append(
RouterTestResult(
"mm_hash_computation",
False,
"FAIL - Different mm_hash produced same block_hash",
)
)
return
self.results.append(
RouterTestResult(
"mm_hash_computation",
True,
"OK - Different mm_hash values produce different block hashes",
)
)
def _test_mm_routing_distinction(self):
"""
Test: Verify that the routing logic can distinguish between
requests with same text but different images.
"""
print("\n[MM-2] MM Routing Distinction Test")
print(" Verifying routing can distinguish same text + different images...")
# This test simulates what the router would see
tokens = [100] * 64 # 2 blocks
block_size = 32
# Simulate Image A cached on worker 0
mm_info_a = {
"mm_objects": [{"mm_hash": 0x1111111111111111, "offsets": [[0, 64]]}]
}
hashes_a = compute_block_hash_for_seq(
tokens, block_size, [mm_info_a, mm_info_a]
)
# Simulate Image B cached on worker 1
mm_info_b = {
"mm_objects": [{"mm_hash": 0x2222222222222222, "offsets": [[0, 64]]}]
}
hashes_b = compute_block_hash_for_seq(
tokens, block_size, [mm_info_b, mm_info_b]
)
self.log(f"Hashes for Image A: {hashes_a}")
self.log(f"Hashes for Image B: {hashes_b}")
# Verify hashes are different
if hashes_a == hashes_b:
self.results.append(
RouterTestResult(
"mm_routing_distinction",
False,
"FAIL - Same tokens with different images produced same hashes",
)
)
return
self.results.append(
RouterTestResult(
"mm_routing_distinction",
True,
"OK - Router can distinguish requests with different images",
)
)
def _test_mm_hash_consistency(self):
"""
Test: Verify that the same mm_hash + tokens produce the same block_hash
regardless of when computed (idempotency).
"""
print("\n[MM-3] MM Hash Consistency Test")
print(" Verifying same inputs produce same hash (idempotent)...")
tokens = [151937] * 32 # Image token placeholder
block_size = 32
mm_hash = 0xDEADBEEFCAFEBABE
mm_info = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 32]]}]}
# Compute hash multiple times
hash1 = compute_block_hash_for_seq(tokens, block_size, [mm_info])
hash2 = compute_block_hash_for_seq(tokens, block_size, [mm_info])
hash3 = compute_block_hash_for_seq(tokens, block_size, [mm_info])
self.log(f"Hash 1: {hash1}")
self.log(f"Hash 2: {hash2}")
self.log(f"Hash 3: {hash3}")
if hash1 != hash2 or hash2 != hash3:
self.results.append(
RouterTestResult(
"mm_hash_consistency",
False,
f"FAIL - Same inputs produced different hashes: {hash1}, {hash2}, {hash3}",
)
)
return
self.results.append(
RouterTestResult(
"mm_hash_consistency",
True,
f"OK - Hash computation is idempotent: {hash1[0]}",
)
)
def _test_mm_offset_affects_hash(self):
"""
Test: Verify that different offsets produce different hashes,
even with same mm_hash and tokens.
"""
print("\n[MM-4] MM Offset Affects Hash Test")
print(" Verifying different offsets produce different hashes...")
tokens = [151937] * 64 # 2 blocks of image tokens
block_size = 32
mm_hash = 0x123456789ABCDEF0
# Image covers first block only
mm_info_first = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 32]]}]}
hash_first = compute_block_hash_for_seq(
tokens, block_size, [mm_info_first, None]
)
# Image covers second block only
mm_info_second = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[32, 64]]}]}
hash_second = compute_block_hash_for_seq(
tokens, block_size, [None, mm_info_second]
)
# Image covers both blocks
mm_info_both = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[0, 64]]}]}
hash_both = compute_block_hash_for_seq(
tokens, block_size, [mm_info_both, mm_info_both]
)
self.log(f"Hash (first block MM): {hash_first}")
self.log(f"Hash (second block MM): {hash_second}")
self.log(f"Hash (both blocks MM): {hash_both}")
# Block 0 with mm_info should differ from block 0 without mm_info
# Block 1 with mm_info should differ from block 1 without mm_info
if hash_first[0] == hash_second[0]:
self.results.append(
RouterTestResult(
"mm_offset_affects_hash",
False,
"FAIL - First block hash should differ based on MM presence",
)
)
return
self.results.append(
RouterTestResult(
"mm_offset_affects_hash",
True,
"OK - Different MM offsets produce different block hashes",
)
)
def _test_mm_block_boundary(self):
"""
Test: Verify that MM info correctly applies at block boundaries.
"""
print("\n[MM-5] MM Block Boundary Test")
print(" Verifying MM info applies correctly at block boundaries...")
block_size = 32
mm_hash = 0xFEDCBA9876543210
# 96 tokens = 3 blocks
# Image tokens in the middle block (32-64)
tokens = [100] * 32 + [151937] * 32 + [200] * 32
# MM info only applies to middle block
mm_info = {"mm_objects": [{"mm_hash": mm_hash, "offsets": [[32, 64]]}]}
hashes_with_mm = compute_block_hash_for_seq(
tokens, block_size, [None, mm_info, None]
)
# No MM info
hashes_without_mm = compute_block_hash_for_seq(tokens, block_size, None)
self.log(f"Hashes with MM: {hashes_with_mm}")
self.log(f"Hashes without MM: {hashes_without_mm}")
# Block 0 and 2 should be the same (no image tokens)
# Block 1 should be different (has image tokens + mm_hash)
if hashes_with_mm[0] != hashes_without_mm[0]:
self.results.append(
RouterTestResult(
"mm_block_boundary", False, "FAIL - Block 0 should be same (no MM)"
)
)
return
if hashes_with_mm[1] == hashes_without_mm[1]:
self.results.append(
RouterTestResult(
"mm_block_boundary", False, "FAIL - Block 1 should differ (has MM)"
)
)
return
if hashes_with_mm[2] != hashes_without_mm[2]:
self.results.append(
RouterTestResult(
"mm_block_boundary", False, "FAIL - Block 2 should be same (no MM)"
)
)
return
self.results.append(
RouterTestResult(
"mm_block_boundary",
True,
"OK - MM info correctly applies only to relevant blocks",
)
)
def _test_mm_same_image_cache_hit(self):
"""
Test: Send same text + same image twice.
Expected: Second request should have cache hit (overlap > 0).
"""
print("\n[MM-S1] Same Image Cache Hit Test")
print(" Sending same text + same image twice...")
payload = make_mm_request("Describe this image", TEST_IMAGE_1)
# Get initial state
initial = get_tree_info(self.client, self.config.router_url)
self.log(f"Initial blocks: {initial['num_blocks']}")
# First request - populates the cache
self.log("Sending first MM request...")
if not send_request(self.client, self.config.api_url, payload):
self.results.append(
RouterTestResult("mm_same_image", False, "First MM request failed")
)
return
# Wait for KV events to propagate
self.log(f"Waiting {self.config.kv_settle_time}s for KV events...")
time.sleep(self.config.kv_settle_time)
after_first = get_tree_info(self.client, self.config.router_url)
blocks_added = after_first["num_blocks"] - initial["num_blocks"]
self.log(
f"Blocks after first: {after_first['num_blocks']} (added {blocks_added})"
)
if blocks_added == 0:
self.results.append(
RouterTestResult(
"mm_same_image", False, "FAIL - No blocks added after first request"
)
)
return
# Second identical request - should hit cache
self.log("Sending second MM request (same image)...")
if not send_request(self.client, self.config.api_url, payload):
self.results.append(
RouterTestResult("mm_same_image", False, "Second MM request failed")
)
return
# Query router to check overlap (simulating what the second request saw)
# We need to compute the same hashes that the API computed
# For now, check the tree grew or stayed same (cache reuse)
after_second = get_tree_info(self.client, self.config.router_url)
self.log(f"Blocks after second: {after_second['num_blocks']}")
# The second request should reuse cached blocks, so minimal new blocks added
new_blocks_second = after_second["num_blocks"] - after_first["num_blocks"]
self.log(f"New blocks from second request: {new_blocks_second}")
self.results.append(
RouterTestResult(
"mm_same_image",
True,
f"OK - First added {blocks_added} blocks, second added {new_blocks_second}. "
f"Check logs for 'overlap > 0' on second request.",
)
)
def _test_mm_different_images_no_cache_hit(self):
"""
Test: Send same text but different images.
Expected: No cache hit (overlap ≈ 0) because mm_hash differs.
Image blocks should not match, only text prefix might match.
"""
print("\n[MM-S2] Different Images No Cache Hit Test")
print(" Sending same text + different images...")
# First image
payload_1 = make_mm_request("Describe this image in detail", TEST_IMAGE_2)
initial = get_tree_info(self.client, self.config.router_url)
self.log(f"Initial blocks: {initial['num_blocks']}")
self.log(f"Sending request with image 1: {TEST_IMAGE_2}")
if not send_request(self.client, self.config.api_url, payload_1):
self.results.append(
RouterTestResult("mm_different_images", False, "Image 1 request failed")
)
return
time.sleep(self.config.kv_settle_time)
after_img1 = get_tree_info(self.client, self.config.router_url)
blocks_img1 = after_img1["num_blocks"] - initial["num_blocks"]
self.log(
f"Blocks after image 1: {after_img1['num_blocks']} (added {blocks_img1})"
)
# Second image (same text, different image)
payload_2 = make_mm_request("Describe this image in detail", TEST_IMAGE_3)
self.log(f"Sending request with image 2: {TEST_IMAGE_3}")
if not send_request(self.client, self.config.api_url, payload_2):
self.results.append(
RouterTestResult("mm_different_images", False, "Image 2 request failed")
)
return
time.sleep(self.config.kv_settle_time)
after_img2 = get_tree_info(self.client, self.config.router_url)
blocks_img2 = after_img2["num_blocks"] - after_img1["num_blocks"]
self.log(
f"Blocks after image 2: {after_img2['num_blocks']} (added {blocks_img2})"
)
# Different images should add similar number of blocks
# If image 2 had cache hit, it would add fewer blocks
if blocks_img2 == 0:
self.results.append(
RouterTestResult(
"mm_different_images",
False,
"FAIL - Image 2 added 0 blocks (unexpected full cache hit)",
)
)
return
# Image 2 should add approximately same number of blocks as image 1
# (since different mm_hash means image blocks don't match)
self.results.append(
RouterTestResult(
"mm_different_images",
True,
f"OK - Image 1 added {blocks_img1} blocks, image 2 added {blocks_img2} blocks. "
f"Different images = different block hashes.",
)
)
def _test_text_cache_hit_with_overlap(self):
"""
Test: Send same text request twice and verify overlap via router API.
Expected: Second request should show overlap > 0 in router response.
"""
print("\n[MM-S3] Text Cache Hit with Overlap Verification")
print(" Sending same text twice and verifying overlap value...")
# Use a unique prompt to avoid interference from other tests
unique_text = (
"This is a unique test prompt for cache hit verification. "
"We need enough tokens to fill at least one block. "
"The quick brown fox jumps over the lazy dog repeatedly. " * 3
)
payload = make_request(unique_text, max_tokens=5)
# First request
self.log("Sending first text request...")
if not send_request(self.client, self.config.api_url, payload):
self.results.append(
RouterTestResult(
"text_cache_hit_overlap", False, "First request failed"
)
)
return
# Wait for KV events
self.log(f"Waiting {self.config.kv_settle_time}s for KV events...")
time.sleep(self.config.kv_settle_time)
# Get tree info to see blocks
tree_info = get_tree_info(self.client, self.config.router_url)
self.log(f"Blocks in tree: {tree_info['num_blocks']}")
# Second request - should see cache hit
self.log("Sending second text request (should hit cache)...")
if not send_request(self.client, self.config.api_url, payload):
self.results.append(
RouterTestResult(
"text_cache_hit_overlap", False, "Second request failed"
)
)
return
# For a true verification, we'd need to intercept the router response
# or add an endpoint that returns the last routing decision
# For now, we verify by checking if blocks increased (they shouldn't much)
tree_info_after = get_tree_info(self.client, self.config.router_url)
new_blocks = tree_info_after["num_blocks"] - tree_info["num_blocks"]
self.log(f"New blocks after second request: {new_blocks}")
self.results.append(
RouterTestResult(
"text_cache_hit_overlap",
True,
f"OK - Second request added {new_blocks} new blocks. "
f"Check logs for 'overlap > 0' (cache hit).",
)
)
def _test_mm_multi_image_partial_match(self):
"""
Test: Verify partial cache match with multi-image requests.
Scenario:
Step 1: Send Request A = text + [Image_1, Image_4]
Step 2: Send Request A again (identical) - verify full cache hit (0 new blocks)
Step 3: Send Request B = text + [Image_1, Image_3] - verify partial match
(Image_3 is different, should add new blocks)
Expected:
- Identical request = no new blocks (full cache hit)
- Different second image = new blocks added (partial match)
"""
print("\n[MM-S4] Multi-Image Partial Match Test")
print(" Verifying cache behavior with multi-image requests...")
# Use longer settle time for this test
settle_time = self.config.kv_settle_time * 2
# Request A: text + Image_1 + Image_4
payload_a = make_multi_image_request(
"Describe these images in detail", [TEST_IMAGE_1, TEST_IMAGE_4]
)
initial = get_tree_info(self.client, self.config.router_url)
self.log(f"Initial blocks: {initial['num_blocks']}")
# Step 1: Send Request A first time
self.log("Step 1: Sending Request A (text + Image_1 + Image_4)...")
if not send_request(self.client, self.config.api_url, payload_a):
self.results.append(
RouterTestResult("mm_multi_image_partial", False, "Request A failed")
)
return
time.sleep(settle_time)
after_a1 = get_tree_info(self.client, self.config.router_url)
blocks_a1 = after_a1["num_blocks"] - initial["num_blocks"]
self.log(
f"Blocks after Request A: {after_a1['num_blocks']} (added {blocks_a1})"
)
if blocks_a1 == 0:
self.results.append(
RouterTestResult(
"mm_multi_image_partial",
False,
"FAIL - Request A added 0 blocks (should populate cache)",
)
)
return
# Step 2: Send Request A again (identical) - should be full cache hit
self.log(
"Step 2: Sending Request A again (identical, expect full cache hit)..."
)
if not send_request(self.client, self.config.api_url, payload_a):
self.results.append(
RouterTestResult(
"mm_multi_image_partial", False, "Request A (repeat) failed"
)
)
return
time.sleep(settle_time)
after_a2 = get_tree_info(self.client, self.config.router_url)
blocks_a2 = after_a2["num_blocks"] - after_a1["num_blocks"]
self.log(
f"Blocks after Request A repeat: {after_a2['num_blocks']} (added {blocks_a2})"
)
# Identical request should add 0 new blocks (full cache hit)
if blocks_a2 != 0:
self.log(
f"WARNING: Identical request added {blocks_a2} blocks (expected 0)"
)
# Step 3: Send Request B with different second image
payload_b = make_multi_image_request(
"Describe these images in detail", [TEST_IMAGE_1, TEST_IMAGE_3]
)
self.log(
"Step 3: Sending Request B (text + Image_1 + Image_3, different 2nd image)..."
)
if not send_request(self.client, self.config.api_url, payload_b):
self.results.append(
RouterTestResult("mm_multi_image_partial", False, "Request B failed")
)
return
time.sleep(settle_time)
after_b = get_tree_info(self.client, self.config.router_url)
blocks_b = after_b["num_blocks"] - after_a2["num_blocks"]
self.log(f"Blocks after Request B: {after_b['num_blocks']} (added {blocks_b})")
# Analysis:
# - If blocks_b > 0: Image_3 created new blocks (correct - different image)
# - If blocks_b == 0: Full cache hit (wrong - Image_3 should be different)
#
# Note: We can't easily verify partial match vs full cache miss because
# the tree growth depends on whether routing hit the cached worker.
# What we CAN verify is that different images should NOT fully cache hit.
if blocks_b == 0 and blocks_a2 == 0:
# Both identical and different requests added 0 blocks
# This suggests Image_3's mm_hash is incorrectly matching Image_4
self.results.append(
RouterTestResult(
"mm_multi_image_partial",
False,
"FAIL - Request B (different image) added 0 blocks. "
"Image_3 should have different mm_hash than Image_4. "
"Check if mm_hash computation is correct.",
)
)
return
if blocks_b == 0:
# Different image but 0 new blocks - might be timing or routing issue
self.results.append(
RouterTestResult(
"mm_multi_image_partial",
False,
f"FAIL - Request B added 0 blocks. "
f"Identical request added {blocks_a2}. "
f"This is unexpected - different images should not fully cache hit.",
)
)
return
# Success: different image added new blocks
self.results.append(
RouterTestResult(
"mm_multi_image_partial",
True,
f"OK - Request A: {blocks_a1} blocks, A repeat: {blocks_a2}, "
f"Request B (diff image): {blocks_b}. "
f"Different images correctly create distinct cache entries.",
)
)
def _print_summary(self) -> bool:
"""Print test results summary."""
print("\n" + "=" * 50)
print("Results")
print("=" * 50)
all_passed = True
for r in self.results:
_ = "PASS" if r.passed else "FAIL"
symbol = "[OK]" if r.passed else "[X]"
print(f" {symbol} {r.name}: {r.message}")
if not r.passed:
all_passed = False
print("\n" + "-" * 50)
if all_passed:
print("All tests passed.")
print("\nTo fully verify, check server logs for:")
print(" - Full match: overlap > 0.5")
print(" - Partial match: 0 < overlap < 0.5")
print(" - No match: overlap = 0.000")
else:
print("Some tests failed. Check the messages above.")
return all_passed
def cleanup(self):
self.client.close()
def main():
parser = argparse.ArgumentParser(description="KV Router Test Suite")
parser.add_argument(
"--verbose", "-v", action="store_true", help="Show detailed logs"
)
parser.add_argument(
"--api-url", default="http://localhost:8000", help="API server URL"
)
parser.add_argument(
"--router-url", default="http://localhost:7000", help="Router URL"
)
parser.add_argument(
"--mm-only",
action="store_true",
help="Run only multimodal local tests (no server needed)",
)
parser.add_argument(
"--mm-server",
action="store_true",
help="Run multimodal server tests (requires VLM model)",
)
parser.add_argument(
"--all", action="store_true", help="Run all tests including multimodal"
)
args = parser.parse_args()
config = RouterTestConfig(api_url=args.api_url, router_url=args.router_url)
tests = KvRouterTests(config, verbose=args.verbose)
try:
if args.mm_only:
# Local MM tests only (no server)
success = tests.run_mm_tests()
elif args.mm_server:
# MM server tests (requires VLM)
success = tests.run_mm_server_tests()
elif args.all:
# Run all tests
success = tests.run_all()
if success:
success = tests.run_mm_tests()
if success:
success = tests.run_mm_server_tests()
else:
# Default: text-only tests
success = tests.run_all()
sys.exit(0 if success else 1)
finally:
tests.cleanup()
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
if "PYTHONHASHSEED" not in os.environ:
os.environ["PYTHONHASHSEED"] = "0"
# Fix protobuf version conflict with etcd3
os.environ.setdefault("PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION", "python")
import asyncio
import json
import logging
from typing import AsyncGenerator, Optional
import zmq
from tensorrt_llm import LLM
from tensorrt_llm.llmapi import KvCacheConfig
from dynamo.llm import compute_block_hash_for_seq_py
logger = logging.getLogger(__name__)
DEFAULT_KV_EVENT_BUFFER_MAX_SIZE = 1024
llm_max_num_tokens = int(os.getenv("TRTLLM_MAX_NUM_TOKENS", "8192"))
# Debug flag: set DYNAMO_DEBUG=1 to enable debug file dumps
DEBUG_ENABLED = os.environ.get("DYNAMO_DEBUG", "0") == "1"
DEBUG_WORKER_KV_FILE = "/tmp/debug_worker_kv.txt"
# As api.py spins up 2 workers by default, we split the single GPU memory between 2
# workers. Hence, 0.4.
# TODO: allow memory args passing so that the caller can decide the best way to
# allocate memory.
kv_cache_free_gpu_memory_fraction = float(os.getenv("TRTLLM_FREE_GPU_FRAC", "0.4"))
# Qwen2-VL specific token ID for image placeholders
IMAGE_TOKEN_ID = 151937
def dump_worker_kv_event(worker_id: int, event: dict, token_ids: list[int]):
"""Dump worker-side KV event to file for debugging."""
if not DEBUG_ENABLED:
return
import datetime
with open(DEBUG_WORKER_KV_FILE, "a") as f:
f.write(f"\n{'='*60}\n")
f.write(f"Timestamp: {datetime.datetime.now()}\n")
f.write(f"Worker ID: {worker_id}\n")
f.write(f"Event: {event}\n")
f.write(f"Tokens ({len(token_ids)}): {token_ids[:50]}...\n")
f.write(f"{'='*60}\n")
def to_unsigned_u64(value: int | None) -> int | None:
"""Ensure value is in unsigned 64-bit range for Rust/msgpack."""
if value is None:
return None
# Handle negative values (two's complement)
return (1 << 64) + value if value < 0 else value
# -----------------------------------------------------------------------------
# ZMQ Publishers
# -----------------------------------------------------------------------------
class MetricsPublisher:
"""Publishes worker metrics over ZMQ."""
def __init__(self, port: int):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(f"tcp://*:{port}")
def publish(self, num_waiting_reqs: int, kv_cache_usage: float):
self.socket.send_json(
{
"num_waiting_reqs": num_waiting_reqs,
"kv_cache_usage": kv_cache_usage,
}
)
def close(self):
self.socket.close()
self.context.term()
class KvEventsPublisher:
"""Publishes KV cache events as KvCacheEvent JSON over ZMQ."""
def __init__(self, port: int, block_size: int):
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(f"tcp://*:{port}")
self.block_size = block_size
self.partial_block_hashes: set[int] = set()
self.next_event_id = 0
def publish_stored(
self,
block_hashes: list[int],
token_ids: list[int],
parent_hash: int | None,
block_mm_infos: list[dict | None] | None,
):
"""Publish a KvCacheEvent with stored blocks.
Computes tokens_hash per block using compute_block_hash_for_seq_py
(including MM info when present) and publishes as KvCacheEvent JSON.
"""
# Compute tokens_hash per block (MM-aware when block_mm_infos provided)
tokens_hashes = compute_block_hash_for_seq_py(
token_ids, self.block_size, block_mm_infos
)
blocks = []
for i, ext_hash in enumerate(block_hashes):
block_data = {
"block_hash": to_unsigned_u64(ext_hash),
"tokens_hash": tokens_hashes[i],
}
mm_info = block_mm_infos[i] if block_mm_infos else None
if mm_info is not None:
block_data["mm_extra_info"] = mm_info
blocks.append(block_data)
event = {
"event_id": self.next_event_id,
"data": {
"stored": {
"parent_hash": (
to_unsigned_u64(parent_hash)
if parent_hash is not None
else None
),
"blocks": blocks,
}
},
"dp_rank": 0,
}
self.next_event_id += 1
self._send(event)
def publish_removed(self, block_hashes: list[int]):
"""Publish a KvCacheEvent with removed blocks."""
filtered = []
for h in block_hashes:
if h in self.partial_block_hashes:
self.partial_block_hashes.remove(h)
else:
filtered.append(to_unsigned_u64(h))
if not filtered:
return
event = {
"event_id": self.next_event_id,
"data": {
"removed": {
"block_hashes": filtered,
}
},
"dp_rank": 0,
}
self.next_event_id += 1
self._send(event)
def _send(self, event: dict):
"""Send a single KvCacheEvent as JSON over ZMQ."""
try:
payload = json.dumps(event).encode("utf-8")
except Exception as e:
logger.error(f"JSON encode error: {e}")
return
self.socket.send(payload)
def close(self):
self.socket.close()
self.context.term()
# -----------------------------------------------------------------------------
# KV Event Processing Helpers
# -----------------------------------------------------------------------------
def extract_mm_info(
blocks_data: list[dict], all_token_ids: list[int]
) -> tuple[list[int] | None, list[list[int]] | None]:
"""Extract multimodal hash info from TRTLLM block data.
Handles multiple images by extracting all mm_hashes and matching them
to their corresponding image token ranges.
Returns:
Tuple of (list of mm_hashes, list of offsets) or (None, None).
Each offset is [start, end) for one image's token range.
"""
# Collect all mm_hashes from blocks
mm_hashes: list[int] = []
for block in blocks_data:
mm_keys = block.get("mm_keys", [])
for mm_key in mm_keys:
if mm_key.get("type") != "mm_key":
continue
hash_hex = mm_key.get("hash", "")
if hash_hex:
mm_hash = int(hash_hex[:16], 16)
if mm_hash not in mm_hashes: # Avoid duplicates
mm_hashes.append(mm_hash)
if not mm_hashes:
return None, None
# Find all image token ranges
image_offsets_list = find_all_image_token_ranges(all_token_ids)
if not image_offsets_list:
return None, None
# Match mm_hashes to image_offsets by order
# (assumes mm_hashes appear in same order as images in token sequence)
return mm_hashes, image_offsets_list
def find_all_image_token_ranges(token_ids: list[int]) -> list[list[int]] | None:
"""Find all [start, end) ranges of contiguous image tokens.
Returns:
List of [start, end) ranges, one per contiguous image token sequence.
Returns None if no image tokens found.
"""
ranges: list[list[int]] = []
current_start: int | None = None
for i, tid in enumerate(token_ids):
if tid == IMAGE_TOKEN_ID:
if current_start is None:
current_start = i
elif current_start is not None:
# End of contiguous sequence
ranges.append([current_start, i])
current_start = None
# Handle sequence ending with image tokens
if current_start is not None:
ranges.append([current_start, len(token_ids)])
return ranges if ranges else None
def build_per_block_mm_infos(
num_blocks: int,
block_size: int,
mm_hashes: list[int] | None,
image_offsets_list: list[list[int]] | None,
) -> list[dict | None] | None:
"""Build per-block mm_infos list for multiple images.
Each block that overlaps with an image's token range gets the corresponding
mm_info with that image's mm_hash.
Args:
num_blocks: Number of blocks in the stored event.
block_size: Number of tokens per block.
mm_hashes: List of mm_hash values, one per image.
image_offsets_list: List of [start, end) token ranges, one per image.
Returns:
List of mm_info (one per block), with None for blocks without image tokens.
Returns None if no mm_info is provided.
"""
if mm_hashes is None or image_offsets_list is None:
return None
if not mm_hashes or not image_offsets_list:
return None
# Initialize result with None for all blocks
result: list[dict | None] = [None] * num_blocks
# Process each image
for mm_hash, offsets in zip(mm_hashes, image_offsets_list):
img_start, img_end = offsets
for block_idx in range(num_blocks):
block_start = block_idx * block_size
block_end = block_start + block_size
# Check if this block overlaps with this image's token range
if block_end > img_start and block_start < img_end:
if result[block_idx] is None:
result[block_idx] = {"mm_objects": []}
# Add this image's mm_object to the block
result[block_idx]["mm_objects"].append(
{"mm_hash": mm_hash, "offsets": [offsets]}
)
return result
def parse_stored_blocks(
blocks_data: list[dict], block_size: int, partial_hashes: set[int]
) -> tuple[list[dict], list[int]]:
"""Parse stored blocks from TRTLLM event data.
Returns:
Tuple of (blocks list, all token_ids)
"""
blocks = []
all_token_ids = []
for block in blocks_data:
tokens = block["tokens"]
num_tokens = len(tokens)
block_hash = block["block_hash"]
if num_tokens == block_size:
token_ids = [int(t["token_id"]) for t in tokens]
blocks.append(
{
"block_hash": block_hash,
"token_ids": token_ids,
"num_tokens": num_tokens,
}
)
all_token_ids.extend(token_ids)
elif num_tokens < block_size:
# Partial block - track but don't publish
partial_hashes.add(block_hash)
break
else:
logger.error(f"Block too large: {num_tokens} > {block_size}")
break
return blocks, all_token_ids
# -----------------------------------------------------------------------------
# TRT-LLM Worker
# -----------------------------------------------------------------------------
class TrtllmWorker:
"""Manages a single TensorRT-LLM worker with event/metrics publishing."""
def __init__(
self,
worker_id: int,
model: str,
block_size: int,
kv_events_port: int,
metrics_port: int,
):
self.worker_id = worker_id
self.model = model
self.block_size = block_size
self.llm: Optional[LLM] = None
self.metrics_publisher: Optional[MetricsPublisher] = None
self.kv_events_publisher: Optional[KvEventsPublisher] = None
self.background_tasks: list[asyncio.Task] = []
self.max_window_size: int | None = None
self.processing_initial_events = True
self.kv_events_started = False
self._initialize(kv_events_port, metrics_port)
def _initialize(self, kv_events_port: int, metrics_port: int):
"""Initialize TensorRT-LLM engine and publishers."""
logger.info(f"Worker {self.worker_id}: Initializing")
self.llm = LLM(
model=self.model,
kv_cache_config=KvCacheConfig(
enable_block_reuse=True,
event_buffer_max_size=DEFAULT_KV_EVENT_BUFFER_MAX_SIZE,
free_gpu_memory_fraction=kv_cache_free_gpu_memory_fraction,
),
max_num_tokens=llm_max_num_tokens,
)
self.metrics_publisher = MetricsPublisher(metrics_port)
self.kv_events_publisher = KvEventsPublisher(kv_events_port, self.block_size)
logger.info(f"Worker {self.worker_id}: Initialized")
# -------------------------------------------------------------------------
# Background Tasks
# -------------------------------------------------------------------------
async def start_background_tasks(self):
"""Start metrics publishing task."""
self.background_tasks.append(asyncio.create_task(self._metrics_loop()))
def _start_kv_events_task(self):
"""Lazily start KV events task on first request."""
if self.kv_events_started:
return
self.kv_events_started = True
logger.info(f"Worker {self.worker_id}: Starting KV events monitoring")
self.background_tasks.append(asyncio.create_task(self._kv_events_loop()))
async def _metrics_loop(self):
"""Continuously publish worker metrics."""
await asyncio.sleep(1)
try:
async for stat in self.llm.get_stats_async(timeout=5):
if not isinstance(stat, dict):
continue
num_waiting = (
stat["numQueuedRequests"]
+ stat["inflightBatchingStats"]["numPausedRequests"]
)
kv_stats = stat["kvCacheStats"]
usage = (
kv_stats["allocTotalBlocks"] / kv_stats["maxNumBlocks"]
if kv_stats["maxNumBlocks"] > 0
else 0.0
)
self.metrics_publisher.publish(num_waiting, usage)
except asyncio.CancelledError:
pass
except Exception as e:
logger.error(f"Worker {self.worker_id} metrics error: {e}")
async def _kv_events_loop(self):
"""Continuously process and publish KV cache events."""
await asyncio.sleep(2)
try:
events = self.llm.get_kv_cache_events_async(timeout=5)
logger.info(f"Worker {self.worker_id}: KV events iterator obtained")
async for event in events:
self._process_kv_event(event)
except asyncio.CancelledError:
pass
except RuntimeError as e:
if "IterationResult is not properly instantiated" in str(e):
logger.warning(f"Worker {self.worker_id}: KV events not available")
else:
logger.error(f"Worker {self.worker_id} KV events error: {e}")
except Exception as e:
logger.error(f"Worker {self.worker_id} KV events error: {e}")
logger.warning(f"Worker {self.worker_id}: KV events loop exited")
def _process_kv_event(self, event: dict):
"""Process a single KV cache event."""
if not isinstance(event, dict):
return
if "event_id" not in event or "data" not in event:
return
data = event["data"]
event_type = data.get("type")
if self._should_drop_event(event):
return
if event_type == "stored":
self._handle_stored_event(data)
elif event_type == "removed":
self._handle_removed_event(data)
elif event_type == "created" and self.processing_initial_events:
self._update_window_size(event)
def _should_drop_event(self, event: dict) -> bool:
"""Check if event should be dropped (non-global attention)."""
if self.processing_initial_events:
return False
window_size = event.get("window_size")
if window_size is None:
return False
return window_size != self.max_window_size
def _update_window_size(self, event: dict):
"""Update max window size from created events."""
window_size = event.get("window_size")
if window_size and (
self.max_window_size is None or window_size > self.max_window_size
):
self.max_window_size = window_size
def _handle_stored_event(self, data: dict):
"""Handle a stored block event."""
self.processing_initial_events = False
blocks, all_token_ids = parse_stored_blocks(
data["blocks"],
self.block_size,
self.kv_events_publisher.partial_block_hashes,
)
if not blocks:
return
parent_hash = data.get("parent_hash")
mm_hashes, image_offsets_list = extract_mm_info(data["blocks"], all_token_ids)
block_hashes = [b["block_hash"] for b in blocks]
# Build per-block mm_infos (only blocks with image tokens get mm_info)
block_mm_infos = build_per_block_mm_infos(
len(blocks), self.block_size, mm_hashes, image_offsets_list
)
# Debug dump
dump_worker_kv_event(
self.worker_id,
{"type": "stored", "blocks": len(blocks), "mm_hashes": mm_hashes},
all_token_ids,
)
self.kv_events_publisher.publish_stored(
block_hashes, all_token_ids, parent_hash, block_mm_infos
)
def _handle_removed_event(self, data: dict):
"""Handle a removed block event."""
self.processing_initial_events = False
block_hashes = data.get("block_hashes", [])
self.kv_events_publisher.publish_removed(block_hashes)
# -------------------------------------------------------------------------
# Generation
# -------------------------------------------------------------------------
async def generate(
self,
prompt_input, # list[int] (tokens) or dict (MM input)
sampling_params: dict,
) -> AsyncGenerator[dict, None]:
"""Generate tokens for a request."""
from tensorrt_llm.llmapi.llm import SamplingParams
# Start KV events on first request
self._start_kv_events_task()
trtllm_params = SamplingParams(
max_tokens=sampling_params.get("max_tokens", 100),
temperature=sampling_params.get("temperature", 1.0),
top_p=sampling_params.get("top_p", 1.0),
top_k=max(0, sampling_params.get("top_k", 0)),
)
outputs = self.llm.generate_async(
prompt_input, sampling_params=trtllm_params, streaming=False
)
async for output in outputs:
yield self._format_output(output)
def _format_output(self, request_output) -> dict:
"""Format TRTLLM output to standard response dict."""
if not hasattr(request_output, "outputs") or not request_output.outputs:
return {"text": "", "text_diff": "", "token_ids": [], "finish_reason": None}
completion = request_output.outputs[0]
text = getattr(completion, "text_diff", None) or getattr(completion, "text", "")
return {
"text": text,
"text_diff": getattr(completion, "text_diff", text),
"token_ids": getattr(completion, "token_ids", []),
"finish_reason": getattr(completion, "finish_reason", None),
}
# -------------------------------------------------------------------------
# Lifecycle
# -------------------------------------------------------------------------
def shutdown(self):
"""Shutdown worker and cleanup resources."""
logger.info(f"Worker {self.worker_id}: Shutting down")
for task in self.background_tasks:
task.cancel()
if self.llm:
self.llm.shutdown()
if self.metrics_publisher:
self.metrics_publisher.close()
if self.kv_events_publisher:
self.kv_events_publisher.close()
# -----------------------------------------------------------------------------
# Worker Manager
# -----------------------------------------------------------------------------
class TrtllmWorkers:
"""Manages multiple TensorRT-LLM workers.
Warning: Creating multiple workers in the same process causes them to share
the same GPU(s).
"""
def __init__(
self,
model: str = "Qwen/Qwen2-VL-2B-Instruct",
block_size: int = 32,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
num_workers: int = 1,
):
self.workers = []
if num_workers > 1:
logger.warning(
f"Creating {num_workers} workers in the same process. "
"All workers will share the same GPU(s). For multi-GPU isolation, "
"start each worker in a separate process with CUDA_VISIBLE_DEVICES set."
)
logger.info(f"Initializing {num_workers} workers for {model}")
for i in range(num_workers):
self.workers.append(
TrtllmWorker(
worker_id=i,
model=model,
block_size=block_size,
kv_events_port=base_kv_events_port + i,
metrics_port=base_metrics_port + i,
)
)
logger.info(f"All {num_workers} workers initialized")
async def start_all(self):
"""Start background tasks for all workers."""
for worker in self.workers:
await worker.start_background_tasks()
async def direct(
self, prompt_input, worker_id: int, sampling_params: dict
) -> AsyncGenerator[dict, None]:
"""Send request to a specific worker."""
async for output in self.workers[worker_id].generate(
prompt_input, sampling_params
):
yield output
def shutdown_all(self):
"""Shutdown all workers."""
logger.info("Shutting down all workers")
for worker in self.workers:
worker.shutdown()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment