Unverified Commit 031590fc authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: vllm prefill router (#3155)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 51b4cd7e
......@@ -66,6 +66,24 @@ First, start the vLLM worker engines in a terminal.
--tensor-parallel-size 2
```
#### Prefill Workers
You can also launch separate decode and prefill workers for disaggregated serving. This allows you to dedicate specific GPUs to prefill (prompt processing) and decode (token generation) tasks:
```bash
# Launch 4 decode workers (GPUs 0-3)
./run_engines.sh \
--num-workers 4 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Launch 4 prefill workers (GPUs 4-7)
./run_engines.sh \
--prefills \
--num-workers 4 \
--base-gpu-offset 4 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
```
#### Alternative: Launch vLLM Mock Workers
We also supports running lightweight mock engines that simulate vLLM behavior without performing actual model inference. Mocker engines are useful for testing router logic and performance without GPU requirements. Use the `--mockers` flag to run mocker engines instead of real vLLM workers.
......@@ -106,6 +124,27 @@ python -m dynamo.frontend --help
For detailed explanations of router arguments (especially KV cache routing parameters), see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md).
#### Launching a Prefill Router (Optional)
If you're using disaggregated serving with separate prefill and decode workers, you should also launch a prefill router. The prefill router handles routing prefill requests to dedicated prefill workers. When using a prefill router, it's recommended to start the frontend (decode router) with `--kv-overlap-score-weight 0` for pure load balancing (as prefix-aware routing is now handled by the prefill router):
```bash
# Start the decode router with pure load balancing
python -m dynamo.frontend \
--router-mode kv \
--kv-cache-block-size 64 \
--router-reset-states \
--http-port 8000 \
--kv-overlap-score-weight 0
# In another terminal, start the prefill router (currently only supports vLLM)
python -m dynamo.vllm_prefill_router \
--namespace dynamo \
--block-size 64
```
The prefill router will automatically coordinate with the decode router to handle request routing between prefill and decode workers.
**Note**: If you're unsure whether your backend engines correctly emit KV events for certain models (e.g., hybrid models like gpt-oss or nemotron nano 2), use the `--no-kv-events` flag to disable KV event tracking and use approximate KV indexing instead:
```bash
......
......@@ -3,8 +3,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
# Get port from first argument, default to 8080 if not provided
PORT=${1:-8080}
# Get port from first argument, default to 8000 if not provided
PORT=${1:-8000}
curl -X POST http://localhost:${PORT}/v1/chat/completions \
-H "Content-Type: application/json" \
......
......@@ -309,7 +309,7 @@ def main():
"--url",
type=str,
nargs="+", # Accept multiple URLs
default=["http://localhost:8080"],
default=["http://localhost:8000"],
# default=["http://localhost:8090", "http://localhost:8090"],
help="Server URL(s). Can specify multiple URLs for parallel benchmarking",
)
......
......@@ -118,7 +118,7 @@ def main():
parser.add_argument(
"--url",
type=str,
default="http://localhost:8080",
default="http://localhost:8000",
help="Server URL",
)
parser.add_argument(
......
......@@ -8,6 +8,8 @@ NUM_WORKERS=8
MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
TENSOR_PARALLEL_SIZE=1
USE_MOCKERS=false
USE_PREFILLS=false
BASE_GPU_OFFSET=0
EXTRA_ARGS=()
# Parse arguments
......@@ -29,6 +31,14 @@ while [[ $# -gt 0 ]]; do
USE_MOCKERS=true
shift
;;
--prefills)
USE_PREFILLS=true
shift
;;
--base-gpu-offset)
BASE_GPU_OFFSET="$2"
shift 2
;;
--)
shift
EXTRA_ARGS+=("$@")
......@@ -71,14 +81,22 @@ if ! [[ "$TENSOR_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$TENSOR_PARALLEL_SIZE" -lt
exit 1
fi
if ! [[ "$BASE_GPU_OFFSET" =~ ^[0-9]+$ ]]; then
echo "Error: BASE_GPU_OFFSET must be a non-negative integer"
exit 1
fi
# Calculate total GPUs needed
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE))
LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1))
echo "Configuration:"
echo " Engine Type: $([ "$USE_MOCKERS" = true ] && echo "Mocker" || echo "vLLM")"
echo " Worker Type: $([ "$USE_PREFILLS" = true ] && echo "Prefill" || echo "Decode")"
echo " Workers: $NUM_WORKERS"
echo " Model: $MODEL_PATH"
echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE"
echo " Total GPUs needed: $TOTAL_GPUS_NEEDED"
echo " GPU Range: $BASE_GPU_OFFSET-$LAST_GPU"
echo " Engine args: ${EXTRA_ARGS[*]}"
echo ""
......@@ -93,14 +111,15 @@ cleanup() {
trap cleanup SIGINT SIGTERM
echo "Starting $NUM_WORKERS workers..."
WORKER_TYPE=$([ "$USE_PREFILLS" = true ] && echo "prefill" || echo "decode")
echo "Starting $NUM_WORKERS $WORKER_TYPE workers..."
for i in $(seq 1 $NUM_WORKERS); do
{
echo "[Worker-$i] Starting..."
echo "[${WORKER_TYPE^} Worker-$i] Starting..."
# Calculate GPU indices for this worker
START_GPU=$(( (i - 1) * TENSOR_PARALLEL_SIZE ))
# Calculate GPU indices for this worker (with base offset)
START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * TENSOR_PARALLEL_SIZE ))
END_GPU=$(( START_GPU + TENSOR_PARALLEL_SIZE - 1 ))
# Build CUDA_VISIBLE_DEVICES string
......@@ -124,17 +143,22 @@ for i in $(seq 1 $NUM_WORKERS); do
--endpoint dyn://test.mocker.generate \
"${EXTRA_ARGS[@]}"
else
echo "[Worker-$i] Using GPUs: $GPU_DEVICES"
echo "[${WORKER_TYPE^} Worker-$i] Using GPUs: $GPU_DEVICES"
# Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing
VLLM_ARGS=()
VLLM_ARGS+=("--model" "$MODEL_PATH")
VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
if [ "$USE_PREFILLS" = true ]; then
VLLM_ARGS+=("--is-prefill-worker")
fi
VLLM_ARGS+=("${EXTRA_ARGS[@]}")
exec env PYTHONHASHSEED=0 CUDA_VISIBLE_DEVICES=$GPU_DEVICES python -m dynamo.vllm \
--model "$MODEL_PATH" \
--endpoint dyn://test.vllm.generate \
--tensor-parallel-size $TENSOR_PARALLEL_SIZE \
"${EXTRA_ARGS[@]}"
"${VLLM_ARGS[@]}"
fi
} &
PIDS+=($!)
echo "Started worker $i (PID: $!)"
echo "Started $WORKER_TYPE worker $i (PID: $!)"
done
echo "All workers started. Press Ctrl+C to stop."
......
......@@ -4,11 +4,29 @@
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# run ingress
python -m dynamo.frontend --router-mode kv --http-port=8000 &
# Set deterministic hash for KV event IDs
export PYTHONHASHSEED=0
# Common configuration
MODEL="Qwen/Qwen3-0.6B"
BLOCK_SIZE=64
# run frontend + KV router
python -m dynamo.frontend \
--router-mode kv \
--http-port 8000 \
--router-reset-states &
# run workers
# --enforce-eager is added for quick deployment. for production use, need to remove this flag
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --connector none &
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager \
--connector none &
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --connector none
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager \
--connector none
......@@ -2,19 +2,48 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
set -e
trap 'echo Cleaning up...; kill 0' EXIT
# run ingress
python -m dynamo.frontend --router-mode kv --http-port=8000 &
# Set deterministic hash for KV event IDs
export PYTHONHASHSEED=0
# Common configuration
MODEL="Qwen/Qwen3-0.6B"
BLOCK_SIZE=64
# run decode router with kv-overlap-score-weight 0 for pure load balancing
python -m dynamo.frontend \
--router-mode kv \
--http-port 8000 \
--kv-overlap-score-weight 0 \
--router-reset-states &
# routing will happen between the two decode workers
# run prefill router service
python -m dynamo.vllm_prefill_router \
--namespace dynamo \
--block-size $BLOCK_SIZE &
# two decode workers
# --enforce-eager is added for quick deployment. for production use, need to remove this flag
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager &
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager &
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager &
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager &
# two prefill workers
CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.vllm \
--model Qwen/Qwen3-0.6B \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager \
--is-prefill-worker &
CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager \
--is-prefill-worker
......@@ -94,9 +94,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine,
default_sampling_params,
prefill_worker_client=None,
prefill_router_client=None,
prefill_router_free_client=None,
):
super().__init__(runtime, component, engine, default_sampling_params)
self.prefill_worker_client = prefill_worker_client
self.prefill_router_client = prefill_router_client
self.prefill_router_free_client = prefill_router_free_client
self.can_prefill = 0
self._prefill_check_task = None
......@@ -143,7 +147,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
# TODO Change to prefill queue
# TODO: Change to prefill queue
# TODO: (PeaBrane) eventually, do not use a router_client and a free_client directly.
# This is least intrusive for now, but quite error prone. Should consider (major) refactoring
# TODO: (PeaBrane) longer term, decode workers should not handle prefill routing at all.
# Prefill routing logic should be integrated directly into the frontend service potentially.
if self.can_prefill:
# Create a copy for prefill with specific modifications
prefill_sampling_params = deepcopy(sampling_params)
......@@ -162,12 +170,37 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"request_id": request_id,
}
used_prefill_router = False
try:
prefill_response = await anext(
await self.prefill_worker_client.round_robin(
prefill_request, context=context
prefill_worker_id = None
if (
self.prefill_router_client is not None
and self.prefill_router_client.instance_ids()
):
used_prefill_router = True
best_worker_response = await anext(
await self.prefill_router_client.generate(
{
"token_ids": request["token_ids"],
"request_id": request_id,
}
)
)
)
prefill_worker_id = best_worker_response.data().get("worker_id")
if prefill_worker_id is not None:
prefill_response = await anext(
await self.prefill_worker_client.direct(
prefill_request, prefill_worker_id, context=context
)
)
else:
prefill_response = await anext(
await self.prefill_worker_client.round_robin(
prefill_request, context=context
)
)
except Exception as e:
# TODO: Cancellation does not propagate until the first token is received
if context.is_stopped() or context.is_killed():
......@@ -176,6 +209,15 @@ class DecodeWorkerHandler(BaseWorkerHandler):
return
raise e
finally:
if used_prefill_router:
await anext(
await self.prefill_router_free_client.generate(
{"request_id": request_id}
)
)
logger.debug(f"Freed router state for request {request_id}")
prefill_response = MyRequestOutput.model_validate_json(
prefill_response.data()
)
......
......@@ -5,6 +5,7 @@ import asyncio
import logging
import os
import signal
from typing import Optional
import uvloop
from vllm.distributed.kv_events import ZmqEventPublisher
......@@ -87,6 +88,40 @@ async def worker(runtime: DistributedRuntime):
logger.debug("Worker function completed, exiting...")
def setup_kv_event_publisher(
config: Config,
component,
generate_endpoint,
vllm_config,
) -> Optional[ZmqKvEventPublisher]:
"""
Set up KV event publisher for prefix caching if enabled.
Returns:
ZmqKvEventPublisher if prefix caching is enabled, None otherwise.
"""
if not config.engine_args.enable_prefix_caching:
return None
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
config.engine_args.kv_events_config.endpoint,
data_parallel_rank=config.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
logger.info(f"Worker reading KV events from {zmq_endpoint}")
return kv_publisher
def setup_vllm_engine(config, stat_logger=None):
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
......@@ -137,9 +172,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
engine_client, _, default_sampling_params = setup_vllm_engine(config)
# TODO register_prefill in similar vein to register_llm
engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
handler = PrefillWorkerHandler(
runtime, component, engine_client, default_sampling_params
......@@ -184,6 +217,20 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
prefill_router_client = (
await runtime.namespace(config.namespace)
.component("prefill_router") # TODO don't hardcode
.endpoint("find_best_worker")
.client()
)
prefill_router_free_client = (
await runtime.namespace(config.namespace)
.component("prefill_router") # TODO don't hardcode
.endpoint("free")
.client()
)
prefill_worker_client = (
await runtime.namespace(config.namespace)
.component("prefill") # TODO don't hardcode
......@@ -213,25 +260,15 @@ async def init(runtime: DistributedRuntime, config: Config):
engine_client,
default_sampling_params,
prefill_worker_client,
prefill_router_client,
prefill_router_free_client,
)
if config.engine_args.enable_prefix_caching:
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
config.engine_args.kv_events_config.endpoint,
data_parallel_rank=config.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
logger.info(f"Reading Events from {zmq_endpoint}")
# Set up KV event publisher for prefix caching if enabled
kv_publisher = setup_kv_event_publisher(
config, component, generate_endpoint, vllm_config
)
if kv_publisher:
handler.kv_publisher = kv_publisher
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
try:
from ._version import __version__
except Exception:
try:
from importlib.metadata import version as _pkg_version
__version__ = _pkg_version("ai-dynamo")
except Exception:
__version__ = "0.0.0+unknown"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Centralized Prefill Router Service
Usage: python -m dynamo.vllm_prefill_router [args]
This service provides a single KV-aware router for all prefill workers in a
disaggregated vLLM deployment. Instead of each decode worker maintaining its own
round-robin client to prefill workers, this service uses KvRouter to make
intelligent routing decisions based on KV cache state.
"""
import argparse
import asyncio
import logging
import os
from typing import Optional
import uvloop
from dynamo.llm import KvRouter, KvRouterConfig
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class PrefillRouterHandler:
"""Handles routing requests to prefill workers using KV-aware routing."""
def __init__(self, runtime: DistributedRuntime, namespace: str, block_size: int):
self.runtime = runtime
self.namespace = namespace
self.block_size = block_size
self.kv_router: Optional[KvRouter] = None
self.prefill_client: Optional[Client] = None
async def initialize(self):
"""Initialize the KV router for prefill workers."""
try:
# Get prefill endpoint
prefill_endpoint = (
self.runtime.namespace(self.namespace)
.component("prefill")
.endpoint("generate")
)
self.prefill_client = await prefill_endpoint.client()
# Create KvRouter with specified configuration
kv_router_config = KvRouterConfig(
router_track_active_blocks=False, # this won't matter for prefill workers
router_reset_states=True, # reset for now
)
self.kv_router = KvRouter(
endpoint=prefill_endpoint,
block_size=self.block_size,
kv_router_config=kv_router_config,
)
logger.info(
f"KvRouter initialized for prefill workers with block_size={self.block_size}"
)
except Exception as e:
logger.error(f"Failed to initialize KvRouter: {e}")
raise
async def find_best_worker(self, request):
"""
Find the best prefill worker based on KV cache state.
This endpoint is called by decode workers to determine which prefill
worker should handle a request.
"""
if self.kv_router is None:
# Fallback to round-robin if router not initialized
logger.warning("KvRouter not initialized, falling back to round-robin")
yield {
"status": "fallback",
"message": "Router not initialized",
}
return
try:
# Get current prefill workers
if self.prefill_client is None:
yield {
"status": "error",
"message": "Prefill client not initialized",
}
return
instance_ids = self.prefill_client.instance_ids()
if not instance_ids:
yield {
"status": "error",
"message": "No prefill workers available",
}
return
logger.debug(f"Routing request with {len(instance_ids)} available workers")
# Validate required fields
if "token_ids" not in request:
raise ValueError("Missing required field 'token_ids' in request")
if "request_id" not in request:
raise ValueError("Missing required field 'request_id' in request")
token_ids = request["token_ids"]
request_id = request["request_id"]
# Use KvRouter to find the best worker with state updates
best_worker_id, overlap_blocks = await self.kv_router.find_best_match(
request_id=request_id,
tokens=token_ids,
update_states=True, # Always update states for prefill routing
)
logger.debug(
f"Selected worker {best_worker_id} with {overlap_blocks} overlap blocks for request {request_id}"
)
yield {
"worker_id": best_worker_id,
"overlap_blocks": overlap_blocks,
}
except Exception as e:
logger.error(f"Error finding best worker: {e}")
yield {
"status": "error",
"message": str(e),
}
async def free(self, request):
"""
Free resources associated with a request.
This endpoint is called when a request is completed to clean up
router state.
"""
if self.kv_router is None:
logger.warning("KvRouter not initialized")
yield {
"status": "error",
"message": "Router not initialized",
}
return
try:
if "request_id" not in request:
raise ValueError("Missing required field 'request_id' in request")
request_id = request["request_id"]
# Free the request from the router
await self.kv_router.free(request_id=request_id)
logger.debug(f"Freed resources for request {request_id}")
yield {
"status": "success",
"message": f"Request {request_id} freed successfully",
}
except Exception as e:
logger.error(f"Error freeing request: {e}")
yield {
"status": "error",
"message": str(e),
}
def parse_args():
parser = argparse.ArgumentParser(
description="Dynamo Prefill Router Service: Centralized KV-aware routing for prefill workers",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--namespace",
type=str,
default=os.environ.get("DYN_NAMESPACE", "dynamo"),
help="Dynamo namespace for discovering prefill workers (default: dynamo or DYN_NAMESPACE env var)",
)
parser.add_argument(
"--block-size",
type=int,
default=128,
help="KV cache block size for routing decisions (default: 128)",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level (default: INFO)",
)
return parser.parse_args()
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
"""Main worker function for the prefill router service."""
args = parse_args()
# Set logging level
logging.getLogger().setLevel(getattr(logging, args.log_level))
logger.info(f"Starting Prefill Router Service for namespace: {args.namespace}")
logger.debug(f"Configuration: block_size={args.block_size}")
# Create service component
component = runtime.namespace(args.namespace).component("prefill_router")
await component.create_service()
# Create handler
handler = PrefillRouterHandler(runtime, args.namespace, args.block_size)
await handler.initialize()
# Expose endpoints
find_best_worker_endpoint = component.endpoint("find_best_worker")
free_endpoint = component.endpoint("free")
logger.debug("Starting to serve find_best_worker and free endpoints...")
try:
await asyncio.gather(
find_best_worker_endpoint.serve_endpoint(
handler.find_best_worker,
graceful_shutdown=True,
metrics_labels=[("service", "prefill_router")],
),
free_endpoint.serve_endpoint(
handler.free,
graceful_shutdown=True,
metrics_labels=[("service", "prefill_router")],
),
)
except Exception as e:
logger.error(f"Failed to serve endpoint: {e}")
raise
finally:
logger.info("Prefill Router Service shutting down")
def main():
"""Entry point for the prefill router service."""
uvloop.run(worker())
if __name__ == "__main__":
main()
......@@ -69,12 +69,6 @@ pub struct Flags {
#[arg(long, default_value = "round-robin")]
pub router_mode: RouterMode,
/// Maximum number of batched tokens for KV routing
/// Needed for informing the KV router
/// NOTE: this is not actually used for now
#[arg(long, default_value = "8192")]
pub max_num_batched_tokens: Option<u32>,
/// KV Router: Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 1.0
#[arg(long)]
......@@ -236,7 +230,6 @@ impl Flags {
self.use_kv_events,
self.router_replica_sync,
self.router_track_active_blocks,
self.max_num_batched_tokens,
// defaulting below args (no longer maintaining new flags for dynamo-run)
None,
None,
......
......@@ -116,6 +116,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::WorkerStats>()?;
m.add_class::<llm::kv::KvStats>()?;
m.add_class::<llm::kv::SpecDecodeStats>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvPushRouter>()?;
m.add_class::<llm::kv::KvPushRouterStream>()?;
m.add_class::<RouterMode>()?;
......
......@@ -61,7 +61,6 @@ impl KvRouterConfig {
router_track_active_blocks,
router_snapshot_threshold,
router_reset_states,
..Default::default()
},
}
}
......
......@@ -14,6 +14,7 @@ use llm_rs::kv_router::protocols::ForwardPassMetrics as RsForwardPassMetrics;
use llm_rs::kv_router::protocols::KvStats as RsKvStats;
use llm_rs::kv_router::protocols::SpecDecodeStats as RsSpecDecodeStats;
use llm_rs::kv_router::protocols::WorkerStats as RsWorkerStats;
use rs::pipeline::{AsyncEngine, SingleIn};
use rs::traits::events::EventSubscriber;
use tracing;
......@@ -832,10 +833,185 @@ impl SpecDecodeStats {
}
}
#[pyclass]
pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>,
}
#[pymethods]
impl KvRouter {
#[new]
#[pyo3(signature = (endpoint, block_size, kv_router_config=None, consumer_uuid=None))]
fn new(
endpoint: &Endpoint,
block_size: usize,
kv_router_config: Option<&super::entrypoint::KvRouterConfig>,
consumer_uuid: Option<String>,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async move {
// Get component from endpoint
let component = endpoint.inner.component();
// Verify we're not in static mode
if component.drt().primary_lease().is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Failed to get primary lease: Cannot KV route static workers",
));
}
// Create KvRouter with provided or generated consumer UUID
let consumer_uuid = consumer_uuid.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let kv_router = llm_rs::kv_router::KvRouter::new(
component.clone(),
block_size as u32,
None, // default selector
kv_router_config.map(|c| c.inner()),
consumer_uuid,
)
.await
.map_err(to_pyerr)?;
Ok(Self {
inner: Arc::new(kv_router),
})
})
}
#[pyo3(signature = (request_id, tokens, update_states=false, router_config_override=None))]
fn find_best_match<'p>(
&self,
py: Python<'p>,
request_id: String,
tokens: Vec<u32>,
update_states: bool,
router_config_override: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
depythonize(obj.bind(py)).map_err(to_pyerr)?;
Ok::<_, PyErr>(Some(override_config))
})?
} else {
None
};
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = inner
.find_best_match(
Some(&request_id),
&tokens,
router_config_override.as_ref(),
update_states,
)
.await
.map_err(to_pyerr)?;
Ok((worker_id, overlap_blocks))
})
}
fn add_request<'p>(
&self,
py: Python<'p>,
request_id: String,
tokens: Vec<u32>,
overlap_blocks: u32,
worker_id: i64,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner
.add_request(request_id, &tokens, overlap_blocks, worker_id)
.await;
Ok(())
})
}
fn mark_prefill_completed<'p>(
&self,
py: Python<'p>,
request_id: String,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner
.mark_prefill_completed(&request_id)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.free(&request_id).await.map_err(to_pyerr)?;
Ok(())
})
}
#[getter]
fn block_size(&self) -> PyResult<u32> {
Ok(self.inner.block_size())
}
}
#[pyclass]
pub(crate) struct KvPushRouter {
inner: Arc<llm_rs::kv_router::KvPushRouter>,
primary_token: tokio_util::sync::CancellationToken,
}
// TODO: can this reuse the stream conversion method in Client bindings?
impl KvPushRouter {
/// Helper method to process a request and create a Python async generator
fn process_request_to_stream<'p>(
py: Python<'p>,
inner: Arc<llm_rs::kv_router::KvPushRouter>,
request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
) -> PyResult<Bound<'p, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let single_in = SingleIn::new(request);
let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
let (tx, rx) = tokio::sync::mpsc::channel(100);
// Spawn a task to process the stream
tokio::spawn(async move {
let mut stream = stream;
while let Some(response) = stream.next().await {
// Convert LLMEngineOutput to PyObject
let py_response = Python::with_gil(|py| {
pythonize(py, &response.data)
.map(|obj| obj.unbind())
.map_err(|e| e.to_string())
});
match py_response {
Ok(obj) => {
if tx.send(obj).await.is_err() {
break; // Receiver dropped
}
}
Err(e) => {
tracing::error!("Failed to pythonize response: {}", e);
break;
}
}
}
});
// Return a Python async generator wrapper
Ok(KvPushRouterStream {
rx: Arc::new(tokio::sync::Mutex::new(rx)),
})
})
}
}
#[pymethods]
......@@ -866,16 +1042,12 @@ impl KvPushRouter {
// Get component from endpoint
let component = endpoint.inner.component();
// Get the primary token from the component's primary lease
let primary_token = component
.drt()
.primary_lease()
.ok_or_else(|| {
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Failed to get primary lease: Cannot KV route static workers",
)
})?
.primary_token();
// Verify we're not in static mode
if component.drt().primary_lease().is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Failed to get primary lease: Cannot KV route static workers",
));
}
// Create KvRouter with a unique consumer UUID
let consumer_uuid = uuid::Uuid::new_v4().to_string();
......@@ -895,7 +1067,6 @@ impl KvPushRouter {
Ok(Self {
inner: Arc::new(kv_push_router),
primary_token,
})
})
}
......@@ -967,54 +1138,27 @@ impl KvPushRouter {
let request = request_builder.build().map_err(to_pyerr)?;
let inner = self.inner.clone();
// Create a Python async generator that wraps the Rust stream
pyo3_async_runtimes::tokio::future_into_py(py, async move {
use rs::pipeline::{AsyncEngine, SingleIn};
use tokio_stream::StreamExt;
let single_in = SingleIn::new(request);
let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
let (tx, rx) = tokio::sync::mpsc::channel(100);
// Spawn a task to process the stream
tokio::spawn(async move {
let mut stream = stream;
while let Some(response) = stream.next().await {
// Convert LLMEngineOutput to PyObject
let py_response = Python::with_gil(|py| {
pythonize(py, &response.data)
.map(|obj| obj.unbind())
.map_err(|e| e.to_string())
});
// Use the helper method to process the request
Self::process_request_to_stream(py, self.inner.clone(), request)
}
match py_response {
Ok(obj) => {
if tx.send(obj).await.is_err() {
break; // Receiver dropped
}
}
Err(e) => {
tracing::error!("Failed to pythonize response: {}", e);
break;
}
}
}
});
fn generate_from_request<'p>(
&self,
py: Python<'p>,
request: PyObject,
) -> PyResult<Bound<'p, PyAny>> {
// Depythonize the request directly into PreprocessedRequest
let request: llm_rs::protocols::common::preprocessor::PreprocessedRequest =
Python::with_gil(|py| depythonize(request.bind(py)).map_err(to_pyerr))?;
// Return a Python async generator wrapper
Ok(KvPushRouterStream {
rx: Arc::new(tokio::sync::Mutex::new(rx)),
})
})
// Use the helper method to process the request
Self::process_request_to_stream(py, self.inner.clone(), request)
}
#[pyo3(signature = (context_id, token_ids, router_config_override=None))]
#[pyo3(signature = (token_ids, router_config_override=None))]
fn best_worker_id<'p>(
&self,
py: Python<'p>,
context_id: String,
token_ids: Vec<u32>,
router_config_override: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
......@@ -1032,7 +1176,7 @@ impl KvPushRouter {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = inner
.find_best_match(&context_id, &token_ids, router_config_override.as_ref())
.find_best_match(&token_ids, router_config_override.as_ref())
.await
.map_err(to_pyerr)?;
......@@ -1076,13 +1220,6 @@ impl KvPushRouter {
}
}
impl Drop for KvPushRouter {
fn drop(&mut self) {
// Cancel the primary token to shut down background tasks
self.primary_token.cancel();
}
}
// Python async generator wrapper for the stream
#[pyclass]
pub(crate) struct KvPushRouterStream {
......
......@@ -1154,6 +1154,103 @@ class ZmqKvEventListener:
"""
...
class KvRouter:
"""
A KV Router that decides which worker to use based on KV cache overlap.
This router tracks request states and manages KV cache distribution across workers.
"""
def __init__(
self,
endpoint: Endpoint,
block_size: int,
kv_router_config: Optional[KvRouterConfig] = None,
consumer_uuid: Optional[str] = None,
) -> None:
"""
Create a new KvRouter instance.
Args:
endpoint: The endpoint to associate with this router
block_size: The KV cache block size
kv_router_config: Optional configuration for the KV router
consumer_uuid: Optional unique identifier for this router instance.
If not provided, a UUID will be generated.
"""
...
async def find_best_match(
self,
request_id: str,
tokens: List[int],
*,
update_states: bool = False,
router_config_override: Optional[JsonLike] = None,
) -> Tuple[int, int]:
"""
Find the best matching worker for the given tokens.
Args:
request_id: Unique identifier for the request used for tracking
tokens: List of token IDs to find matches for
update_states: Whether to update router states for this request (default: False)
router_config_override: Optional router configuration override with fields:
- overlap_score_weight: Optional weight for overlap score
- router_temperature: Optional temperature for worker selection
Returns:
A tuple of (worker_id, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- overlap_blocks: The number of overlapping blocks found
"""
...
async def add_request(
self,
request_id: str,
tokens: List[int],
overlap_blocks: int,
worker_id: int,
) -> None:
"""
Add a request to the router's tracking system.
Args:
request_id: Unique identifier for the request
tokens: List of token IDs for the request
overlap_blocks: Number of overlapping blocks found
worker_id: ID of the worker handling this request
"""
...
async def mark_prefill_completed(self, request_id: str) -> None:
"""
Mark that prefill has been completed for a request.
Args:
request_id: The request ID to mark as prefill completed
"""
...
async def free(self, request_id: str) -> None:
"""
Free resources associated with a request.
Args:
request_id: The request ID to free
"""
...
@property
def block_size(self) -> int:
"""
Get the KV cache block size.
Returns:
The block size in tokens
"""
...
class KvPushRouter:
"""
A KV-aware push router that performs intelligent routing based on KV cache overlap.
......@@ -1211,7 +1308,6 @@ class KvPushRouter:
async def best_worker_id(
self,
context_id: str,
token_ids: List[int],
router_config_override: Optional[JsonLike] = None,
) -> Tuple[int, int]:
......@@ -1219,7 +1315,6 @@ class KvPushRouter:
Find the best matching worker for the given tokens without updating states.
Args:
context_id: String identifier for the request
token_ids: List of token IDs to find matches for
router_config_override: Optional router configuration override
......
......@@ -25,7 +25,9 @@ from dynamo._core import HttpService as HttpService
from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvPushRouter as KvPushRouter
from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouter as KvRouter
from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import KvStats as KvStats
from dynamo._core import ModelInput as ModelInput
......
......@@ -104,10 +104,6 @@ pub struct KvRouterConfig {
/// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool,
// TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32,
/// Threshold for triggering snapshots. If None, no snapshots will be performed.
pub router_snapshot_threshold: Option<u32>,
......@@ -123,7 +119,6 @@ impl Default for KvRouterConfig {
use_kv_events: true,
router_replica_sync: false,
router_track_active_blocks: true,
max_num_batched_tokens: 8192,
router_snapshot_threshold: Some(10000),
router_reset_states: false,
}
......@@ -140,7 +135,6 @@ impl KvRouterConfig {
use_kv_events: Option<bool>,
replica_sync: Option<bool>,
track_active_blocks: Option<bool>,
max_num_batched_tokens: Option<u32>,
router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>,
) -> Self {
......@@ -152,8 +146,6 @@ impl KvRouterConfig {
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
router_track_active_blocks: track_active_blocks
.unwrap_or(default.router_track_active_blocks),
max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens),
router_snapshot_threshold: router_snapshot_threshold
.unwrap_or(default.router_snapshot_threshold),
router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
......@@ -216,6 +208,8 @@ pub struct KvRouter {
block_size: u32,
kv_router_config: KvRouterConfig,
cancellation_token: tokio_util::sync::CancellationToken,
}
impl KvRouter {
......@@ -314,19 +308,25 @@ impl KvRouter {
scheduler,
block_size,
kv_router_config,
cancellation_token,
})
}
/// Give these tokens, find the worker with the best match in it's KV cache.
/// Returned overlap amount is in number of blocks.
/// Now also takes context_id for request tracking
async fn find_best_match(
/// Now also takes optional context_id for request tracking
pub async fn find_best_match(
&self,
context_id: &str,
context_id: Option<&str>,
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
update_states: bool,
) -> anyhow::Result<(i64, u32)> {
// Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() {
panic!("context_id must be provided if update_states is true");
}
let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
......@@ -350,7 +350,7 @@ impl KvRouter {
let best_worker_id = self
.scheduler
.schedule(
context_id.to_string(),
context_id.map(|s| s.to_string()),
isl_tokens,
maybe_seq_hashes_2,
overlap_scores.clone(),
......@@ -448,7 +448,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
let response = match request {
RouterRequest::New { tokens } => {
let (worker_id, overlap_blocks) = self
.find_best_match(&context_id, &tokens, None, true)
.find_best_match(Some(&context_id), &tokens, None, true)
.await?;
RouterResponse::New {
......@@ -486,12 +486,11 @@ impl KvPushRouter {
/// Find the best matching worker for the given tokens without updating states
pub async fn find_best_match(
&self,
context_id: &str,
tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>,
) -> Result<(i64, u32)> {
self.chooser
.find_best_match(context_id, tokens, router_config_override, false)
.find_best_match(None, tokens, router_config_override, false)
.await
}
......@@ -554,7 +553,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Otherwise, find the best match
self.chooser
.find_best_match(
&context_id,
Some(&context_id),
&request.token_ids,
request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id
......@@ -610,3 +609,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
}
}
}
impl Drop for KvRouter {
fn drop(&mut self) {
tracing::info!("Dropping KvRouter - cancelling background tasks");
self.cancellation_token.cancel();
}
}
......@@ -56,7 +56,7 @@ pub struct SchedulingResponse {
}
pub struct SchedulingRequest {
pub request_id: String,
pub maybe_request_id: Option<String>,
pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize,
pub overlaps: OverlapScores,
......@@ -248,7 +248,13 @@ impl KvScheduler {
continue;
}
let request_id = request.request_id;
let Some(request_id) = request.maybe_request_id else {
tracing::error!(
"No request_id provided to add_request to the slot tracker"
);
continue;
};
if let Err(e) = slots_clone
.add_request(
request_id.clone(),
......@@ -290,7 +296,7 @@ impl KvScheduler {
pub async fn schedule(
&self,
request_id: String,
maybe_request_id: Option<String>,
isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores,
......@@ -299,7 +305,7 @@ impl KvScheduler {
) -> Result<i64, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest {
request_id,
maybe_request_id,
token_seq,
isl_tokens,
overlaps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment