Unverified Commit 13a99b7f authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: Standalone Router (#1409)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
Signed-off-by: default avatarYan Ru Pei <yanrpei@gmail.com>
Signed-off-by: default avatarjain-ria <riajain@NVIDIA.com>
Co-authored-by: default avatarcoderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
Co-authored-by: default avatarjain-ria <riajain@NVIDIA.com>
parent 1906b702
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Router Standalone
A toy implementation of KvRouter that demonstrates standalone usage without dependency on the dynamo runtime, etcd control plane, or nats event plane.
## Overview
This example shows how to use KvRouter in a standalone fashion to intelligently route requests across multiple vLLM workers based on KV cache overlap and load metrics. The router maintains a view of each worker's cached blocks and routes new requests to the worker with the best combination of cache overlap and available capacity.
> [!Tip]
> The main focus should be put on `router.py` as it contains the bulk of the non-boilerplate code and core routing logic.
## How It Works
### Core Architecture
The router uses a **RadixTree** data structure (written in Rust) to efficiently track which blocks each worker has cached. When a new request arrives, the router:
1. Uses `find_matches` to calculate overlap scores (number of matching blocks) between the request and each worker's cached blocks
2. Combines this with current load metrics to select the optimal worker
3. Routes the request to the chosen worker for processing
### Event-Driven Updates
The router receives two types of events from vLLM engines:
1. **KV Events**: Emitted automatically by vLLM engines when blocks are cached/evicted
2. **Load Metrics**: GPU usage percentage and waiting request count via custom callbacks
These events keep the router's view of worker state up-to-date in real-time.
### Alternative: Pure Predictive Routing
While not implemented in this example, the router can also operate in a pure predictive mode, estimating the radix tree state and loads based solely on the requests it receives, without relying on backend events. This requires simulating / mocking the block managing (e.g. eviction) and the scheduling policies of the backend engine. This is not recommended as there is no real-time feedback from the engines, and the router state may drift out of sync with the engine states. Nevertheless, this is WIP and can be supported in the future via our mocker engines.
## Components
> [!Note]
> This is a standalone toy implementation created for pedagogical purposes to demonstrate the core KvRouter concepts in isolation.
> Our default dynamo router is already very efficient and uses NATS for event communication and etcd for endpoint registration.
> This example intentionally avoids these production components to provide a simpler, self-contained demonstration of the routing logic and cache overlap mechanics.
>
> The toy communication pattern is as follows:
> - **OpenAI Compatible Frontend** – FastAPI application serving OpenAI compatible HTTP API.
> - **Router** – Standalone FastAPI endpoint for best worker selection, with core routines implemented in Rust exposed via Python bindings.
> - **Workers** – Served in-process within the frontend application to reduce complexity and boilerplate, rather than as separate endpoints.
### `router.py`
- **KvRouter**: Core routing logic using RadixTree
- Subscribes to KV cache events and load metrics from workers
- Implements `get_best_worker()` to select optimal routing destination
- Runs background tasks to periodically update worker states
### `worker.py`
- **VllmWorkers**: Manages multiple vLLM worker processes
- Each worker runs on a separate port with KV cache event emission enabled
- Provides `direct()` method for sending requests to specific workers
- Handles worker lifecycle and configuration
### `api.py`
- **RouterAPI**: Minimal FastAPI server providing OpenAI-compatible chat completions endpoint
- Enables in-process communication between router and workers
- Can be easily modified to use external communication (FastAPI clients, dynamo endpoints, etc.)
- Integrates with vLLM's OpenAI serving components for request preprocessing and response formatting
### `perf.sh`
- Benchmarking script using `genai-perf` to test the router setup
- Configured for streaming chat completions with synthetic workloads
- Tests concurrent requests to evaluate routing performance
## Usage
1. **Install latest vLLM**:
```bash
uv pip uninstall ai-dynamo-vllm
uv pip install vllm==0.9.0
```
*Note: This uninstalls the local vLLM patch (`ai-dynamo-vllm`) and replaces it with the latest standard vLLM package.*
2. **Start the router API**:
For example:
```bash
python api.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--num-workers 4 \
--block-size 64 \
--base-kv-events-port 5557 \
--base-metrics-port 5657 \
--router-port 7000 \
--http-port 8000
```
3. **Ping the endpoint (optional)**:
```bash
./ping.sh
```
4. **Run performance benchmark**:
```bash
./perf.sh
```
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import logging
import uuid
from dataclasses import dataclass
from typing import Optional
import httpx
import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from router import RouterAPI, RouterRequest, RouterResponse # Add this import
from transformers import PreTrainedTokenizerBase
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ErrorResponse,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.transformers_utils.tokenizer import get_tokenizer
from worker import VllmWorkers
from dynamo._core import compute_block_hash_for_seq_py
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class ServingParams:
model: str
block_size: int
num_workers: int
base_kv_events_port: int
base_metrics_port: int
router_port: int
http_port: int
class ServiceAPI:
def __init__(self, init_params: ServingParams):
self.init_params = init_params
self.app = FastAPI(title="Router API", version="0.0.1")
# These will be initialized in start()
self.workers: Optional[VllmWorkers] = None
self.tokenizer: Optional[PreTrainedTokenizerBase] = None
self.openai_serving_chat: Optional[OpenAIServingChat] = None
self.model_config: Optional[ModelConfig] = None
self.http_client: Optional[httpx.AsyncClient] = None
self.setup_routes()
def setup_routes(self):
@self.app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
if (
self.workers is None
or self.tokenizer is None
or self.openai_serving_chat is None
or self.http_client is None
):
return ErrorResponse(
message="Service not ready",
type="service_unavailable",
code=503,
)
try:
# Determine max_tokens: use max_completion_tokens first, then max_tokens, or error
max_tokens_value = None
if (
hasattr(request, "max_completion_tokens")
and request.max_completion_tokens is not None
):
max_tokens_value = request.max_completion_tokens
elif hasattr(request, "max_tokens") and request.max_tokens is not None:
max_tokens_value = request.max_tokens
else:
return ErrorResponse(
message="Either max_tokens or max_completion_tokens must be specified",
type="invalid_request_error",
code=400,
)
# Use vLLM's preprocessing to convert chat to prompt
(
conversation,
request_prompts,
engine_prompts,
) = await self.openai_serving_chat._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=request.chat_template or self.tokenizer.chat_template,
chat_template_content_format=self.openai_serving_chat.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
add_special_tokens=False,
)
engine_prompt = engine_prompts[0]
# Convert request to sampling parameters with our determined max_tokens
sampling_params = request.to_sampling_params(
default_max_tokens=max_tokens_value,
logits_processor_pattern=None,
default_sampling_params=None,
)
# Get best worker using HTTP request to router
tokens: list[int] = engine_prompt["prompt_token_ids"]
num_tokens = len(tokens)
if num_tokens == 0:
return ErrorResponse(
message="Input prompt is empty",
type="invalid_request_error",
code=400,
)
# It is much preferred to communicate block hashes to the router instead of
# raw text prompts or tokens, especially when over network using pydantic validation,
# as block hashes can be orders of magnitude smaller.
# Note that the hashing function needs to be deterministic (across processes),
# and has to be consistent with the hashing function used to send KV Events to the Router.
local_hashes = compute_block_hash_for_seq_py(
tokens, self.init_params.block_size
)
# Call router via HTTP
try:
router_request = RouterRequest(
local_hashes=local_hashes, num_tokens=num_tokens
)
router_response = await self.http_client.post(
f"http://localhost:{self.init_params.router_port}/find_best_worker",
json=router_request.model_dump(),
timeout=1,
)
router_response.raise_for_status()
router_data = RouterResponse.model_validate(router_response.json())
best_worker_id = router_data.worker_id
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.error(f"Router request failed: {e}")
return ErrorResponse(
message="Router service unavailable",
type="service_unavailable",
code=503,
)
logger.info(f"Selected worker {best_worker_id} for request")
# Generate request ID
request_id = f"chatcmpl-{uuid.uuid4()}"
request_metadata = RequestResponseMetadata(request_id=request_id)
# Get the generator from the selected worker with sampling params
result_generator = self.workers.direct(
engine_prompt, best_worker_id, sampling_params
)
assert request.stream
# Use vLLM's streaming response generator
return StreamingResponse(
self.openai_serving_chat.chat_completion_stream_generator(
request,
result_generator,
request_id,
self.init_params.model,
conversation,
self.tokenizer,
request_metadata,
),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
except Exception as e:
logger.error(f"Error processing request: {e}")
return ErrorResponse(message=str(e), type="internal_error", code=500)
async def initialize_services(self):
"""Initialize workers, HTTP client, and OpenAI serving components"""
logger.info("Initializing VllmWorkers...")
self.workers = VllmWorkers(
model=self.init_params.model,
block_size=self.init_params.block_size,
base_kv_events_port=self.init_params.base_kv_events_port,
base_metrics_port=self.init_params.base_metrics_port,
num_workers=self.init_params.num_workers,
)
# Initialize HTTP client for router communication
self.http_client = httpx.AsyncClient()
logger.info("Initializing OpenAI serving components...")
# Initialize tokenizer and model config
self.tokenizer = get_tokenizer(self.init_params.model)
# Create a mock model config
self.model_config = ModelConfig(
model=self.init_params.model,
enforce_eager=True,
)
# Initialize OpenAI serving models
base_model_paths = [
BaseModelPath(
name=self.init_params.model, model_path=self.init_params.model
)
]
openai_serving_models = OpenAIServingModels(
engine_client=None,
model_config=self.model_config,
base_model_paths=base_model_paths,
)
# Initialize OpenAI serving chat
self.openai_serving_chat = OpenAIServingChat(
engine_client=None,
model_config=self.model_config,
models=openai_serving_models,
response_role="assistant",
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
logger.info("Waiting 2 seconds for services to initialize...")
await asyncio.sleep(2)
logger.info("Services initialized successfully!")
async def start(self):
"""Start the API server"""
# Initialize services first
await self.initialize_services()
# Start the API server
logger.info(f"Starting API server on port {self.init_params.http_port}")
config = uvicorn.Config(
self.app, host="0.0.0.0", port=self.init_params.http_port, log_level="info"
)
server = uvicorn.Server(config)
await server.serve()
async def shutdown(self):
"""Proper shutdown handler"""
logger.info("Shutting down API...")
if self.http_client:
await self.http_client.aclose()
logger.info("API shutdown completed")
def main():
parser = argparse.ArgumentParser(description="Router API Server")
# Arguments from worker.py
parser.add_argument(
"--model",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model name to use",
)
# Common arguments
parser.add_argument(
"--block-size", type=int, default=64, help="Block size for caching"
)
parser.add_argument(
"--num-workers", type=int, default=2, help="Number of worker processes"
)
parser.add_argument(
"--base-kv-events-port", type=int, default=5557, help="Base port for KV events"
)
parser.add_argument(
"--base-metrics-port", type=int, default=5657, help="Base port for metrics"
)
parser.add_argument(
"--router-port",
type=int,
default=7000,
help="Port for router service",
)
parser.add_argument(
"--http-port", type=int, default=8000, help="Port to serve the API on"
)
args = parser.parse_args()
# Setup logging
logging.basicConfig(level=logging.INFO)
init_params = ServingParams(
model=args.model,
block_size=args.block_size,
num_workers=args.num_workers,
base_kv_events_port=args.base_kv_events_port,
base_metrics_port=args.base_metrics_port,
router_port=args.router_port,
http_port=args.http_port,
)
# Create both services
api = ServiceAPI(init_params=init_params)
router_api = RouterAPI(
block_size=args.block_size,
num_workers=args.num_workers,
base_kv_events_port=args.base_kv_events_port,
base_metrics_port=args.base_metrics_port,
port=args.router_port,
)
async def run_with_shutdown():
try:
# Start both services concurrently
await asyncio.gather(
api.start(), router_api.start(), return_exceptions=True
)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down services...")
except Exception as e:
logger.exception(f"Unhandled exception: {e}")
finally:
await api.shutdown()
try:
asyncio.run(run_with_shutdown())
except KeyboardInterrupt:
# Just in case KeyboardInterrupt happens outside of the event loop
logger.info("Force shutdown via KeyboardInterrupt.")
if __name__ == "__main__":
main()
#/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
model=deepseek-ai/DeepSeek-R1-Distill-Llama-8B
type=chat
endpoint=/v1/chat/completions
port=8000
isl=3000
osl=100
concurrency=25
num_requests=100
num_unique_prompts=10
seed=42
genai-perf profile \
--model ${model} \
--tokenizer ${model} \
--endpoint-type ${type} \
--endpoint ${endpoint} \
--streaming \
--url http://localhost:${port} \
--synthetic-input-tokens-mean ${isl} \
--synthetic-input-tokens-stddev 0 \
--output-tokens-mean ${osl} \
--output-tokens-stddev 0 \
--extra-inputs max_tokens:${osl} \
--extra-inputs min_tokens:${osl} \
--extra-inputs ignore_eos:true \
--extra-inputs "{\"nvext\":{\"ignore_eos\":true}}" \
--concurrency ${concurrency} \
--request-count ${num_requests} \
--num-dataset-entries ${num_unique_prompts} \
--random-seed ${seed} \
-- \
-v \
--max-threads 256 \
-H 'Authorization: Bearer NOT USED' \
-H 'Accept: text/event-stream'
#/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Accept: text/event-stream" \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": true,
"max_tokens": 100
}'
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import json
import logging
from contextlib import asynccontextmanager
from typing import List
import numpy as np
import uvicorn
import zmq
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from dynamo._core import RadixTree, ZmqKvEventListener
logger = logging.getLogger(__name__)
class RouterRequest(BaseModel):
local_hashes: List[int]
num_tokens: int
class RouterResponse(BaseModel):
worker_id: int
class LoadMetrics(BaseModel):
gpu_cache_usage: float
num_waiting_reqs: int
def setup_zmq_subscriber(context: zmq.Context, endpoint: str) -> zmq.Socket[bytes]:
socket = context.socket(zmq.SUB)
socket.connect(endpoint)
socket.setsockopt(zmq.SUBSCRIBE, b"") # Subscribe to all messages
socket.setsockopt(zmq.CONFLATE, 1) # Only keep latest message
socket.setsockopt(zmq.RCVTIMEO, 1) # 1ms timeout (very short)
return socket
class KvRouter:
def __init__(
self,
block_size: int = 64,
num_workers: int = 4,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
):
self.num_workers = num_workers
self.block_size = block_size
self.radix_tree = RadixTree()
self.kv_usages = [0.0] * num_workers
self.waitings = [0] * num_workers
self.context = zmq.Context()
self.load_listeners = [
setup_zmq_subscriber(
self.context, f"tcp://localhost:{base_metrics_port + worker_id}"
)
for worker_id in range(num_workers)
]
self.kv_listeners = [
ZmqKvEventListener(
f"tcp://localhost:{base_kv_events_port + worker_id}", "", block_size
)
for worker_id in range(num_workers)
]
self.background_tasks: list[asyncio.Task] = []
logger.info("Router initialized")
async def start_background_tasks(self):
"""Start background tasks for load and indexer updates"""
logger.info("Starting router background tasks...")
self.background_tasks.append(asyncio.create_task(self.periodic_update_load()))
self.background_tasks.append(
asyncio.create_task(self.periodic_update_indexer())
)
async def periodic_update_load(self):
async def update_load(worker_id: int):
while True:
try:
metrics_dict = self.load_listeners[worker_id].recv_json(zmq.NOBLOCK)
metrics = LoadMetrics.model_validate(metrics_dict)
self.kv_usages[worker_id] = metrics.gpu_cache_usage
self.waitings[worker_id] = metrics.num_waiting_reqs
except zmq.Again:
pass
except Exception as e:
logger.warning(
f"Error receiving metrics for worker {worker_id}: {e}"
)
await asyncio.sleep(0.1)
for worker_id in range(self.num_workers):
asyncio.create_task(update_load(worker_id))
async def periodic_update_indexer(self):
async def update_tree(worker_id: int):
while True:
try:
kv_events: list[str] = await self.kv_listeners[
worker_id
].get_events()
for event in kv_events:
event = json.loads(event)
self.radix_tree.apply_event(
worker_id, json.dumps(event).encode("utf-8")
)
except zmq.Again:
pass
except Exception as e:
logger.warning(
f"Error receiving KV events for worker {worker_id}: {e}"
)
await asyncio.sleep(0.1)
for worker_id in range(self.num_workers):
asyncio.create_task(update_tree(worker_id))
async def get_best_worker(self, local_hashes: list[int], num_tokens: int) -> int:
try:
if num_tokens <= 0:
raise ValueError("num_tokens must be positive")
# local_hashes can be empty
raw_scores = self.radix_tree.find_matches(local_hashes).scores
overlap_scores = {
worker_id: raw_scores.get(worker_id, 0) * self.block_size / num_tokens
for worker_id in range(self.num_workers)
}
kv_usages = self.kv_usages[:]
waitings = self.waitings[:]
max_waiting = max(waitings) if waitings else 0
waitings_normalized = [
waiting / max_waiting if max_waiting else 0.0 for waiting in waitings
]
logits = []
for worker_id in range(self.num_workers):
overlap = overlap_scores[worker_id]
usage = kv_usages[worker_id]
waiting = waitings_normalized[worker_id]
logit = 2 * overlap - usage - waiting
logits.append(logit)
logger.info(
f"worker_id: {worker_id}, logit = 2 * {overlap:.3f} - {usage:.3f} - {waiting:.3f} = {logit:.3f}"
)
logits_array = np.array(logits)
best_worker_id = int(
np.random.choice(np.flatnonzero(logits_array == logits_array.max()))
)
# this is a predictive update which will be reset as new metrics are polled
# but it is helpful for handling short bursts of highly concurrent requests
# we omit updating the gpu_usage_perc as done in the rusty router for simplicity
# as this requires obtaining num_gpu_blocks from the engines and can be intrusive
# no need for async lock here, as the state is intended to be continuously overwritten
self.waitings[best_worker_id] += 1
return best_worker_id
except Exception as e:
logger.error(f"Error in get_best_worker: {e}")
raise
async def shutdown(self):
"""Shutdown ZMQ listeners, context, and background tasks"""
logger.info("Shutting down KvRouter...")
# Cancel background tasks
for task in self.background_tasks:
task.cancel()
if self.background_tasks:
await asyncio.gather(*self.background_tasks, return_exceptions=True)
# Close load listeners (ZMQ sockets)
for listener in self.load_listeners:
try:
listener.close()
except Exception as e:
logger.error(f"Error closing load listener: {e}")
# Terminate ZMQ context
try:
self.context.term()
logger.info("ZMQ context terminated successfully")
except Exception as e:
logger.error(f"Error terminating ZMQ context: {e}")
logger.info("KvRouter shutdown completed")
class RouterAPI:
def __init__(
self,
block_size: int = 64,
num_workers: int = 4,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
port: int = 7000,
):
self.port = port
self.block_size = block_size
self.num_workers = num_workers
self.base_kv_events_port = base_kv_events_port
self.base_metrics_port = base_metrics_port
self.router = None
self.app = FastAPI(
title="KV Router API", version="0.0.1", lifespan=self.lifespan
)
self.setup_routes()
@asynccontextmanager
async def lifespan(self, app: FastAPI):
# Startup
self.router = KvRouter(
block_size=self.block_size,
num_workers=self.num_workers,
base_kv_events_port=self.base_kv_events_port,
base_metrics_port=self.base_metrics_port,
)
await self.router.start_background_tasks()
logger.info("Router API started successfully")
yield
# Shutdown
if self.router:
await self.router.shutdown()
def setup_routes(self):
@self.app.post("/find_best_worker", response_model=RouterResponse)
async def find_best_worker(request: RouterRequest):
if self.router is None:
raise HTTPException(status_code=503, detail="Router not initialized")
try:
worker_id = await self.router.get_best_worker(
request.local_hashes, request.num_tokens
)
return RouterResponse(worker_id=worker_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error finding best worker: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
async def start(self):
"""Start the router API server"""
logger.info(f"Starting Router API server on port {self.port}")
config = uvicorn.Config(
self.app, host="0.0.0.0", port=self.port, log_level="info"
)
server = uvicorn.Server(config)
await server.serve()
def main():
parser = argparse.ArgumentParser(description="KV Router API Server")
parser.add_argument(
"--block-size", type=int, default=64, help="Block size for caching"
)
parser.add_argument(
"--num-workers", type=int, default=2, help="Number of worker processes"
)
parser.add_argument(
"--base-kv-events-port", type=int, default=5557, help="Base port for KV events"
)
parser.add_argument(
"--base-metrics-port", type=int, default=5657, help="Base port for metrics"
)
parser.add_argument(
"--port", type=int, default=7000, help="Port to serve the Router API on"
)
args = parser.parse_args()
# Setup logging
logging.basicConfig(level=logging.INFO)
api = RouterAPI(
block_size=args.block_size,
num_workers=args.num_workers,
base_kv_events_port=args.base_kv_events_port,
base_metrics_port=args.base_metrics_port,
port=args.port,
)
async def run_with_shutdown():
try:
await api.start()
except KeyboardInterrupt:
logger.info(
"Received KeyboardInterrupt, shutting down Router API server..."
)
except Exception as e:
logger.exception(f"Unhandled exception: {e}")
try:
asyncio.run(run_with_shutdown())
except KeyboardInterrupt:
logger.info("Force shutdown via KeyboardInterrupt.")
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import os
import uuid
from typing import AsyncGenerator, Optional
import zmq
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig, VllmConfig
from vllm.distributed.kv_events import KVEventsConfig
from vllm.inputs.data import TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
logger = logging.getLogger(__name__)
class MetricsPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
def __init__(self, port: int) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(f"tcp://*:{port}")
logger.info(f"ZMQ publisher initialized on port {port}")
def record(
self, scheduler_stats: SchedulerStats, iteration_stats: Optional[IterationStats]
):
# Send metrics over ZMQ
metrics_data = {
"num_waiting_reqs": scheduler_stats.num_waiting_reqs,
"gpu_cache_usage": scheduler_stats.gpu_cache_usage,
}
self.socket.send_json(metrics_data)
def log_engine_initialized(self) -> None:
pass
class LoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""
def __init__(self, port: int) -> None:
self.port = port
def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase:
return MetricsPublisher(port=self.port)
class VllmWorkers:
def __init__(
self,
model: str = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
block_size: int = 64,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
num_workers: int = 1,
):
os.environ["VLLM_NO_USAGE_STATS"] = "1"
self.num_workers = num_workers
self.llms: list[AsyncLLM] = []
for worker_id in range(num_workers):
os.environ["CUDA_VISIBLE_DEVICES"] = str(worker_id)
zmq_port = base_kv_events_port + worker_id
metrics_port = base_metrics_port + worker_id
model_config = ModelConfig(
model=model,
enforce_eager=True,
)
cache_config = CacheConfig(
block_size=block_size,
enable_prefix_caching=True,
)
kv_events_config = KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{zmq_port}",
)
scheduler_config = SchedulerConfig(
scheduler_cls="vllm.v1.core.sched.scheduler.Scheduler"
)
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
kv_events_config=kv_events_config,
scheduler_config=scheduler_config,
)
self.llms.append(
AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
stat_loggers=[LoggerFactory(port=metrics_port)],
)
)
async def direct(
self, prompt: TokensPrompt, worker_id: int, sampling_params: SamplingParams
) -> AsyncGenerator[RequestOutput, None]:
outputs = self.llms[worker_id].generate(
prompt,
sampling_params=sampling_params,
request_id=str(uuid.uuid4()),
)
async for output in outputs:
yield output
......@@ -1174,6 +1174,7 @@ dependencies = [
"thiserror 2.0.12",
"tokio",
"tokio-stream",
"tokio-util",
"tracing",
"tracing-subscriber",
]
......
......@@ -52,6 +52,7 @@ serde_json = { version = "1.0.138" }
thiserror = { version = "2.0" }
tokio = { version = "1", features = ["full"] }
tokio-stream = { version = "0" }
tokio-util = { version = "0.7" }
tracing = { version = "0" }
tracing-subscriber = { version = "0.3", features = ["fmt", "env-filter"] }
......
......@@ -38,6 +38,7 @@ const DEFAULT_ANNOTATED_SETTING: Option<bool> = Some(true);
#[pymodule]
fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
logging::init();
m.add_function(wrap_pyfunction!(llm::kv::compute_block_hash_for_seq_py, m)?)?;
m.add_function(wrap_pyfunction!(log_message, m)?)?;
m.add_function(wrap_pyfunction!(register_llm, m)?)?;
......@@ -61,6 +62,8 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::AggregatedMetrics>()?;
m.add_class::<llm::kv::KvMetricsAggregator>()?;
m.add_class::<llm::kv::KvEventPublisher>()?;
m.add_class::<llm::kv::RadixTree>()?;
m.add_class::<llm::kv::ZmqKvEventListener>()?;
m.add_class::<llm::kv::ZmqKvEventPublisher>()?;
m.add_class::<llm::kv::ZmqKvEventPublisherConfig>()?;
m.add_class::<llm::kv::KvRecorder>()?;
......
......@@ -17,6 +17,7 @@ use std::collections::HashMap;
use std::sync::atomic::AtomicU32;
use super::*;
use llm_rs::kv_router::indexer::compute_block_hash_for_seq;
use llm_rs::kv_router::indexer::KvIndexerInterface;
use rs::traits::events::EventSubscriber;
use tracing;
......@@ -33,6 +34,10 @@ pub(crate) struct KvRouter {
impl KvRouter {
#[new]
fn new(component: Component, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
};
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let inner =
......@@ -62,6 +67,16 @@ impl KvRouter {
}
}
#[pyfunction]
pub fn compute_block_hash_for_seq_py(tokens: Vec<u32>, kv_block_size: usize) -> PyResult<Vec<u64>> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}
let hashes = compute_block_hash_for_seq(&tokens, kv_block_size);
Ok(hashes.into_iter().map(|h| h.0).collect())
}
#[pyclass]
pub(crate) struct WorkerMetricsPublisher {
inner: Arc<llm_rs::kv_router::publisher::WorkerMetricsPublisher>,
......@@ -191,6 +206,75 @@ impl ZmqKvEventPublisher {
}
}
/// A ZMQ-based key-value cache event listener that operates independently
/// of the dynamo runtime or event plane infrastructure.
#[pyclass]
pub(crate) struct ZmqKvEventListener {
event_receiver: Arc<tokio::sync::Mutex<tokio::sync::mpsc::UnboundedReceiver<KvCacheEvent>>>,
shutdown_token: tokio_util::sync::CancellationToken,
}
#[pymethods]
impl ZmqKvEventListener {
#[new]
fn new(zmq_endpoint: String, zmq_topic: String, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel::<KvCacheEvent>();
let shutdown_token = tokio_util::sync::CancellationToken::new();
tokio::spawn(llm_rs::kv_router::publisher::start_zmq_listener(
zmq_endpoint,
zmq_topic,
tx,
shutdown_token.clone(),
kv_block_size,
));
Ok(Self {
event_receiver: Arc::new(tokio::sync::Mutex::new(rx)),
shutdown_token,
})
})
}
fn get_events<'p>(&self, py: Python<'p>) -> PyResult<Bound<'p, PyAny>> {
let receiver = self.event_receiver.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let mut rx = receiver.lock().await;
let mut events = Vec::new();
// Drain all available events
while let Ok(event) = rx.try_recv() {
events.push(event);
}
// Convert events to JSON strings
let json_events: Result<Vec<String>, _> =
events.iter().map(serde_json::to_string).collect();
match json_events {
Ok(json_strings) => Ok(json_strings),
Err(e) => Err(PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to serialize events to JSON: {}",
e
))),
}
})
}
}
// manual shutdown needed as it's not tied to the dynamo DRT
impl Drop for ZmqKvEventListener {
fn drop(&mut self) {
self.shutdown_token.cancel();
}
}
#[pyclass]
pub(crate) struct KvEventPublisher {
inner: Arc<llm_rs::kv_router::publisher::KvEventPublisher>,
......@@ -202,6 +286,10 @@ pub(crate) struct KvEventPublisher {
impl KvEventPublisher {
#[new]
fn new(component: Component, worker_id: i64, kv_block_size: usize) -> PyResult<Self> {
if kv_block_size == 0 {
return Err(to_pyerr(anyhow::anyhow!("kv_block_size cannot be 0")));
}
let inner = llm_rs::kv_router::publisher::KvEventPublisher::new(
component.inner,
worker_id,
......@@ -280,6 +368,70 @@ impl OverlapScores {
}
}
// NOTE: the user needs to guarantee that this stays single threaded in Python land
#[pyclass(unsendable)]
pub(crate) struct RadixTree {
inner: llm_rs::kv_router::indexer::RadixTree,
}
#[pymethods]
impl RadixTree {
#[new]
#[pyo3(signature = (expiration_duration_secs=None))]
fn new(expiration_duration_secs: Option<f64>) -> PyResult<Self> {
let expiration_duration = expiration_duration_secs.map(std::time::Duration::from_secs_f64);
let inner = llm_rs::kv_router::indexer::RadixTree::new_with_frequency(expiration_duration);
Ok(Self { inner })
}
#[pyo3(signature = (sequence, early_exit=false))]
fn find_matches(
&self,
_py: Python,
sequence: Vec<u64>,
early_exit: bool,
) -> PyResult<OverlapScores> {
let local_block_hashes: Vec<llm_rs::kv_router::protocols::LocalBlockHash> = sequence
.into_iter()
.map(llm_rs::kv_router::protocols::LocalBlockHash)
.collect();
let rs_overlap_scores = self.inner.find_matches(local_block_hashes, early_exit);
Ok(OverlapScores {
inner: rs_overlap_scores,
})
}
fn apply_event(
&mut self,
_py: Python,
worker_id: i64,
kv_cache_event_bytes: &[u8],
) -> PyResult<()> {
let kv_cache_event: llm_rs::kv_router::protocols::KvCacheEvent =
serde_json::from_slice(kv_cache_event_bytes).map_err(|e| {
PyErr::new::<pyo3::exceptions::PyValueError, _>(format!(
"Failed to deserialize KvCacheEvent: {}",
e
))
})?;
let router_event = llm_rs::kv_router::indexer::RouterEvent::new(worker_id, kv_cache_event);
self.inner.apply_event(router_event);
Ok(())
}
fn remove_worker(&mut self, _py: Python, worker_id: i64) -> PyResult<()> {
self.inner.remove_worker(worker_id);
Ok(())
}
fn clear_all_blocks(&mut self, _py: Python, worker_id: i64) -> PyResult<()> {
self.inner.clear_all_blocks(worker_id);
Ok(())
}
}
#[pyclass]
pub(crate) struct KvIndexer {
inner: Arc<llm_rs::kv_router::indexer::KvIndexer>,
......
......@@ -344,6 +344,19 @@ class DisaggregatedRouter:
"""
...
def compute_block_hash_for_seq_py(tokens: List[int], kv_block_size: int) -> List[int]:
"""
Compute block hashes for a sequence of tokens
Args:
tokens: List of token IDs
kv_block_size: Size of each KV cache block
Returns:
List of block hashes as integers
"""
...
class WorkerMetricsPublisher:
"""
A metrics publisher will provide metrics to the router.
......@@ -418,7 +431,89 @@ class OverlapScores:
'scores' is a map of worker id to the score which is the number of matching blocks.
"""
...
@property
def scores(self) -> Dict[int, int]:
"""
Map of worker_id to the score which is the number of matching blocks.
Returns:
Dictionary mapping worker IDs to their overlap scores
"""
...
@property
def frequencies(self) -> List[int]:
"""
List of frequencies that the blocks have been accessed.
Entries with value 0 are omitted.
Returns:
List of access frequencies for each block
"""
...
class RadixTree:
"""
A RadixTree that tracks KV cache blocks and can find prefix matches for sequences.
NOTE: This class is not thread-safe and should only be used from a single thread in Python.
"""
def __init__(self, expiration_duration_secs: Optional[float] = None) -> None:
"""
Create a new RadixTree instance.
Args:
expiration_duration_secs: Optional expiration duration in seconds for cached blocks.
If None, blocks never expire.
"""
...
def find_matches(
self, sequence: List[int], early_exit: bool = False
) -> OverlapScores:
"""
Find prefix matches for the given sequence of block hashes.
Args:
sequence: List of block hashes to find matches for
early_exit: If True, stop searching after finding the first match
Returns:
OverlapScores containing worker matching scores and frequencies
"""
...
def apply_event(self, worker_id: int, kv_cache_event_bytes: bytes) -> None:
"""
Apply a KV cache event to update the RadixTree state.
Args:
worker_id: ID of the worker that generated the event
kv_cache_event_bytes: Serialized KV cache event as bytes
Raises:
ValueError: If the event bytes cannot be deserialized
"""
...
def remove_worker(self, worker_id: int) -> None:
"""
Remove all blocks associated with a specific worker.
Args:
worker_id: ID of the worker to remove
"""
...
def clear_all_blocks(self, worker_id: int) -> None:
"""
Clear all blocks for a specific worker.
Args:
worker_id: ID of the worker whose blocks should be cleared
"""
...
class KvIndexer:
"""
......@@ -919,3 +1014,34 @@ class BlockManager:
List of allocated blocks
"""
...
class ZmqKvEventListener:
"""
A ZMQ-based key-value cache event listener that operates independently
of the dynamo runtime or event plane infrastructure.
"""
def __init__(
self, zmq_endpoint: str, zmq_topic: str, kv_block_size: int
) -> None:
"""
Create a new ZmqKvEventListener instance.
Args:
zmq_endpoint: ZeroMQ endpoint to connect to (e.g., "tcp://127.0.0.1:5557")
zmq_topic: ZeroMQ topic to subscribe to
kv_block_size: Size of KV cache blocks
"""
...
async def get_events(self) -> List[str]:
"""
Get all available KV cache events from the ZMQ listener.
Returns:
List of JSON-serialized KV cache events as strings
Raises:
ValueError: If events cannot be serialized to JSON
"""
...
......@@ -33,9 +33,12 @@ from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouter as KvRouter
from dynamo._core import ModelType as ModelType
from dynamo._core import OverlapScores as OverlapScores
from dynamo._core import RadixTree as RadixTree
from dynamo._core import WorkerMetricsPublisher as WorkerMetricsPublisher
from dynamo._core import ZmqKvEventListener as ZmqKvEventListener
from dynamo._core import ZmqKvEventPublisher as ZmqKvEventPublisher
from dynamo._core import ZmqKvEventPublisherConfig as ZmqKvEventPublisherConfig
from dynamo._core import compute_block_hash_for_seq_py as compute_block_hash_for_seq_py
from dynamo._core import register_llm as register_llm
try:
......
......@@ -28,6 +28,7 @@ from dynamo.llm import (
KvEventPublisher,
KvIndexer,
KvMetricsAggregator,
RadixTree,
WorkerMetricsPublisher,
)
from dynamo.runtime import Component, DistributedRuntime
......@@ -59,6 +60,56 @@ async def distributed_runtime():
return DistributedRuntime(loop, False)
async def test_radix_tree_binding(distributed_runtime):
"""Test RadixTree binding directly with store event and find matches"""
import json
# Create RadixTree instance
radix_tree = RadixTree()
# Create a store event with parent_hash=None, block_hash=0
# Following the KvCacheEvent format from the Rust protocols
store_event = {
"event_id": 1,
"data": {
"stored": {
"parent_hash": None,
"blocks": [
{
"block_hash": 0,
"tokens_hash": 0, # Using 0 for both hashes to match tokens [0]
}
],
}
},
}
# Convert to JSON bytes
event_bytes = json.dumps(store_event).encode("utf-8")
# Apply the event to worker_id 0
worker_id = 0
radix_tree.apply_event(worker_id, event_bytes)
# Find matches for tokens [0]
# The sequence parameter expects token hashes, so we use [0] to match tokens_hash=0
overlap_scores = radix_tree.find_matches([0])
# Verify the results
assert overlap_scores.scores is not None
assert (
len(overlap_scores.scores) == 1
), f"Expected 1 worker in scores, got {len(overlap_scores.scores)}"
assert worker_id in overlap_scores.scores, f"Worker {worker_id} not found in scores"
assert (
overlap_scores.scores[worker_id] == 1
), f"Expected score 1 for worker {worker_id}, got {overlap_scores.scores[worker_id]}"
print(
f"✓ RadixTree test passed: worker {worker_id} has score {overlap_scores.scores[worker_id]}"
)
# TODO Figure out how to test with different kv_block_size
# Right now I get an error in EventPublisher init when I run this test
# back to back. It occurs when calling dynamo_llm_init and I think is related to the
......
......@@ -218,7 +218,7 @@ fn calculate_backoff_ms(consecutive_errors: u32) -> u64 {
)
}
async fn start_zmq_listener(
pub async fn start_zmq_listener(
zmq_endpoint: String,
zmq_topic: String,
tx: mpsc::UnboundedSender<KvCacheEvent>,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment