Unverified Commit 031590fc authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: vllm prefill router (#3155)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 51b4cd7e
...@@ -66,6 +66,24 @@ First, start the vLLM worker engines in a terminal. ...@@ -66,6 +66,24 @@ First, start the vLLM worker engines in a terminal.
--tensor-parallel-size 2 --tensor-parallel-size 2
``` ```
#### Prefill Workers
You can also launch separate decode and prefill workers for disaggregated serving. This allows you to dedicate specific GPUs to prefill (prompt processing) and decode (token generation) tasks:
```bash
# Launch 4 decode workers (GPUs 0-3)
./run_engines.sh \
--num-workers 4 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
# Launch 4 prefill workers (GPUs 4-7)
./run_engines.sh \
--prefills \
--num-workers 4 \
--base-gpu-offset 4 \
--model-path deepseek-ai/DeepSeek-R1-Distill-Llama-8B
```
#### Alternative: Launch vLLM Mock Workers #### Alternative: Launch vLLM Mock Workers
We also supports running lightweight mock engines that simulate vLLM behavior without performing actual model inference. Mocker engines are useful for testing router logic and performance without GPU requirements. Use the `--mockers` flag to run mocker engines instead of real vLLM workers. We also supports running lightweight mock engines that simulate vLLM behavior without performing actual model inference. Mocker engines are useful for testing router logic and performance without GPU requirements. Use the `--mockers` flag to run mocker engines instead of real vLLM workers.
...@@ -106,6 +124,27 @@ python -m dynamo.frontend --help ...@@ -106,6 +124,27 @@ python -m dynamo.frontend --help
For detailed explanations of router arguments (especially KV cache routing parameters), see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md). For detailed explanations of router arguments (especially KV cache routing parameters), see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md).
#### Launching a Prefill Router (Optional)
If you're using disaggregated serving with separate prefill and decode workers, you should also launch a prefill router. The prefill router handles routing prefill requests to dedicated prefill workers. When using a prefill router, it's recommended to start the frontend (decode router) with `--kv-overlap-score-weight 0` for pure load balancing (as prefix-aware routing is now handled by the prefill router):
```bash
# Start the decode router with pure load balancing
python -m dynamo.frontend \
--router-mode kv \
--kv-cache-block-size 64 \
--router-reset-states \
--http-port 8000 \
--kv-overlap-score-weight 0
# In another terminal, start the prefill router (currently only supports vLLM)
python -m dynamo.vllm_prefill_router \
--namespace dynamo \
--block-size 64
```
The prefill router will automatically coordinate with the decode router to handle request routing between prefill and decode workers.
**Note**: If you're unsure whether your backend engines correctly emit KV events for certain models (e.g., hybrid models like gpt-oss or nemotron nano 2), use the `--no-kv-events` flag to disable KV event tracking and use approximate KV indexing instead: **Note**: If you're unsure whether your backend engines correctly emit KV events for certain models (e.g., hybrid models like gpt-oss or nemotron nano 2), use the `--no-kv-events` flag to disable KV event tracking and use approximate KV indexing instead:
```bash ```bash
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
# Get port from first argument, default to 8080 if not provided # Get port from first argument, default to 8000 if not provided
PORT=${1:-8080} PORT=${1:-8000}
curl -X POST http://localhost:${PORT}/v1/chat/completions \ curl -X POST http://localhost:${PORT}/v1/chat/completions \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
......
...@@ -309,7 +309,7 @@ def main(): ...@@ -309,7 +309,7 @@ def main():
"--url", "--url",
type=str, type=str,
nargs="+", # Accept multiple URLs nargs="+", # Accept multiple URLs
default=["http://localhost:8080"], default=["http://localhost:8000"],
# default=["http://localhost:8090", "http://localhost:8090"], # default=["http://localhost:8090", "http://localhost:8090"],
help="Server URL(s). Can specify multiple URLs for parallel benchmarking", help="Server URL(s). Can specify multiple URLs for parallel benchmarking",
) )
......
...@@ -118,7 +118,7 @@ def main(): ...@@ -118,7 +118,7 @@ def main():
parser.add_argument( parser.add_argument(
"--url", "--url",
type=str, type=str,
default="http://localhost:8080", default="http://localhost:8000",
help="Server URL", help="Server URL",
) )
parser.add_argument( parser.add_argument(
......
...@@ -8,6 +8,8 @@ NUM_WORKERS=8 ...@@ -8,6 +8,8 @@ NUM_WORKERS=8
MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B" MODEL_PATH="deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
TENSOR_PARALLEL_SIZE=1 TENSOR_PARALLEL_SIZE=1
USE_MOCKERS=false USE_MOCKERS=false
USE_PREFILLS=false
BASE_GPU_OFFSET=0
EXTRA_ARGS=() EXTRA_ARGS=()
# Parse arguments # Parse arguments
...@@ -29,6 +31,14 @@ while [[ $# -gt 0 ]]; do ...@@ -29,6 +31,14 @@ while [[ $# -gt 0 ]]; do
USE_MOCKERS=true USE_MOCKERS=true
shift shift
;; ;;
--prefills)
USE_PREFILLS=true
shift
;;
--base-gpu-offset)
BASE_GPU_OFFSET="$2"
shift 2
;;
--) --)
shift shift
EXTRA_ARGS+=("$@") EXTRA_ARGS+=("$@")
...@@ -71,14 +81,22 @@ if ! [[ "$TENSOR_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$TENSOR_PARALLEL_SIZE" -lt ...@@ -71,14 +81,22 @@ if ! [[ "$TENSOR_PARALLEL_SIZE" =~ ^[0-9]+$ ]] || [ "$TENSOR_PARALLEL_SIZE" -lt
exit 1 exit 1
fi fi
if ! [[ "$BASE_GPU_OFFSET" =~ ^[0-9]+$ ]]; then
echo "Error: BASE_GPU_OFFSET must be a non-negative integer"
exit 1
fi
# Calculate total GPUs needed # Calculate total GPUs needed
TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE)) TOTAL_GPUS_NEEDED=$((NUM_WORKERS * TENSOR_PARALLEL_SIZE))
LAST_GPU=$((BASE_GPU_OFFSET + TOTAL_GPUS_NEEDED - 1))
echo "Configuration:" echo "Configuration:"
echo " Engine Type: $([ "$USE_MOCKERS" = true ] && echo "Mocker" || echo "vLLM")" echo " Engine Type: $([ "$USE_MOCKERS" = true ] && echo "Mocker" || echo "vLLM")"
echo " Worker Type: $([ "$USE_PREFILLS" = true ] && echo "Prefill" || echo "Decode")"
echo " Workers: $NUM_WORKERS" echo " Workers: $NUM_WORKERS"
echo " Model: $MODEL_PATH" echo " Model: $MODEL_PATH"
echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE" echo " Tensor Parallel Size: $TENSOR_PARALLEL_SIZE"
echo " Total GPUs needed: $TOTAL_GPUS_NEEDED" echo " Total GPUs needed: $TOTAL_GPUS_NEEDED"
echo " GPU Range: $BASE_GPU_OFFSET-$LAST_GPU"
echo " Engine args: ${EXTRA_ARGS[*]}" echo " Engine args: ${EXTRA_ARGS[*]}"
echo "" echo ""
...@@ -93,14 +111,15 @@ cleanup() { ...@@ -93,14 +111,15 @@ cleanup() {
trap cleanup SIGINT SIGTERM trap cleanup SIGINT SIGTERM
echo "Starting $NUM_WORKERS workers..." WORKER_TYPE=$([ "$USE_PREFILLS" = true ] && echo "prefill" || echo "decode")
echo "Starting $NUM_WORKERS $WORKER_TYPE workers..."
for i in $(seq 1 $NUM_WORKERS); do for i in $(seq 1 $NUM_WORKERS); do
{ {
echo "[Worker-$i] Starting..." echo "[${WORKER_TYPE^} Worker-$i] Starting..."
# Calculate GPU indices for this worker # Calculate GPU indices for this worker (with base offset)
START_GPU=$(( (i - 1) * TENSOR_PARALLEL_SIZE )) START_GPU=$(( BASE_GPU_OFFSET + (i - 1) * TENSOR_PARALLEL_SIZE ))
END_GPU=$(( START_GPU + TENSOR_PARALLEL_SIZE - 1 )) END_GPU=$(( START_GPU + TENSOR_PARALLEL_SIZE - 1 ))
# Build CUDA_VISIBLE_DEVICES string # Build CUDA_VISIBLE_DEVICES string
...@@ -124,17 +143,22 @@ for i in $(seq 1 $NUM_WORKERS); do ...@@ -124,17 +143,22 @@ for i in $(seq 1 $NUM_WORKERS); do
--endpoint dyn://test.mocker.generate \ --endpoint dyn://test.mocker.generate \
"${EXTRA_ARGS[@]}" "${EXTRA_ARGS[@]}"
else else
echo "[Worker-$i] Using GPUs: $GPU_DEVICES" echo "[${WORKER_TYPE^} Worker-$i] Using GPUs: $GPU_DEVICES"
# Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing # Run vLLM engine with PYTHONHASHSEED=0 for deterministic event IDs in KV-aware routing
VLLM_ARGS=()
VLLM_ARGS+=("--model" "$MODEL_PATH")
VLLM_ARGS+=("--tensor-parallel-size" "$TENSOR_PARALLEL_SIZE")
if [ "$USE_PREFILLS" = true ]; then
VLLM_ARGS+=("--is-prefill-worker")
fi
VLLM_ARGS+=("${EXTRA_ARGS[@]}")
exec env PYTHONHASHSEED=0 CUDA_VISIBLE_DEVICES=$GPU_DEVICES python -m dynamo.vllm \ exec env PYTHONHASHSEED=0 CUDA_VISIBLE_DEVICES=$GPU_DEVICES python -m dynamo.vllm \
--model "$MODEL_PATH" \ "${VLLM_ARGS[@]}"
--endpoint dyn://test.vllm.generate \
--tensor-parallel-size $TENSOR_PARALLEL_SIZE \
"${EXTRA_ARGS[@]}"
fi fi
} & } &
PIDS+=($!) PIDS+=($!)
echo "Started worker $i (PID: $!)" echo "Started $WORKER_TYPE worker $i (PID: $!)"
done done
echo "All workers started. Press Ctrl+C to stop." echo "All workers started. Press Ctrl+C to stop."
......
...@@ -4,11 +4,29 @@ ...@@ -4,11 +4,29 @@
set -e set -e
trap 'echo Cleaning up...; kill 0' EXIT trap 'echo Cleaning up...; kill 0' EXIT
# run ingress # Set deterministic hash for KV event IDs
python -m dynamo.frontend --router-mode kv --http-port=8000 & export PYTHONHASHSEED=0
# Common configuration
MODEL="Qwen/Qwen3-0.6B"
BLOCK_SIZE=64
# run frontend + KV router
python -m dynamo.frontend \
--router-mode kv \
--http-port 8000 \
--router-reset-states &
# run workers # run workers
# --enforce-eager is added for quick deployment. for production use, need to remove this flag # --enforce-eager is added for quick deployment. for production use, need to remove this flag
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --connector none & CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager \
--connector none &
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager --connector none CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager \
--connector none
...@@ -2,19 +2,48 @@ ...@@ -2,19 +2,48 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
set -e set -e
trap 'echo Cleaning up...; kill 0' EXIT trap 'echo Cleaning up...; kill 0' EXIT
# run ingress # Set deterministic hash for KV event IDs
python -m dynamo.frontend --router-mode kv --http-port=8000 & export PYTHONHASHSEED=0
# Common configuration
MODEL="Qwen/Qwen3-0.6B"
BLOCK_SIZE=64
# run decode router with kv-overlap-score-weight 0 for pure load balancing
python -m dynamo.frontend \
--router-mode kv \
--http-port 8000 \
--kv-overlap-score-weight 0 \
--router-reset-states &
# routing will happen between the two decode workers # run prefill router service
python -m dynamo.vllm_prefill_router \
--namespace dynamo \
--block-size $BLOCK_SIZE &
# two decode workers
# --enforce-eager is added for quick deployment. for production use, need to remove this flag # --enforce-eager is added for quick deployment. for production use, need to remove this flag
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager & CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager &
CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm --model Qwen/Qwen3-0.6B --enforce-eager & CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager &
# two prefill workers
CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.vllm \ CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.vllm \
--model Qwen/Qwen3-0.6B \ --model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager \
--is-prefill-worker &
CUDA_VISIBLE_DEVICES=3 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
--enforce-eager \ --enforce-eager \
--is-prefill-worker --is-prefill-worker
...@@ -94,9 +94,13 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -94,9 +94,13 @@ class DecodeWorkerHandler(BaseWorkerHandler):
engine, engine,
default_sampling_params, default_sampling_params,
prefill_worker_client=None, prefill_worker_client=None,
prefill_router_client=None,
prefill_router_free_client=None,
): ):
super().__init__(runtime, component, engine, default_sampling_params) super().__init__(runtime, component, engine, default_sampling_params)
self.prefill_worker_client = prefill_worker_client self.prefill_worker_client = prefill_worker_client
self.prefill_router_client = prefill_router_client
self.prefill_router_free_client = prefill_router_free_client
self.can_prefill = 0 self.can_prefill = 0
self._prefill_check_task = None self._prefill_check_task = None
...@@ -143,7 +147,11 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -143,7 +147,11 @@ class DecodeWorkerHandler(BaseWorkerHandler):
if value is not None and hasattr(sampling_params, key): if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value) setattr(sampling_params, key, value)
# TODO Change to prefill queue # TODO: Change to prefill queue
# TODO: (PeaBrane) eventually, do not use a router_client and a free_client directly.
# This is least intrusive for now, but quite error prone. Should consider (major) refactoring
# TODO: (PeaBrane) longer term, decode workers should not handle prefill routing at all.
# Prefill routing logic should be integrated directly into the frontend service potentially.
if self.can_prefill: if self.can_prefill:
# Create a copy for prefill with specific modifications # Create a copy for prefill with specific modifications
prefill_sampling_params = deepcopy(sampling_params) prefill_sampling_params = deepcopy(sampling_params)
...@@ -162,12 +170,37 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -162,12 +170,37 @@ class DecodeWorkerHandler(BaseWorkerHandler):
"request_id": request_id, "request_id": request_id,
} }
used_prefill_router = False
try: try:
prefill_response = await anext( prefill_worker_id = None
await self.prefill_worker_client.round_robin( if (
prefill_request, context=context self.prefill_router_client is not None
and self.prefill_router_client.instance_ids()
):
used_prefill_router = True
best_worker_response = await anext(
await self.prefill_router_client.generate(
{
"token_ids": request["token_ids"],
"request_id": request_id,
}
)
) )
) prefill_worker_id = best_worker_response.data().get("worker_id")
if prefill_worker_id is not None:
prefill_response = await anext(
await self.prefill_worker_client.direct(
prefill_request, prefill_worker_id, context=context
)
)
else:
prefill_response = await anext(
await self.prefill_worker_client.round_robin(
prefill_request, context=context
)
)
except Exception as e: except Exception as e:
# TODO: Cancellation does not propagate until the first token is received # TODO: Cancellation does not propagate until the first token is received
if context.is_stopped() or context.is_killed(): if context.is_stopped() or context.is_killed():
...@@ -176,6 +209,15 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -176,6 +209,15 @@ class DecodeWorkerHandler(BaseWorkerHandler):
return return
raise e raise e
finally:
if used_prefill_router:
await anext(
await self.prefill_router_free_client.generate(
{"request_id": request_id}
)
)
logger.debug(f"Freed router state for request {request_id}")
prefill_response = MyRequestOutput.model_validate_json( prefill_response = MyRequestOutput.model_validate_json(
prefill_response.data() prefill_response.data()
) )
......
...@@ -5,6 +5,7 @@ import asyncio ...@@ -5,6 +5,7 @@ import asyncio
import logging import logging
import os import os
import signal import signal
from typing import Optional
import uvloop import uvloop
from vllm.distributed.kv_events import ZmqEventPublisher from vllm.distributed.kv_events import ZmqEventPublisher
...@@ -87,6 +88,40 @@ async def worker(runtime: DistributedRuntime): ...@@ -87,6 +88,40 @@ async def worker(runtime: DistributedRuntime):
logger.debug("Worker function completed, exiting...") logger.debug("Worker function completed, exiting...")
def setup_kv_event_publisher(
config: Config,
component,
generate_endpoint,
vllm_config,
) -> Optional[ZmqKvEventPublisher]:
"""
Set up KV event publisher for prefix caching if enabled.
Returns:
ZmqKvEventPublisher if prefix caching is enabled, None otherwise.
"""
if not config.engine_args.enable_prefix_caching:
return None
# TODO: We start off with a valid endpoint, then we increment it by dp_rank
# May no longer be valid. Lets remove the increment behavior from vLLM and here
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port(
config.engine_args.kv_events_config.endpoint,
data_parallel_rank=config.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
logger.info(f"Worker reading KV events from {zmq_endpoint}")
return kv_publisher
def setup_vllm_engine(config, stat_logger=None): def setup_vllm_engine(config, stat_logger=None):
os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests os.environ["VLLM_NO_USAGE_STATS"] = "1" # Avoid internal HTTP requests
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
...@@ -137,9 +172,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -137,9 +172,7 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(config.endpoint) generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
engine_client, _, default_sampling_params = setup_vllm_engine(config) engine_client, vllm_config, default_sampling_params = setup_vllm_engine(config)
# TODO register_prefill in similar vein to register_llm
handler = PrefillWorkerHandler( handler = PrefillWorkerHandler(
runtime, component, engine_client, default_sampling_params runtime, component, engine_client, default_sampling_params
...@@ -184,6 +217,20 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -184,6 +217,20 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(config.endpoint) generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
prefill_router_client = (
await runtime.namespace(config.namespace)
.component("prefill_router") # TODO don't hardcode
.endpoint("find_best_worker")
.client()
)
prefill_router_free_client = (
await runtime.namespace(config.namespace)
.component("prefill_router") # TODO don't hardcode
.endpoint("free")
.client()
)
prefill_worker_client = ( prefill_worker_client = (
await runtime.namespace(config.namespace) await runtime.namespace(config.namespace)
.component("prefill") # TODO don't hardcode .component("prefill") # TODO don't hardcode
...@@ -213,25 +260,15 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -213,25 +260,15 @@ async def init(runtime: DistributedRuntime, config: Config):
engine_client, engine_client,
default_sampling_params, default_sampling_params,
prefill_worker_client, prefill_worker_client,
prefill_router_client,
prefill_router_free_client,
) )
if config.engine_args.enable_prefix_caching: # Set up KV event publisher for prefix caching if enabled
# TODO: We start off with a valid endpoint, then we increment it by dp_rank kv_publisher = setup_kv_event_publisher(
# May no longer be valid. Lets remove the increment behavior from vLLM and here config, component, generate_endpoint, vllm_config
zmq_endpoint = ZmqEventPublisher.offset_endpoint_port( )
config.engine_args.kv_events_config.endpoint, if kv_publisher:
data_parallel_rank=config.engine_args.data_parallel_rank or 0,
).replace("*", "127.0.0.1")
zmq_config = ZmqKvEventPublisherConfig(
worker_id=generate_endpoint.lease_id(),
kv_block_size=vllm_config.cache_config.block_size,
zmq_endpoint=zmq_endpoint,
)
kv_publisher = ZmqKvEventPublisher(component=component, config=zmq_config)
logger.info(f"Reading Events from {zmq_endpoint}")
handler.kv_publisher = kv_publisher handler.kv_publisher = kv_publisher
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
......
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
try:
from ._version import __version__
except Exception:
try:
from importlib.metadata import version as _pkg_version
__version__ = _pkg_version("ai-dynamo")
except Exception:
__version__ = "0.0.0+unknown"
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
"""
Centralized Prefill Router Service
Usage: python -m dynamo.vllm_prefill_router [args]
This service provides a single KV-aware router for all prefill workers in a
disaggregated vLLM deployment. Instead of each decode worker maintaining its own
round-robin client to prefill workers, this service uses KvRouter to make
intelligent routing decisions based on KV cache state.
"""
import argparse
import asyncio
import logging
import os
from typing import Optional
import uvloop
from dynamo.llm import KvRouter, KvRouterConfig
from dynamo.runtime import Client, DistributedRuntime, dynamo_worker
from dynamo.runtime.logging import configure_dynamo_logging
configure_dynamo_logging()
logger = logging.getLogger(__name__)
class PrefillRouterHandler:
"""Handles routing requests to prefill workers using KV-aware routing."""
def __init__(self, runtime: DistributedRuntime, namespace: str, block_size: int):
self.runtime = runtime
self.namespace = namespace
self.block_size = block_size
self.kv_router: Optional[KvRouter] = None
self.prefill_client: Optional[Client] = None
async def initialize(self):
"""Initialize the KV router for prefill workers."""
try:
# Get prefill endpoint
prefill_endpoint = (
self.runtime.namespace(self.namespace)
.component("prefill")
.endpoint("generate")
)
self.prefill_client = await prefill_endpoint.client()
# Create KvRouter with specified configuration
kv_router_config = KvRouterConfig(
router_track_active_blocks=False, # this won't matter for prefill workers
router_reset_states=True, # reset for now
)
self.kv_router = KvRouter(
endpoint=prefill_endpoint,
block_size=self.block_size,
kv_router_config=kv_router_config,
)
logger.info(
f"KvRouter initialized for prefill workers with block_size={self.block_size}"
)
except Exception as e:
logger.error(f"Failed to initialize KvRouter: {e}")
raise
async def find_best_worker(self, request):
"""
Find the best prefill worker based on KV cache state.
This endpoint is called by decode workers to determine which prefill
worker should handle a request.
"""
if self.kv_router is None:
# Fallback to round-robin if router not initialized
logger.warning("KvRouter not initialized, falling back to round-robin")
yield {
"status": "fallback",
"message": "Router not initialized",
}
return
try:
# Get current prefill workers
if self.prefill_client is None:
yield {
"status": "error",
"message": "Prefill client not initialized",
}
return
instance_ids = self.prefill_client.instance_ids()
if not instance_ids:
yield {
"status": "error",
"message": "No prefill workers available",
}
return
logger.debug(f"Routing request with {len(instance_ids)} available workers")
# Validate required fields
if "token_ids" not in request:
raise ValueError("Missing required field 'token_ids' in request")
if "request_id" not in request:
raise ValueError("Missing required field 'request_id' in request")
token_ids = request["token_ids"]
request_id = request["request_id"]
# Use KvRouter to find the best worker with state updates
best_worker_id, overlap_blocks = await self.kv_router.find_best_match(
request_id=request_id,
tokens=token_ids,
update_states=True, # Always update states for prefill routing
)
logger.debug(
f"Selected worker {best_worker_id} with {overlap_blocks} overlap blocks for request {request_id}"
)
yield {
"worker_id": best_worker_id,
"overlap_blocks": overlap_blocks,
}
except Exception as e:
logger.error(f"Error finding best worker: {e}")
yield {
"status": "error",
"message": str(e),
}
async def free(self, request):
"""
Free resources associated with a request.
This endpoint is called when a request is completed to clean up
router state.
"""
if self.kv_router is None:
logger.warning("KvRouter not initialized")
yield {
"status": "error",
"message": "Router not initialized",
}
return
try:
if "request_id" not in request:
raise ValueError("Missing required field 'request_id' in request")
request_id = request["request_id"]
# Free the request from the router
await self.kv_router.free(request_id=request_id)
logger.debug(f"Freed resources for request {request_id}")
yield {
"status": "success",
"message": f"Request {request_id} freed successfully",
}
except Exception as e:
logger.error(f"Error freeing request: {e}")
yield {
"status": "error",
"message": str(e),
}
def parse_args():
parser = argparse.ArgumentParser(
description="Dynamo Prefill Router Service: Centralized KV-aware routing for prefill workers",
formatter_class=argparse.RawTextHelpFormatter,
)
parser.add_argument(
"--namespace",
type=str,
default=os.environ.get("DYN_NAMESPACE", "dynamo"),
help="Dynamo namespace for discovering prefill workers (default: dynamo or DYN_NAMESPACE env var)",
)
parser.add_argument(
"--block-size",
type=int,
default=128,
help="KV cache block size for routing decisions (default: 128)",
)
parser.add_argument(
"--log-level",
type=str,
default="INFO",
choices=["DEBUG", "INFO", "WARNING", "ERROR"],
help="Logging level (default: INFO)",
)
return parser.parse_args()
@dynamo_worker(static=False)
async def worker(runtime: DistributedRuntime):
"""Main worker function for the prefill router service."""
args = parse_args()
# Set logging level
logging.getLogger().setLevel(getattr(logging, args.log_level))
logger.info(f"Starting Prefill Router Service for namespace: {args.namespace}")
logger.debug(f"Configuration: block_size={args.block_size}")
# Create service component
component = runtime.namespace(args.namespace).component("prefill_router")
await component.create_service()
# Create handler
handler = PrefillRouterHandler(runtime, args.namespace, args.block_size)
await handler.initialize()
# Expose endpoints
find_best_worker_endpoint = component.endpoint("find_best_worker")
free_endpoint = component.endpoint("free")
logger.debug("Starting to serve find_best_worker and free endpoints...")
try:
await asyncio.gather(
find_best_worker_endpoint.serve_endpoint(
handler.find_best_worker,
graceful_shutdown=True,
metrics_labels=[("service", "prefill_router")],
),
free_endpoint.serve_endpoint(
handler.free,
graceful_shutdown=True,
metrics_labels=[("service", "prefill_router")],
),
)
except Exception as e:
logger.error(f"Failed to serve endpoint: {e}")
raise
finally:
logger.info("Prefill Router Service shutting down")
def main():
"""Entry point for the prefill router service."""
uvloop.run(worker())
if __name__ == "__main__":
main()
...@@ -69,12 +69,6 @@ pub struct Flags { ...@@ -69,12 +69,6 @@ pub struct Flags {
#[arg(long, default_value = "round-robin")] #[arg(long, default_value = "round-robin")]
pub router_mode: RouterMode, pub router_mode: RouterMode,
/// Maximum number of batched tokens for KV routing
/// Needed for informing the KV router
/// NOTE: this is not actually used for now
#[arg(long, default_value = "8192")]
pub max_num_batched_tokens: Option<u32>,
/// KV Router: Weight for overlap score in worker selection. /// KV Router: Weight for overlap score in worker selection.
/// Higher values prioritize KV cache reuse. Default: 1.0 /// Higher values prioritize KV cache reuse. Default: 1.0
#[arg(long)] #[arg(long)]
...@@ -236,7 +230,6 @@ impl Flags { ...@@ -236,7 +230,6 @@ impl Flags {
self.use_kv_events, self.use_kv_events,
self.router_replica_sync, self.router_replica_sync,
self.router_track_active_blocks, self.router_track_active_blocks,
self.max_num_batched_tokens,
// defaulting below args (no longer maintaining new flags for dynamo-run) // defaulting below args (no longer maintaining new flags for dynamo-run)
None, None,
None, None,
......
...@@ -116,6 +116,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> { ...@@ -116,6 +116,7 @@ fn _core(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<llm::kv::WorkerStats>()?; m.add_class::<llm::kv::WorkerStats>()?;
m.add_class::<llm::kv::KvStats>()?; m.add_class::<llm::kv::KvStats>()?;
m.add_class::<llm::kv::SpecDecodeStats>()?; m.add_class::<llm::kv::SpecDecodeStats>()?;
m.add_class::<llm::kv::KvRouter>()?;
m.add_class::<llm::kv::KvPushRouter>()?; m.add_class::<llm::kv::KvPushRouter>()?;
m.add_class::<llm::kv::KvPushRouterStream>()?; m.add_class::<llm::kv::KvPushRouterStream>()?;
m.add_class::<RouterMode>()?; m.add_class::<RouterMode>()?;
......
...@@ -61,7 +61,6 @@ impl KvRouterConfig { ...@@ -61,7 +61,6 @@ impl KvRouterConfig {
router_track_active_blocks, router_track_active_blocks,
router_snapshot_threshold, router_snapshot_threshold,
router_reset_states, router_reset_states,
..Default::default()
}, },
} }
} }
......
...@@ -14,6 +14,7 @@ use llm_rs::kv_router::protocols::ForwardPassMetrics as RsForwardPassMetrics; ...@@ -14,6 +14,7 @@ use llm_rs::kv_router::protocols::ForwardPassMetrics as RsForwardPassMetrics;
use llm_rs::kv_router::protocols::KvStats as RsKvStats; use llm_rs::kv_router::protocols::KvStats as RsKvStats;
use llm_rs::kv_router::protocols::SpecDecodeStats as RsSpecDecodeStats; use llm_rs::kv_router::protocols::SpecDecodeStats as RsSpecDecodeStats;
use llm_rs::kv_router::protocols::WorkerStats as RsWorkerStats; use llm_rs::kv_router::protocols::WorkerStats as RsWorkerStats;
use rs::pipeline::{AsyncEngine, SingleIn};
use rs::traits::events::EventSubscriber; use rs::traits::events::EventSubscriber;
use tracing; use tracing;
...@@ -832,10 +833,185 @@ impl SpecDecodeStats { ...@@ -832,10 +833,185 @@ impl SpecDecodeStats {
} }
} }
#[pyclass]
pub(crate) struct KvRouter {
inner: Arc<llm_rs::kv_router::KvRouter>,
}
#[pymethods]
impl KvRouter {
#[new]
#[pyo3(signature = (endpoint, block_size, kv_router_config=None, consumer_uuid=None))]
fn new(
endpoint: &Endpoint,
block_size: usize,
kv_router_config: Option<&super::entrypoint::KvRouterConfig>,
consumer_uuid: Option<String>,
) -> PyResult<Self> {
let runtime = pyo3_async_runtimes::tokio::get_runtime();
runtime.block_on(async move {
// Get component from endpoint
let component = endpoint.inner.component();
// Verify we're not in static mode
if component.drt().primary_lease().is_none() {
return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
"Failed to get primary lease: Cannot KV route static workers",
));
}
// Create KvRouter with provided or generated consumer UUID
let consumer_uuid = consumer_uuid.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
let kv_router = llm_rs::kv_router::KvRouter::new(
component.clone(),
block_size as u32,
None, // default selector
kv_router_config.map(|c| c.inner()),
consumer_uuid,
)
.await
.map_err(to_pyerr)?;
Ok(Self {
inner: Arc::new(kv_router),
})
})
}
#[pyo3(signature = (request_id, tokens, update_states=false, router_config_override=None))]
fn find_best_match<'p>(
&self,
py: Python<'p>,
request_id: String,
tokens: Vec<u32>,
update_states: bool,
router_config_override: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> {
let router_config_override = if let Some(obj) = router_config_override {
Python::with_gil(|py| {
let override_config: llm_rs::kv_router::RouterConfigOverride =
depythonize(obj.bind(py)).map_err(to_pyerr)?;
Ok::<_, PyErr>(Some(override_config))
})?
} else {
None
};
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = inner
.find_best_match(
Some(&request_id),
&tokens,
router_config_override.as_ref(),
update_states,
)
.await
.map_err(to_pyerr)?;
Ok((worker_id, overlap_blocks))
})
}
fn add_request<'p>(
&self,
py: Python<'p>,
request_id: String,
tokens: Vec<u32>,
overlap_blocks: u32,
worker_id: i64,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner
.add_request(request_id, &tokens, overlap_blocks, worker_id)
.await;
Ok(())
})
}
fn mark_prefill_completed<'p>(
&self,
py: Python<'p>,
request_id: String,
) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner
.mark_prefill_completed(&request_id)
.await
.map_err(to_pyerr)?;
Ok(())
})
}
fn free<'p>(&self, py: Python<'p>, request_id: String) -> PyResult<Bound<'p, PyAny>> {
let inner = self.inner.clone();
pyo3_async_runtimes::tokio::future_into_py(py, async move {
inner.free(&request_id).await.map_err(to_pyerr)?;
Ok(())
})
}
#[getter]
fn block_size(&self) -> PyResult<u32> {
Ok(self.inner.block_size())
}
}
#[pyclass] #[pyclass]
pub(crate) struct KvPushRouter { pub(crate) struct KvPushRouter {
inner: Arc<llm_rs::kv_router::KvPushRouter>, inner: Arc<llm_rs::kv_router::KvPushRouter>,
primary_token: tokio_util::sync::CancellationToken, }
// TODO: can this reuse the stream conversion method in Client bindings?
impl KvPushRouter {
/// Helper method to process a request and create a Python async generator
fn process_request_to_stream<'p>(
py: Python<'p>,
inner: Arc<llm_rs::kv_router::KvPushRouter>,
request: llm_rs::protocols::common::preprocessor::PreprocessedRequest,
) -> PyResult<Bound<'p, PyAny>> {
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let single_in = SingleIn::new(request);
let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
let (tx, rx) = tokio::sync::mpsc::channel(100);
// Spawn a task to process the stream
tokio::spawn(async move {
let mut stream = stream;
while let Some(response) = stream.next().await {
// Convert LLMEngineOutput to PyObject
let py_response = Python::with_gil(|py| {
pythonize(py, &response.data)
.map(|obj| obj.unbind())
.map_err(|e| e.to_string())
});
match py_response {
Ok(obj) => {
if tx.send(obj).await.is_err() {
break; // Receiver dropped
}
}
Err(e) => {
tracing::error!("Failed to pythonize response: {}", e);
break;
}
}
}
});
// Return a Python async generator wrapper
Ok(KvPushRouterStream {
rx: Arc::new(tokio::sync::Mutex::new(rx)),
})
})
}
} }
#[pymethods] #[pymethods]
...@@ -866,16 +1042,12 @@ impl KvPushRouter { ...@@ -866,16 +1042,12 @@ impl KvPushRouter {
// Get component from endpoint // Get component from endpoint
let component = endpoint.inner.component(); let component = endpoint.inner.component();
// Get the primary token from the component's primary lease // Verify we're not in static mode
let primary_token = component if component.drt().primary_lease().is_none() {
.drt() return Err(PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(
.primary_lease() "Failed to get primary lease: Cannot KV route static workers",
.ok_or_else(|| { ));
PyErr::new::<pyo3::exceptions::PyRuntimeError, _>( }
"Failed to get primary lease: Cannot KV route static workers",
)
})?
.primary_token();
// Create KvRouter with a unique consumer UUID // Create KvRouter with a unique consumer UUID
let consumer_uuid = uuid::Uuid::new_v4().to_string(); let consumer_uuid = uuid::Uuid::new_v4().to_string();
...@@ -895,7 +1067,6 @@ impl KvPushRouter { ...@@ -895,7 +1067,6 @@ impl KvPushRouter {
Ok(Self { Ok(Self {
inner: Arc::new(kv_push_router), inner: Arc::new(kv_push_router),
primary_token,
}) })
}) })
} }
...@@ -967,54 +1138,27 @@ impl KvPushRouter { ...@@ -967,54 +1138,27 @@ impl KvPushRouter {
let request = request_builder.build().map_err(to_pyerr)?; let request = request_builder.build().map_err(to_pyerr)?;
let inner = self.inner.clone(); // Use the helper method to process the request
Self::process_request_to_stream(py, self.inner.clone(), request)
// Create a Python async generator that wraps the Rust stream }
pyo3_async_runtimes::tokio::future_into_py(py, async move {
use rs::pipeline::{AsyncEngine, SingleIn};
use tokio_stream::StreamExt;
let single_in = SingleIn::new(request);
let stream = inner.generate(single_in).await.map_err(to_pyerr)?;
let (tx, rx) = tokio::sync::mpsc::channel(100);
// Spawn a task to process the stream
tokio::spawn(async move {
let mut stream = stream;
while let Some(response) = stream.next().await {
// Convert LLMEngineOutput to PyObject
let py_response = Python::with_gil(|py| {
pythonize(py, &response.data)
.map(|obj| obj.unbind())
.map_err(|e| e.to_string())
});
match py_response { fn generate_from_request<'p>(
Ok(obj) => { &self,
if tx.send(obj).await.is_err() { py: Python<'p>,
break; // Receiver dropped request: PyObject,
} ) -> PyResult<Bound<'p, PyAny>> {
} // Depythonize the request directly into PreprocessedRequest
Err(e) => { let request: llm_rs::protocols::common::preprocessor::PreprocessedRequest =
tracing::error!("Failed to pythonize response: {}", e); Python::with_gil(|py| depythonize(request.bind(py)).map_err(to_pyerr))?;
break;
}
}
}
});
// Return a Python async generator wrapper // Use the helper method to process the request
Ok(KvPushRouterStream { Self::process_request_to_stream(py, self.inner.clone(), request)
rx: Arc::new(tokio::sync::Mutex::new(rx)),
})
})
} }
#[pyo3(signature = (context_id, token_ids, router_config_override=None))] #[pyo3(signature = (token_ids, router_config_override=None))]
fn best_worker_id<'p>( fn best_worker_id<'p>(
&self, &self,
py: Python<'p>, py: Python<'p>,
context_id: String,
token_ids: Vec<u32>, token_ids: Vec<u32>,
router_config_override: Option<PyObject>, router_config_override: Option<PyObject>,
) -> PyResult<Bound<'p, PyAny>> { ) -> PyResult<Bound<'p, PyAny>> {
...@@ -1032,7 +1176,7 @@ impl KvPushRouter { ...@@ -1032,7 +1176,7 @@ impl KvPushRouter {
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
let (worker_id, overlap_blocks) = inner let (worker_id, overlap_blocks) = inner
.find_best_match(&context_id, &token_ids, router_config_override.as_ref()) .find_best_match(&token_ids, router_config_override.as_ref())
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
...@@ -1076,13 +1220,6 @@ impl KvPushRouter { ...@@ -1076,13 +1220,6 @@ impl KvPushRouter {
} }
} }
impl Drop for KvPushRouter {
fn drop(&mut self) {
// Cancel the primary token to shut down background tasks
self.primary_token.cancel();
}
}
// Python async generator wrapper for the stream // Python async generator wrapper for the stream
#[pyclass] #[pyclass]
pub(crate) struct KvPushRouterStream { pub(crate) struct KvPushRouterStream {
......
...@@ -1154,6 +1154,103 @@ class ZmqKvEventListener: ...@@ -1154,6 +1154,103 @@ class ZmqKvEventListener:
""" """
... ...
class KvRouter:
"""
A KV Router that decides which worker to use based on KV cache overlap.
This router tracks request states and manages KV cache distribution across workers.
"""
def __init__(
self,
endpoint: Endpoint,
block_size: int,
kv_router_config: Optional[KvRouterConfig] = None,
consumer_uuid: Optional[str] = None,
) -> None:
"""
Create a new KvRouter instance.
Args:
endpoint: The endpoint to associate with this router
block_size: The KV cache block size
kv_router_config: Optional configuration for the KV router
consumer_uuid: Optional unique identifier for this router instance.
If not provided, a UUID will be generated.
"""
...
async def find_best_match(
self,
request_id: str,
tokens: List[int],
*,
update_states: bool = False,
router_config_override: Optional[JsonLike] = None,
) -> Tuple[int, int]:
"""
Find the best matching worker for the given tokens.
Args:
request_id: Unique identifier for the request used for tracking
tokens: List of token IDs to find matches for
update_states: Whether to update router states for this request (default: False)
router_config_override: Optional router configuration override with fields:
- overlap_score_weight: Optional weight for overlap score
- router_temperature: Optional temperature for worker selection
Returns:
A tuple of (worker_id, overlap_blocks) where:
- worker_id: The ID of the best matching worker
- overlap_blocks: The number of overlapping blocks found
"""
...
async def add_request(
self,
request_id: str,
tokens: List[int],
overlap_blocks: int,
worker_id: int,
) -> None:
"""
Add a request to the router's tracking system.
Args:
request_id: Unique identifier for the request
tokens: List of token IDs for the request
overlap_blocks: Number of overlapping blocks found
worker_id: ID of the worker handling this request
"""
...
async def mark_prefill_completed(self, request_id: str) -> None:
"""
Mark that prefill has been completed for a request.
Args:
request_id: The request ID to mark as prefill completed
"""
...
async def free(self, request_id: str) -> None:
"""
Free resources associated with a request.
Args:
request_id: The request ID to free
"""
...
@property
def block_size(self) -> int:
"""
Get the KV cache block size.
Returns:
The block size in tokens
"""
...
class KvPushRouter: class KvPushRouter:
""" """
A KV-aware push router that performs intelligent routing based on KV cache overlap. A KV-aware push router that performs intelligent routing based on KV cache overlap.
...@@ -1211,7 +1308,6 @@ class KvPushRouter: ...@@ -1211,7 +1308,6 @@ class KvPushRouter:
async def best_worker_id( async def best_worker_id(
self, self,
context_id: str,
token_ids: List[int], token_ids: List[int],
router_config_override: Optional[JsonLike] = None, router_config_override: Optional[JsonLike] = None,
) -> Tuple[int, int]: ) -> Tuple[int, int]:
...@@ -1219,7 +1315,6 @@ class KvPushRouter: ...@@ -1219,7 +1315,6 @@ class KvPushRouter:
Find the best matching worker for the given tokens without updating states. Find the best matching worker for the given tokens without updating states.
Args: Args:
context_id: String identifier for the request
token_ids: List of token IDs to find matches for token_ids: List of token IDs to find matches for
router_config_override: Optional router configuration override router_config_override: Optional router configuration override
......
...@@ -25,7 +25,9 @@ from dynamo._core import HttpService as HttpService ...@@ -25,7 +25,9 @@ from dynamo._core import HttpService as HttpService
from dynamo._core import KvEventPublisher as KvEventPublisher from dynamo._core import KvEventPublisher as KvEventPublisher
from dynamo._core import KvIndexer as KvIndexer from dynamo._core import KvIndexer as KvIndexer
from dynamo._core import KvMetricsAggregator as KvMetricsAggregator from dynamo._core import KvMetricsAggregator as KvMetricsAggregator
from dynamo._core import KvPushRouter as KvPushRouter
from dynamo._core import KvRecorder as KvRecorder from dynamo._core import KvRecorder as KvRecorder
from dynamo._core import KvRouter as KvRouter
from dynamo._core import KvRouterConfig as KvRouterConfig from dynamo._core import KvRouterConfig as KvRouterConfig
from dynamo._core import KvStats as KvStats from dynamo._core import KvStats as KvStats
from dynamo._core import ModelInput as ModelInput from dynamo._core import ModelInput as ModelInput
......
...@@ -104,10 +104,6 @@ pub struct KvRouterConfig { ...@@ -104,10 +104,6 @@ pub struct KvRouterConfig {
/// Whether to track active blocks in the router (default: true) /// Whether to track active blocks in the router (default: true)
pub router_track_active_blocks: bool, pub router_track_active_blocks: bool,
// TODO: this is not actually used for now
// Would need this (along with total kv blocks) to trigger AllWorkersBusy error for e.g. rate-limiting
pub max_num_batched_tokens: u32,
/// Threshold for triggering snapshots. If None, no snapshots will be performed. /// Threshold for triggering snapshots. If None, no snapshots will be performed.
pub router_snapshot_threshold: Option<u32>, pub router_snapshot_threshold: Option<u32>,
...@@ -123,7 +119,6 @@ impl Default for KvRouterConfig { ...@@ -123,7 +119,6 @@ impl Default for KvRouterConfig {
use_kv_events: true, use_kv_events: true,
router_replica_sync: false, router_replica_sync: false,
router_track_active_blocks: true, router_track_active_blocks: true,
max_num_batched_tokens: 8192,
router_snapshot_threshold: Some(10000), router_snapshot_threshold: Some(10000),
router_reset_states: false, router_reset_states: false,
} }
...@@ -140,7 +135,6 @@ impl KvRouterConfig { ...@@ -140,7 +135,6 @@ impl KvRouterConfig {
use_kv_events: Option<bool>, use_kv_events: Option<bool>,
replica_sync: Option<bool>, replica_sync: Option<bool>,
track_active_blocks: Option<bool>, track_active_blocks: Option<bool>,
max_num_batched_tokens: Option<u32>,
router_snapshot_threshold: Option<Option<u32>>, router_snapshot_threshold: Option<Option<u32>>,
router_reset_states: Option<bool>, router_reset_states: Option<bool>,
) -> Self { ) -> Self {
...@@ -152,8 +146,6 @@ impl KvRouterConfig { ...@@ -152,8 +146,6 @@ impl KvRouterConfig {
router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync), router_replica_sync: replica_sync.unwrap_or(default.router_replica_sync),
router_track_active_blocks: track_active_blocks router_track_active_blocks: track_active_blocks
.unwrap_or(default.router_track_active_blocks), .unwrap_or(default.router_track_active_blocks),
max_num_batched_tokens: max_num_batched_tokens
.unwrap_or(default.max_num_batched_tokens),
router_snapshot_threshold: router_snapshot_threshold router_snapshot_threshold: router_snapshot_threshold
.unwrap_or(default.router_snapshot_threshold), .unwrap_or(default.router_snapshot_threshold),
router_reset_states: router_reset_states.unwrap_or(default.router_reset_states), router_reset_states: router_reset_states.unwrap_or(default.router_reset_states),
...@@ -216,6 +208,8 @@ pub struct KvRouter { ...@@ -216,6 +208,8 @@ pub struct KvRouter {
block_size: u32, block_size: u32,
kv_router_config: KvRouterConfig, kv_router_config: KvRouterConfig,
cancellation_token: tokio_util::sync::CancellationToken,
} }
impl KvRouter { impl KvRouter {
...@@ -314,19 +308,25 @@ impl KvRouter { ...@@ -314,19 +308,25 @@ impl KvRouter {
scheduler, scheduler,
block_size, block_size,
kv_router_config, kv_router_config,
cancellation_token,
}) })
} }
/// Give these tokens, find the worker with the best match in it's KV cache. /// Give these tokens, find the worker with the best match in it's KV cache.
/// Returned overlap amount is in number of blocks. /// Returned overlap amount is in number of blocks.
/// Now also takes context_id for request tracking /// Now also takes optional context_id for request tracking
async fn find_best_match( pub async fn find_best_match(
&self, &self,
context_id: &str, context_id: Option<&str>,
tokens: &[u32], tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
update_states: bool, update_states: bool,
) -> anyhow::Result<(i64, u32)> { ) -> anyhow::Result<(i64, u32)> {
// Validate that context_id is provided when update_states is true
if update_states && context_id.is_none() {
panic!("context_id must be provided if update_states is true");
}
let isl_tokens = tokens.len(); let isl_tokens = tokens.len();
let block_hashes = compute_block_hash_for_seq(tokens, self.block_size); let block_hashes = compute_block_hash_for_seq(tokens, self.block_size);
...@@ -350,7 +350,7 @@ impl KvRouter { ...@@ -350,7 +350,7 @@ impl KvRouter {
let best_worker_id = self let best_worker_id = self
.scheduler .scheduler
.schedule( .schedule(
context_id.to_string(), context_id.map(|s| s.to_string()),
isl_tokens, isl_tokens,
maybe_seq_hashes_2, maybe_seq_hashes_2,
overlap_scores.clone(), overlap_scores.clone(),
...@@ -448,7 +448,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er ...@@ -448,7 +448,7 @@ impl AsyncEngine<SingleIn<RouterRequest>, ManyOut<Annotated<RouterResponse>>, Er
let response = match request { let response = match request {
RouterRequest::New { tokens } => { RouterRequest::New { tokens } => {
let (worker_id, overlap_blocks) = self let (worker_id, overlap_blocks) = self
.find_best_match(&context_id, &tokens, None, true) .find_best_match(Some(&context_id), &tokens, None, true)
.await?; .await?;
RouterResponse::New { RouterResponse::New {
...@@ -486,12 +486,11 @@ impl KvPushRouter { ...@@ -486,12 +486,11 @@ impl KvPushRouter {
/// Find the best matching worker for the given tokens without updating states /// Find the best matching worker for the given tokens without updating states
pub async fn find_best_match( pub async fn find_best_match(
&self, &self,
context_id: &str,
tokens: &[u32], tokens: &[u32],
router_config_override: Option<&RouterConfigOverride>, router_config_override: Option<&RouterConfigOverride>,
) -> Result<(i64, u32)> { ) -> Result<(i64, u32)> {
self.chooser self.chooser
.find_best_match(context_id, tokens, router_config_override, false) .find_best_match(None, tokens, router_config_override, false)
.await .await
} }
...@@ -554,7 +553,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -554,7 +553,7 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
// Otherwise, find the best match // Otherwise, find the best match
self.chooser self.chooser
.find_best_match( .find_best_match(
&context_id, Some(&context_id),
&request.token_ids, &request.token_ids,
request.router_config_override.as_ref(), request.router_config_override.as_ref(),
!query_instance_id, // Don't update states if query_instance_id !query_instance_id, // Don't update states if query_instance_id
...@@ -610,3 +609,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu ...@@ -610,3 +609,10 @@ impl AsyncEngine<SingleIn<PreprocessedRequest>, ManyOut<Annotated<LLMEngineOutpu
} }
} }
} }
impl Drop for KvRouter {
fn drop(&mut self) {
tracing::info!("Dropping KvRouter - cancelling background tasks");
self.cancellation_token.cancel();
}
}
...@@ -56,7 +56,7 @@ pub struct SchedulingResponse { ...@@ -56,7 +56,7 @@ pub struct SchedulingResponse {
} }
pub struct SchedulingRequest { pub struct SchedulingRequest {
pub request_id: String, pub maybe_request_id: Option<String>,
pub token_seq: Option<Vec<SequenceHash>>, pub token_seq: Option<Vec<SequenceHash>>,
pub isl_tokens: usize, pub isl_tokens: usize,
pub overlaps: OverlapScores, pub overlaps: OverlapScores,
...@@ -248,7 +248,13 @@ impl KvScheduler { ...@@ -248,7 +248,13 @@ impl KvScheduler {
continue; continue;
} }
let request_id = request.request_id; let Some(request_id) = request.maybe_request_id else {
tracing::error!(
"No request_id provided to add_request to the slot tracker"
);
continue;
};
if let Err(e) = slots_clone if let Err(e) = slots_clone
.add_request( .add_request(
request_id.clone(), request_id.clone(),
...@@ -290,7 +296,7 @@ impl KvScheduler { ...@@ -290,7 +296,7 @@ impl KvScheduler {
pub async fn schedule( pub async fn schedule(
&self, &self,
request_id: String, maybe_request_id: Option<String>,
isl_tokens: usize, isl_tokens: usize,
token_seq: Option<Vec<SequenceHash>>, token_seq: Option<Vec<SequenceHash>>,
overlaps: OverlapScores, overlaps: OverlapScores,
...@@ -299,7 +305,7 @@ impl KvScheduler { ...@@ -299,7 +305,7 @@ impl KvScheduler {
) -> Result<i64, KvSchedulerError> { ) -> Result<i64, KvSchedulerError> {
let (resp_tx, resp_rx) = tokio::sync::oneshot::channel(); let (resp_tx, resp_rx) = tokio::sync::oneshot::channel();
let request = SchedulingRequest { let request = SchedulingRequest {
request_id, maybe_request_id,
token_seq, token_seq,
isl_tokens, isl_tokens,
overlaps, overlaps,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment