Unverified Commit ef1078e4 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: nuke standalone fast api router (#5845)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 6c870810
......@@ -46,7 +46,6 @@ Platform-specific deployment guides for production environments:
- **[Amazon EKS](/examples/deployments/EKS/)** - Deploy Dynamo on Amazon Elastic Kubernetes Service
- **[Azure AKS](/examples/deployments/AKS/)** - Deploy Dynamo on Azure Kubernetes Service
- **[Amazon ECS](/examples/deployments/ECS/)** - Deploy Dynamo on Amazon Elastic Container Service
- **[Router Standalone](/examples/deployments/router_standalone/)** - Standalone router deployment patterns
- **Google GKE** - _Coming soon_
## Runtime Examples
......
<!--
SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
SPDX-License-Identifier: Apache-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
-->
# Router Standalone
A toy implementation of KvRouter that demonstrates standalone usage without dependency on the dynamo runtime, etcd control plane, or nats event plane.
## Overview
This example shows how to use KvRouter in a standalone fashion to intelligently route requests across multiple vLLM workers based on KV cache overlap and load metrics. The router maintains a view of each worker's cached blocks and routes new requests to the worker with the best combination of cache overlap and available capacity.
> [!Tip]
> The main focus should be put on `router.py` as it contains the bulk of the non-boilerplate code and core routing logic.
## How It Works
### Core Architecture
The router uses a **RadixTree** data structure (written in Rust) to efficiently track which blocks each worker has cached. When a new request arrives, the router:
1. Uses `find_matches` to calculate overlap scores (number of matching blocks) between the request and each worker's cached blocks
2. Combines this with current load metrics to select the optimal worker
3. Routes the request to the chosen worker for processing
### Event-Driven Updates
The router receives two types of events from vLLM engines:
1. **KV Events**: Emitted automatically by vLLM engines when blocks are cached/evicted
2. **Load Metrics**: GPU usage percentage and waiting request count via custom callbacks
These events keep the router's view of worker state up-to-date in real-time.
### Alternative: Pure Predictive Routing
While not implemented in this example, the router can also operate in a pure predictive mode, estimating the radix tree state and loads based solely on the requests it receives, without relying on backend events. This requires simulating / mocking the block managing (e.g. eviction) and the scheduling policies of the backend engine. This is not recommended as there is no real-time feedback from the engines, and the router state may drift out of sync with the engine states. Nevertheless, this is WIP and can be supported in the future via our mocker engines.
## Components
> [!Note]
> This is a standalone toy implementation created for pedagogical purposes to demonstrate the core KvRouter concepts in isolation.
> Our default dynamo router is already very efficient and uses NATS for event communication and etcd for endpoint registration.
> This example intentionally avoids these production components to provide a simpler, self-contained demonstration of the routing logic and cache overlap mechanics.
>
> The toy communication pattern is as follows:
> - **OpenAI Compatible Frontend** – FastAPI application serving OpenAI compatible HTTP API.
> - **Router** – Standalone FastAPI endpoint for best worker selection, with core routines implemented in Rust exposed via Python bindings.
> - **Workers** – Served in-process within the frontend application to reduce complexity and boilerplate, rather than as separate endpoints.
### `router.py`
- **KvRouter**: Core routing logic using RadixTree
- Subscribes to KV cache events and load metrics from workers
- Implements `get_best_worker()` to select optimal routing destination
- Runs background tasks to periodically update worker states
### `worker.py`
- **VllmWorkers**: Manages multiple vLLM worker processes
- Each worker runs on a separate port with KV cache event emission enabled
- Provides `direct()` method for sending requests to specific workers
- Handles worker lifecycle and configuration
### `api.py`
- **RouterAPI**: Minimal FastAPI server providing OpenAI-compatible chat completions endpoint
- Enables in-process communication between router and workers
- Can be easily modified to use external communication (FastAPI clients, dynamo endpoints, etc.)
- Integrates with vLLM's OpenAI serving components for request preprocessing and response formatting
### `perf.sh`
- Benchmarking script using `aiperf` to test the router setup
- Configured for streaming chat completions with synthetic workloads
- Tests concurrent requests to evaluate routing performance
## Usage
1. **Install latest vLLM**:
```bash
uv pip uninstall ai-dynamo-vllm
uv pip install vllm==0.9.0
```
*Note: This uninstalls the local vLLM patch (`ai-dynamo-vllm`) and replaces it with the latest standard vLLM package.*
2. **Start the router API**:
For example:
```bash
python api.py \
--model deepseek-ai/DeepSeek-R1-Distill-Llama-8B \
--num-workers 4 \
--block-size 64 \
--base-kv-events-port 5557 \
--base-metrics-port 5657 \
--router-port 7000 \
--http-port 8000
```
3. **Ping the endpoint (optional)**:
```bash
./ping.sh
```
4. **Run performance benchmark**:
```bash
./perf.sh
```
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import logging
import uuid
from dataclasses import dataclass
from typing import Optional
import httpx
import uvicorn
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from router import RouterAPI, RouterRequest, RouterResponse # Add this import
from transformers import PreTrainedTokenizerBase
from vllm.config import ModelConfig
from vllm.entrypoints.openai.protocol import (
ChatCompletionRequest,
ErrorResponse,
RequestResponseMetadata,
)
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels
from vllm.inputs.data import TokensPrompt
from vllm.transformers_utils.tokenizer import get_tokenizer
from worker import VllmWorkers
from dynamo._core import compute_block_hash_for_seq_py
logger = logging.getLogger(__name__)
@dataclass(frozen=True)
class ServingParams:
model: str
block_size: int
num_workers: int
base_kv_events_port: int
base_metrics_port: int
router_port: int
http_port: int
class ServiceAPI:
def __init__(self, init_params: ServingParams):
self.init_params = init_params
self.app = FastAPI(title="Router API", version="0.0.1")
# These will be initialized in start()
self.workers: Optional[VllmWorkers] = None
self.tokenizer: Optional[PreTrainedTokenizerBase] = None
self.openai_serving_chat: Optional[OpenAIServingChat] = None
self.model_config: Optional[ModelConfig] = None
self.http_client: Optional[httpx.AsyncClient] = None
self.setup_routes()
def setup_routes(self):
@self.app.post("/v1/chat/completions")
async def chat_completions(request: ChatCompletionRequest):
if (
self.workers is None
or self.tokenizer is None
or self.openai_serving_chat is None
or self.http_client is None
):
return ErrorResponse(
error={
"message": "Service not ready",
"type": "service_unavailable",
"code": 503,
},
)
try:
# Determine max_tokens: use max_completion_tokens first, then max_tokens, or error
max_tokens_value = None
if (
hasattr(request, "max_completion_tokens")
and request.max_completion_tokens is not None
):
max_tokens_value = request.max_completion_tokens
elif hasattr(request, "max_tokens") and request.max_tokens is not None:
max_tokens_value = request.max_tokens
else:
return ErrorResponse(
error={
"message": "Either max_tokens or max_completion_tokens must be specified",
"type": "invalid_request_error",
"code": 400,
},
)
# Use vLLM's preprocessing to convert chat to prompt
# In newer vLLM, _preprocess_chat returns (conversation, engine_prompts) - 2 values
(
conversation,
engine_prompts,
) = await self.openai_serving_chat._preprocess_chat(
request,
self.tokenizer,
request.messages,
chat_template=request.chat_template or self.tokenizer.chat_template,
chat_template_content_format=self.openai_serving_chat.chat_template_content_format,
add_generation_prompt=request.add_generation_prompt,
add_special_tokens=False,
)
engine_prompt = engine_prompts[0]
# Convert request to sampling parameters with our determined max_tokens
sampling_params = request.to_sampling_params(
max_tokens=max_tokens_value,
logits_processor_pattern=None,
default_sampling_params={},
)
# Get best worker using HTTP request to router
tokens: list[int] = engine_prompt["prompt_token_ids"]
num_tokens = len(tokens)
if num_tokens == 0:
return ErrorResponse(
error={
"message": "Input prompt is empty",
"type": "invalid_request_error",
"code": 400,
}
)
# It is much preferred to communicate block hashes to the router instead of
# raw text prompts or tokens, especially when over network using pydantic validation,
# as block hashes can be orders of magnitude smaller.
# Note that the hashing function needs to be deterministic (across processes),
# and has to be consistent with the hashing function used to send KV Events to the Router.
local_hashes = compute_block_hash_for_seq_py(
tokens, self.init_params.block_size
)
# Call router via HTTP
try:
router_request = RouterRequest(
local_hashes=local_hashes, num_tokens=num_tokens
)
router_response = await self.http_client.post(
f"http://localhost:{self.init_params.router_port}/find_best_worker",
json=router_request.model_dump(),
timeout=1,
)
router_response.raise_for_status()
router_data = RouterResponse.model_validate(router_response.json())
best_worker_id = router_data.worker_id
except (httpx.RequestError, httpx.HTTPStatusError) as e:
logger.error(f"Router request failed: {e}")
return ErrorResponse(
error={
"message": "Router service unavailable",
"type": "service_unavailable",
"code": 503,
}
)
logger.info(f"Selected worker {best_worker_id} for request")
# Generate request ID
request_id = f"chatcmpl-{uuid.uuid4()}"
request_metadata = RequestResponseMetadata(request_id=request_id)
# Convert engine_prompt dict to TokensPrompt object
tokens_prompt = TokensPrompt(prompt_token_ids=tokens)
logger.info(f"Created TokensPrompt with {len(tokens)} tokens")
# Get the generator from the selected worker with sampling params
result_generator = self.workers.direct(
tokens_prompt, best_worker_id, sampling_params
)
assert request.stream
# Use vLLM's streaming response generator
return StreamingResponse(
self.openai_serving_chat.chat_completion_stream_generator(
request,
result_generator,
request_id,
self.init_params.model,
conversation,
self.tokenizer,
request_metadata,
enable_force_include_usage=False,
),
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "Connection": "keep-alive"},
)
except Exception as e:
logger.error(f"Error processing request: {e}")
return ErrorResponse(
error={"message": str(e), "type": "internal_error", "code": 500}
)
async def initialize_services(self):
"""Initialize workers, HTTP client, and OpenAI serving components"""
logger.info("Initializing VllmWorkers...")
self.workers = VllmWorkers(
model=self.init_params.model,
block_size=self.init_params.block_size,
base_kv_events_port=self.init_params.base_kv_events_port,
base_metrics_port=self.init_params.base_metrics_port,
num_workers=self.init_params.num_workers,
)
# Initialize HTTP client for router communication
self.http_client = httpx.AsyncClient()
logger.info("Initializing OpenAI serving components...")
# Initialize tokenizer and model config
self.tokenizer = get_tokenizer(self.init_params.model)
# Create a mock model config
self.model_config = ModelConfig(
model=self.init_params.model,
enforce_eager=True,
)
# Initialize OpenAI serving models
base_model_paths = [
BaseModelPath(
name=self.init_params.model, model_path=self.init_params.model
)
]
openai_serving_models = OpenAIServingModels(
engine_client=None,
model_config=self.model_config,
base_model_paths=base_model_paths,
)
# Initialize OpenAI serving chat
self.openai_serving_chat = OpenAIServingChat(
engine_client=None,
model_config=self.model_config,
models=openai_serving_models,
response_role="assistant",
request_logger=None,
chat_template=None,
chat_template_content_format="auto",
)
logger.info("Waiting 2 seconds for services to initialize...")
await asyncio.sleep(2)
logger.info("Services initialized successfully!")
async def start(self):
"""Start the API server"""
# Initialize services first
await self.initialize_services()
# Start the API server
logger.info(f"Starting API server on port {self.init_params.http_port}")
config = uvicorn.Config(
self.app, host="0.0.0.0", port=self.init_params.http_port, log_level="info"
)
server = uvicorn.Server(config)
await server.serve()
async def shutdown(self):
"""Proper shutdown handler"""
logger.info("Shutting down API...")
if self.http_client:
await self.http_client.aclose()
logger.info("API shutdown completed")
def main():
parser = argparse.ArgumentParser(description="Router API Server")
# Arguments from worker.py
parser.add_argument(
"--model",
type=str,
default="deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
help="Model name to use",
)
# Common arguments
parser.add_argument(
"--block-size", type=int, default=64, help="Block size for caching"
)
parser.add_argument(
"--num-workers", type=int, default=2, help="Number of worker processes"
)
parser.add_argument(
"--base-kv-events-port", type=int, default=5557, help="Base port for KV events"
)
parser.add_argument(
"--base-metrics-port", type=int, default=5657, help="Base port for metrics"
)
parser.add_argument(
"--router-port",
type=int,
default=7000,
help="Port for router service",
)
parser.add_argument(
"--http-port", type=int, default=8000, help="Port to serve the API on"
)
args = parser.parse_args()
# Setup logging
logging.basicConfig(level=logging.INFO)
init_params = ServingParams(
model=args.model,
block_size=args.block_size,
num_workers=args.num_workers,
base_kv_events_port=args.base_kv_events_port,
base_metrics_port=args.base_metrics_port,
router_port=args.router_port,
http_port=args.http_port,
)
# Create both services
api = ServiceAPI(init_params=init_params)
router_api = RouterAPI(
block_size=args.block_size,
num_workers=args.num_workers,
base_kv_events_port=args.base_kv_events_port,
base_metrics_port=args.base_metrics_port,
port=args.router_port,
)
async def run_with_shutdown():
try:
# Start both services concurrently
await asyncio.gather(
api.start(), router_api.start(), return_exceptions=True
)
except KeyboardInterrupt:
logger.info("Received KeyboardInterrupt, shutting down services...")
except Exception as e:
logger.exception(f"Unhandled exception: {e}")
finally:
await api.shutdown()
try:
asyncio.run(run_with_shutdown())
except KeyboardInterrupt:
# Just in case KeyboardInterrupt happens outside of the event loop
logger.info("Force shutdown via KeyboardInterrupt.")
if __name__ == "__main__":
main()
#/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
model=deepseek-ai/DeepSeek-R1-Distill-Llama-8B
type=chat
endpoint=/v1/chat/completions
port=8000
isl=3000
osl=100
concurrency=25
num_requests=100
num_unique_prompts=10
seed=42
aiperf profile \
--model ${model} \
--tokenizer ${model} \
--endpoint-type ${type} \
--endpoint ${endpoint} \
--streaming \
--url http://localhost:${port} \
--synthetic-input-tokens-mean ${isl} \
--synthetic-input-tokens-stddev 0 \
--output-tokens-mean ${osl} \
--output-tokens-stddev 0 \
--extra-inputs max_tokens:${osl} \
--extra-inputs min_tokens:${osl} \
--extra-inputs ignore_eos:true \
--extra-inputs "{\"nvext\":{\"ignore_eos\":true}}" \
--concurrency ${concurrency} \
--request-count ${num_requests} \
--num-dataset-entries ${num_unique_prompts} \
--random-seed ${seed} \
-H 'Authorization: Bearer NOT USED' \
-H 'Accept: text/event-stream'
#/bin/bash
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
curl -X POST http://localhost:8000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Accept: text/event-stream" \
-d '{
"model": "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": true,
"max_tokens": 100
}'
\ No newline at end of file
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import asyncio
import json
import logging
from contextlib import asynccontextmanager
from typing import List
import numpy as np
import uvicorn
import zmq
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from dynamo._core import RadixTree, ZmqKvEventListener
logger = logging.getLogger(__name__)
class RouterRequest(BaseModel):
local_hashes: List[int]
num_tokens: int
class RouterResponse(BaseModel):
worker_id: int
class LoadMetrics(BaseModel):
kv_cache_usage: float
num_waiting_reqs: int
def setup_zmq_subscriber(context: zmq.Context, endpoint: str) -> zmq.Socket[bytes]:
socket = context.socket(zmq.SUB)
socket.connect(endpoint)
socket.setsockopt(zmq.SUBSCRIBE, b"") # Subscribe to all messages
socket.setsockopt(zmq.CONFLATE, 1) # Only keep latest message
socket.setsockopt(zmq.RCVTIMEO, 1) # 1ms timeout (very short)
return socket
class KvRouter:
def __init__(
self,
block_size: int = 64,
num_workers: int = 4,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
):
self.num_workers = num_workers
self.block_size = block_size
self.radix_tree = RadixTree()
self.kv_usages = [0.0] * num_workers
self.waitings = [0] * num_workers
self.context = zmq.Context()
self.load_listeners = [
setup_zmq_subscriber(
self.context, f"tcp://localhost:{base_metrics_port + worker_id}"
)
for worker_id in range(num_workers)
]
self.kv_listeners = [
ZmqKvEventListener(
f"tcp://localhost:{base_kv_events_port + worker_id}", "", block_size
)
for worker_id in range(num_workers)
]
self.background_tasks: list[asyncio.Task] = []
logger.info("Router initialized")
async def start_background_tasks(self):
"""Start background tasks for load and indexer updates"""
logger.info("Starting router background tasks...")
self.background_tasks.append(asyncio.create_task(self.periodic_update_load()))
self.background_tasks.append(
asyncio.create_task(self.periodic_update_indexer())
)
async def periodic_update_load(self):
async def update_load(worker_id: int):
while True:
try:
metrics_dict = self.load_listeners[worker_id].recv_json(zmq.NOBLOCK)
metrics = LoadMetrics.model_validate(metrics_dict)
self.kv_usages[worker_id] = metrics.kv_cache_usage
self.waitings[worker_id] = metrics.num_waiting_reqs
except zmq.Again:
pass
except Exception as e:
logger.warning(
f"Error receiving metrics for worker {worker_id}: {e}"
)
await asyncio.sleep(0.1)
for worker_id in range(self.num_workers):
asyncio.create_task(update_load(worker_id))
async def periodic_update_indexer(self):
async def update_tree(worker_id: int):
while True:
try:
kv_events: list[str] = await self.kv_listeners[
worker_id
].get_events()
for event in kv_events:
event = json.loads(event)
self.radix_tree.apply_event(
worker_id, json.dumps(event).encode("utf-8")
)
except zmq.Again:
pass
except Exception as e:
logger.warning(
f"Error receiving KV events for worker {worker_id}: {e}"
)
await asyncio.sleep(0.1)
for worker_id in range(self.num_workers):
asyncio.create_task(update_tree(worker_id))
async def get_best_worker(self, local_hashes: list[int], num_tokens: int) -> int:
try:
if num_tokens <= 0:
raise ValueError("num_tokens must be positive")
# local_hashes can be empty
raw_scores = self.radix_tree.find_matches(local_hashes).scores
overlap_scores = {
worker_id: raw_scores.get(worker_id, 0) * self.block_size / num_tokens
for worker_id in range(self.num_workers)
}
kv_usages = self.kv_usages[:]
waitings = self.waitings[:]
max_waiting = max(waitings) if waitings else 0
waitings_normalized = [
waiting / max_waiting if max_waiting else 0.0 for waiting in waitings
]
logits = []
for worker_id in range(self.num_workers):
overlap = overlap_scores[worker_id]
usage = kv_usages[worker_id]
waiting = waitings_normalized[worker_id]
logit = 2 * overlap - usage - waiting
logits.append(logit)
logger.info(
f"worker_id: {worker_id}, logit = 2 * {overlap:.3f} - {usage:.3f} - {waiting:.3f} = {logit:.3f}"
)
logits_array = np.array(logits)
best_worker_id = int(
np.random.choice(np.flatnonzero(logits_array == logits_array.max()))
)
# this is a predictive update which will be reset as new metrics are polled
# but it is helpful for handling short bursts of highly concurrent requests
# we omit updating the gpu_usage_perc as done in the rusty router for simplicity
# as this requires obtaining num_gpu_blocks from the engines and can be intrusive
# no need for async lock here, as the state is intended to be continuously overwritten
self.waitings[best_worker_id] += 1
return best_worker_id
except Exception as e:
logger.error(f"Error in get_best_worker: {e}")
raise
async def shutdown(self):
"""Shutdown ZMQ listeners, context, and background tasks"""
logger.info("Shutting down KvRouter...")
# Cancel background tasks
for task in self.background_tasks:
task.cancel()
if self.background_tasks:
await asyncio.gather(*self.background_tasks, return_exceptions=True)
# Close load listeners (ZMQ sockets)
for listener in self.load_listeners:
try:
listener.close()
except Exception as e:
logger.error(f"Error closing load listener: {e}")
# Terminate ZMQ context
try:
self.context.term()
logger.info("ZMQ context terminated successfully")
except Exception as e:
logger.error(f"Error terminating ZMQ context: {e}")
logger.info("KvRouter shutdown completed")
class RouterAPI:
def __init__(
self,
block_size: int = 64,
num_workers: int = 4,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
port: int = 7000,
):
self.port = port
self.block_size = block_size
self.num_workers = num_workers
self.base_kv_events_port = base_kv_events_port
self.base_metrics_port = base_metrics_port
self.router = None
self.app = FastAPI(
title="KV Router API", version="0.0.1", lifespan=self.lifespan
)
self.setup_routes()
@asynccontextmanager
async def lifespan(self, app: FastAPI):
# Startup
self.router = KvRouter(
block_size=self.block_size,
num_workers=self.num_workers,
base_kv_events_port=self.base_kv_events_port,
base_metrics_port=self.base_metrics_port,
)
await self.router.start_background_tasks()
logger.info("Router API started successfully")
yield
# Shutdown
if self.router:
await self.router.shutdown()
def setup_routes(self):
@self.app.post("/find_best_worker", response_model=RouterResponse)
async def find_best_worker(request: RouterRequest):
if self.router is None:
raise HTTPException(status_code=503, detail="Router not initialized")
try:
worker_id = await self.router.get_best_worker(
request.local_hashes, request.num_tokens
)
return RouterResponse(worker_id=worker_id)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Error finding best worker: {e}")
raise HTTPException(status_code=500, detail="Internal server error")
async def start(self):
"""Start the router API server"""
logger.info(f"Starting Router API server on port {self.port}")
config = uvicorn.Config(
self.app, host="0.0.0.0", port=self.port, log_level="info"
)
server = uvicorn.Server(config)
await server.serve()
def main():
parser = argparse.ArgumentParser(description="KV Router API Server")
parser.add_argument(
"--block-size", type=int, default=64, help="Block size for caching"
)
parser.add_argument(
"--num-workers", type=int, default=2, help="Number of worker processes"
)
parser.add_argument(
"--base-kv-events-port", type=int, default=5557, help="Base port for KV events"
)
parser.add_argument(
"--base-metrics-port", type=int, default=5657, help="Base port for metrics"
)
parser.add_argument(
"--port", type=int, default=7000, help="Port to serve the Router API on"
)
args = parser.parse_args()
# Setup logging
logging.basicConfig(level=logging.INFO)
api = RouterAPI(
block_size=args.block_size,
num_workers=args.num_workers,
base_kv_events_port=args.base_kv_events_port,
base_metrics_port=args.base_metrics_port,
port=args.port,
)
async def run_with_shutdown():
try:
await api.start()
except KeyboardInterrupt:
logger.info(
"Received KeyboardInterrupt, shutting down Router API server..."
)
except Exception as e:
logger.exception(f"Unhandled exception: {e}")
try:
asyncio.run(run_with_shutdown())
except KeyboardInterrupt:
logger.info("Force shutdown via KeyboardInterrupt.")
if __name__ == "__main__":
main()
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
if "PYTHONHASHSEED" not in os.environ:
os.environ["PYTHONHASHSEED"] = "0"
import logging
import uuid
from typing import AsyncGenerator, Optional
import zmq
from vllm.config import (
CacheConfig,
ModelConfig,
ObservabilityConfig,
SchedulerConfig,
VllmConfig,
)
from vllm.distributed.kv_events import KVEventsConfig
from vllm.inputs.data import TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.async_llm import AsyncLLM
from vllm.v1.metrics.loggers import StatLoggerBase
from vllm.v1.metrics.stats import IterationStats, SchedulerStats
logger = logging.getLogger(__name__)
class MetricsPublisher(StatLoggerBase):
"""Stat logger publisher. Wrapper for the WorkerMetricsPublisher to match the StatLoggerBase interface."""
def __init__(self, port: int) -> None:
self.context = zmq.Context()
self.socket = self.context.socket(zmq.PUB)
self.socket.bind(f"tcp://*:{port}")
logger.info(f"ZMQ publisher initialized on port {port}")
def record(
self,
scheduler_stats: SchedulerStats,
iteration_stats: Optional[IterationStats],
engine_idx: int = 0,
):
# Send metrics over ZMQ
metrics_data = {
"num_waiting_reqs": scheduler_stats.num_waiting_reqs,
"kv_cache_usage": scheduler_stats.kv_cache_usage,
}
self.socket.send_json(metrics_data)
def log_engine_initialized(self) -> None:
pass
class LoggerFactory:
"""Factory for creating stat logger publishers. Required by vLLM."""
def __init__(self, port: int) -> None:
self.port = port
def __call__(self, vllm_config: VllmConfig, dp_rank: int) -> StatLoggerBase:
return MetricsPublisher(port=self.port)
class VllmWorkers:
def __init__(
self,
model: str = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B",
block_size: int = 64,
base_kv_events_port: int = 5557,
base_metrics_port: int = 5657,
num_workers: int = 1,
):
os.environ["VLLM_NO_USAGE_STATS"] = "1"
self.num_workers = num_workers
self.llms: list[AsyncLLM] = []
for worker_id in range(num_workers):
os.environ["CUDA_VISIBLE_DEVICES"] = str(worker_id)
zmq_port = base_kv_events_port + worker_id
metrics_port = base_metrics_port + worker_id
model_config = ModelConfig(
model=model,
enforce_eager=True,
)
cache_config = CacheConfig(
block_size=block_size,
enable_prefix_caching=True,
)
kv_events_config = KVEventsConfig(
enable_kv_cache_events=True,
publisher="zmq",
endpoint=f"tcp://*:{zmq_port}",
)
scheduler_config = SchedulerConfig(
scheduler_cls="vllm.v1.core.sched.scheduler.Scheduler"
)
observability_config = ObservabilityConfig()
vllm_config = VllmConfig(
model_config=model_config,
cache_config=cache_config,
kv_events_config=kv_events_config,
scheduler_config=scheduler_config,
observability_config=observability_config,
)
self.llms.append(
AsyncLLM.from_vllm_config(
vllm_config=vllm_config,
stat_loggers=[LoggerFactory(port=metrics_port)],
)
)
async def direct(
self, prompt: TokensPrompt, worker_id: int, sampling_params: SamplingParams
) -> AsyncGenerator[RequestOutput, None]:
outputs = self.llms[worker_id].generate(
prompt,
sampling_params=sampling_params,
request_id=str(uuid.uuid4()),
)
async for output in outputs:
yield output
......@@ -245,5 +245,4 @@ When processing multimodal requests:
## See Also
- [vLLM Router Standalone](../router_standalone/) - Original vLLM version
- [TensorRT-LLM KV Event Documentation](https://nvidia.github.io/TensorRT-LLM/0.21.0/examples/llm_inference_kv_events.html)
......@@ -34,7 +34,6 @@ Platform-specific deployment guides for production environments:
- **[Amazon EKS](https://github.com/ai-dynamo/dynamo/blob/main/examples/deployments/EKS/)** - Deploy Dynamo on Amazon Elastic Kubernetes Service
- **[Azure AKS](https://github.com/ai-dynamo/dynamo/blob/main/examples/deployments/AKS/)** - Deploy Dynamo on Azure Kubernetes Service
- **[Amazon ECS](https://github.com/ai-dynamo/dynamo/blob/main/examples/deployments/ECS/)** - Deploy Dynamo on Amazon Elastic Container Service
- **[Router Standalone](https://github.com/ai-dynamo/dynamo/blob/main/examples/deployments/router_standalone/)** - Standalone router deployment patterns
- **Google GKE** - _Coming soon_
## Runtime Examples
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment