Unverified Commit e01c6e99 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: bake prefill router into frontend, supporting vllm for now (#3762)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent cd5e7e33
......@@ -118,27 +118,14 @@ python -m dynamo.frontend --help
For detailed explanations of router arguments (especially KV cache routing parameters), see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md).
#### Launching a Standalone Router for Prefill Workers (Optional)
#### Disaggregated Serving with Automatic Prefill Routing
If you're using disaggregated serving with separate prefill and decode workers, you should also launch a standalone router for prefill workers. This router handles routing prefill requests to dedicated prefill workers. When using a standalone prefill router, it's recommended to start the frontend (decode router) with `--kv-overlap-score-weight 0` for pure load balancing (as prefix-aware routing is now handled by the standalone router):
When you launch prefill workers using `run_engines.sh --prefill`, the frontend automatically detects them and activates an internal prefill router. This prefill router:
- Automatically routes initial token processing to dedicated prefill workers
- Uses KV-aware routing regardless of the frontend's `--router-mode` setting
- Seamlessly integrates with your decode workers for token generation
```bash
# Start the decode router with pure load balancing
python -m dynamo.frontend \
--router-mode kv \
--router-reset-states \
--http-port 8000 \
--kv-overlap-score-weight 0
# In another terminal, start the standalone router for prefill workers
python -m dynamo.router \
--endpoint dynamo.prefill.generate \
--block-size 64 \
--router-reset-states \
--no-track-active-blocks
```
The `--router-reset-states` flag clears any previous state, and `--no-track-active-blocks` disables active block tracking (suitable for prefill-only routing where decode load is not relevant).
No additional configuration is needed - simply launch both decode and prefill workers, and the system handles the rest. See the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md#disaggregated-serving-prefill-and-decode) for more details.
**Note**: If you're unsure whether your backend engines correctly emit KV events for certain models (e.g., hybrid models like gpt-oss or nemotron nano 2), use the `--no-kv-events` flag to disable KV event tracking and use approximate KV indexing instead:
......
......@@ -11,20 +11,13 @@ export PYTHONHASHSEED=0
MODEL="Qwen/Qwen3-0.6B"
BLOCK_SIZE=64
# run decode router with kv-overlap-score-weight 0 for pure load balancing
# Start frontend with KV routing
# The frontend will automatically detect prefill workers and activate an internal prefill router
python -m dynamo.frontend \
--router-mode kv \
--http-port 8000 \
--kv-overlap-score-weight 0 \
--router-reset-states &
# run standalone router service for prefill workers
python -m dynamo.router \
--endpoint dynamo.prefill.generate \
--block-size $BLOCK_SIZE \
--router-reset-states \
--no-track-active-blocks &
# two decode workers
# --enforce-eager is added for quick deployment. for production use, need to remove this flag
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \
......@@ -38,6 +31,8 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \
--enforce-eager &
# two prefill workers
# When registered with --is-prefill-worker, these workers are automatically detected
# by the frontend, which activates an internal prefill router for KV-aware prefill routing
CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.vllm \
--model $MODEL \
--block-size $BLOCK_SIZE \
......
......@@ -40,7 +40,12 @@ The standalone router exposes two endpoints via the Dynamo runtime:
Clients query the `find_best_worker` endpoint to determine which worker should process each request, then call the selected worker directly.
## Example: Disaggregated Serving with Prefill Workers
## Example: Manual Disaggregated Serving (Alternative Setup)
> [!Note]
> **This is an alternative advanced setup.** The recommended approach for disaggregated serving is to use the frontend's automatic prefill routing, which activates when you register workers with `ModelType.Prefill`. See the [KV Cache Routing documentation](/docs/architecture/kv_cache_routing.md#disaggregated-serving-prefill-and-decode) for the default setup.
>
> Use this manual setup if you need explicit control over prefill routing configuration or want to manage prefill and decode routers separately.
See [`components/backends/vllm/launch/disagg_router.sh`](/components/backends/vllm/launch/disagg_router.sh) for a complete example.
......
......@@ -4,13 +4,10 @@
import asyncio
import logging
import os
import uuid
from abc import ABC, abstractmethod
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import Any, AsyncGenerator, Dict
import msgspec
from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams
from vllm.v1.engine.exceptions import EngineDeadError
......@@ -23,6 +20,35 @@ configure_dynamo_logging()
logger = logging.getLogger(__name__)
def build_sampling_params(
request: Dict[str, Any], default_sampling_params: Dict[str, Any]
) -> SamplingParams:
"""
Build SamplingParams from a PreprocessedRequest.
Args:
request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions'
default_sampling_params: Default sampling parameters to initialize with
Returns:
SamplingParams configured from the request
"""
sampling_params = SamplingParams(**default_sampling_params)
sampling_params.detokenize = False
# Apply sampling_options
for key, value in request["sampling_options"].items():
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
# Apply stop_conditions
for key, value in request["stop_conditions"].items():
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
return sampling_params
class BaseWorkerHandler(ABC):
"""
Request handler for the generate and clear_kv_blocks endpoints.
......@@ -130,93 +156,31 @@ class DecodeWorkerHandler(BaseWorkerHandler):
component,
engine,
default_sampling_params,
prefill_worker_client=None,
prefill_router_client=None,
):
super().__init__(runtime, component, engine, default_sampling_params)
self.prefill_worker_client = prefill_worker_client
self.prefill_router_client = prefill_router_client
async def generate(self, request, context):
request_id = str(uuid.uuid4().hex)
logger.debug(f"New Request ID: {request_id}")
# Use context ID for request tracking and correlation
request_id = context.id()
logger.debug(f"Decode Request ID: {request_id}")
prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = SamplingParams(**self.default_sampling_params)
sampling_params.detokenize = False
for key, value in request["sampling_options"].items():
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
for key, value in request["stop_conditions"].items():
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
# Use prefill router or worker if available
can_prefill = (
self.prefill_worker_client is not None
) and self.prefill_worker_client.instance_ids()
if can_prefill:
# Create prefill sampling params with modifications
prefill_sampling_params = deepcopy(sampling_params)
if prefill_sampling_params.extra_args is None:
prefill_sampling_params.extra_args = {}
prefill_sampling_params.extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
prefill_sampling_params.max_tokens = 1
prefill_sampling_params.min_tokens = 1
try:
# Send request with sampling_params and request_id in extra_args
prefill_request = request.copy()
# TODO (PeaBrane): this smells a bit bad as not we have two nestings
# of extra_args (an inner one again in sampling_params)
prefill_request["extra_args"] = {
"sampling_params": msgspec.to_builtins(prefill_sampling_params),
"request_id": request_id,
}
# Try router first if available, fallback to worker
if (
self.prefill_router_client is not None
and self.prefill_router_client.instance_ids()
):
# Call router's generate endpoint which returns LLMEngineOutput
prefill_response = await anext(
await self.prefill_router_client.generate(
prefill_request, context=context
)
)
else:
# Fallback to direct worker with same format
prefill_response = await anext(
await self.prefill_worker_client.round_robin(
prefill_request, context=context
)
)
prefill_output = prefill_response.data()
# Extract kv_transfer_params from response
kv_transfer_params = prefill_output.get("extra_args", {}).get(
"kv_transfer_params"
)
if kv_transfer_params:
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args[
"kv_transfer_params"
] = kv_transfer_params
except Exception as e:
if context.is_stopped() or context.is_killed():
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
return
logger.warning(f"Prefill error: {e}, falling back to local prefill")
# Build sampling params from request
sampling_params = build_sampling_params(request, self.default_sampling_params)
# Extract disaggregated_params from request (set by prefill router in Rust frontend)
disaggregated_params = request.get("disaggregated_params")
if disaggregated_params:
# Prefill was performed - use the disaggregated params
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args["kv_transfer_params"] = disaggregated_params.get(
"kv_transfer_params"
)
logger.debug(
f"Using disaggregated params from prefill for request {request_id}"
)
dp_rank = request.get("dp_rank", None)
......@@ -238,17 +202,25 @@ class PrefillWorkerHandler(BaseWorkerHandler):
super().__init__(runtime, component, engine, default_sampling_params)
async def generate(self, request, context):
# Extract from PreprocessedRequest format - request_id and sampling_params from extra_args
extra_args = request.get("extra_args", {})
request_id = extra_args.get("request_id", str(uuid.uuid4().hex))
logger.debug(f"New Prefill Request ID: {request_id}")
# Use context ID for request tracking and correlation with decode phase
request_id = context.id()
logger.debug(f"Prefill Request ID: {request_id}")
token_ids = request["token_ids"]
prompt = TokensPrompt(prompt_token_ids=token_ids)
# Get sampling_params from extra_args
sampling_params_dict = extra_args.get("sampling_params", {})
sampling_params = msgspec.convert(sampling_params_dict, SamplingParams)
# Build sampling params from request using shared utility
sampling_params = build_sampling_params(request, self.default_sampling_params)
# Configure for prefill-only mode with remote decode
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
# Override for prefill: only generate 1 token
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
dp_rank = request.get("dp_rank", None)
......@@ -271,10 +243,10 @@ class PrefillWorkerHandler(BaseWorkerHandler):
output: Dict[str, Any] = {
"token_ids": list(token_ids),
"extra_args": (
"disaggregated_params": (
{"kv_transfer_params": res.kv_transfer_params}
if res.kv_transfer_params
else {}
else None
),
}
......
......@@ -85,12 +85,12 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload):
"""
vLLM-specific health check payload for prefill workers in disaggregated mode.
The prefill handler expects a different structure with 'request_id' and 'sampling_params'.
The prefill handler expects PreprocessedRequest format with sampling_options and stop_conditions.
"""
def __init__(self, engine_client=None):
"""
Initialize vLLM prefill health check payload with proper structure.
Initialize vLLM prefill health check payload with proper PreprocessedRequest structure.
Args:
engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from.
......@@ -98,25 +98,21 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload):
"""
bos_token_id = _get_bos_token_id_from_engine(engine_client)
# Prefill handler expects request_id, token_ids, and sampling_params
# The sampling_params are converted via msgspec in the handler
# Prefill handler expects PreprocessedRequest format: token_ids, sampling_options, stop_conditions
# The handler will override max_tokens/min_tokens to 1 and add do_remote_decode
self.default_payload = {
"request_id": "health_check",
"token_ids": [bos_token_id],
"sampling_params": {
"max_tokens": 1,
"min_tokens": 1,
"sampling_options": {
"temperature": 0.0,
"top_p": 1.0,
"top_k": -1,
"detokenize": False,
},
"stop_conditions": {
"stop": None,
"stop_token_ids": None,
"include_stop_str_in_output": False,
"ignore_eos": False,
"extra_args": {
"kv_transfer_params": {
"do_remote_decode": True,
}
},
"min_tokens": 0,
},
}
super().__init__()
......@@ -191,6 +191,60 @@ def setup_vllm_engine(config, stat_logger=None):
return engine_client, vllm_config, default_sampling_params
async def register_vllm_model(
model_input: ModelInput,
model_type: ModelType,
generate_endpoint,
config: Config,
engine_client: AsyncLLM,
vllm_config,
migration_limit: int,
):
"""
Helper function to register a vLLM model with runtime configuration.
Args:
model_input: Input type for the model (e.g., ModelInput.Tokens)
model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
generate_endpoint: Endpoint to register
config: Configuration object
engine_client: vLLM engine client
vllm_config: vLLM configuration
migration_limit: Migration limit for the model
"""
runtime_config = ModelRuntimeConfig()
# Get runtime configuration from vLLM engine
logging.info(
f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
)
runtime_values = get_engine_cache_info(engine_client)
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
# Add tool/reasoning parsers for decode models
if model_type != ModelType.Prefill:
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
runtime_config.data_parallel_size = data_parallel_size
await register_llm(
model_input,
model_type,
generate_endpoint,
config.model,
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
migration_limit=migration_limit,
runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
)
async def init_prefill(runtime: DistributedRuntime, config: Config):
"""
Instantiate and serve
......@@ -214,6 +268,18 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
if kv_publishers:
handler.kv_publishers = kv_publishers
# Register prefill model with ModelType.Prefill
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
await register_vllm_model(
ModelInput.Tokens,
ModelType.Prefill,
generate_endpoint,
config,
engine_client,
vllm_config,
migration_limit=0, # Prefill doesn't support migration
)
health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
try:
......@@ -255,20 +321,6 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks")
prefill_router_client = (
await runtime.namespace(config.namespace)
.component("router") # Standalone router for prefill workers
.endpoint("generate")
.client()
)
prefill_worker_client = (
await runtime.namespace(config.namespace)
.component("prefill") # TODO don't hardcode
.endpoint("generate")
.client()
)
factory = StatLoggerFactory(
component,
config.engine_args.data_parallel_rank or 0,
......@@ -288,8 +340,6 @@ async def init(runtime: DistributedRuntime, config: Config):
component,
engine_client,
default_sampling_params,
prefill_worker_client,
prefill_router_client,
)
# Set up KV event publishers for prefix caching if enabled (one per dp_rank)
......@@ -305,35 +355,14 @@ async def init(runtime: DistributedRuntime, config: Config):
register_engine_metrics_callback(generate_endpoint, REGISTRY, "vllm:", "vLLM")
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
runtime_config = ModelRuntimeConfig()
# make a `collective_rpc` call to get runtime configuration values
logging.info(
"Getting engine runtime configuration metadata from vLLM engine..."
)
runtime_values = get_engine_cache_info(engine_client)
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(
vllm_config.parallel_config, "data_parallel_size", 1
)
runtime_config.data_parallel_size = data_parallel_size
await register_llm(
await register_vllm_model(
ModelInput.Tokens,
ModelType.Chat | ModelType.Completions,
generate_endpoint,
config.model,
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
config,
engine_client,
vllm_config,
migration_limit=config.migration_limit,
runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
)
health_check_payload = VllmHealthCheckPayload(engine_client).to_dict()
......
......@@ -54,6 +54,51 @@ The main KV-aware routing arguments:
For basic model registration without KV routing, you can use `--router-mode round-robin` or `--router-mode random` with both static and dynamic endpoints.
## Disaggregated Serving (Prefill and Decode)
Dynamo supports disaggregated serving where prefill (prompt processing) and decode (token generation) are handled by separate worker pools. When you register workers with `ModelType.Prefill` (see [Backend Guide](../development/backend-guide.md#model-types)), the frontend automatically detects them and activates an internal prefill router.
### Automatic Prefill Router Activation
The prefill router is automatically created when:
1. A decode model is registered (e.g., via `register_llm()` with `ModelType.Chat | ModelType.Completions`)
2. A prefill worker is detected with the same model name and `ModelType.Prefill`
**Key characteristics of the prefill router:**
- **Always uses KV-aware routing** regardless of the frontend's `--router-mode` setting
- **Always disables active block tracking** (`track_active_blocks=false`) since prefill workers don't perform decode
- **Seamlessly integrated** into the request pipeline between preprocessing and decode routing
- **Falls back gracefully** to decode-only mode if prefill fails or no prefill workers are available
### Setup Example
```python
# Decode worker registration (in your decode worker)
await register_llm(
model_input=ModelInput.Tokens,
model_type=ModelType.Chat | ModelType.Completions,
endpoint=generate_endpoint,
model_name="meta-llama/Llama-2-7b-hf",
# ... other parameters
)
# Prefill worker registration (in your prefill worker)
await register_llm(
model_input=ModelInput.Tokens,
model_type=ModelType.Prefill, # <-- Mark as prefill worker
endpoint=generate_endpoint,
model_name="meta-llama/Llama-2-7b-hf", # Must match decode model name
# ... other parameters
)
```
When both workers are registered, requests are automatically routed:
1. **Prefill phase** → Prefill router selects best prefill worker (KV-aware)
2. **Decode phase** → Decode router selects decode worker (uses frontend's `--router-mode`)
> [!Note]
> **WIP**: Currently, the prefill router always uses KV routing. Future updates will provide more fine-grained control over prefill routing behavior to match user-specified frontend router modes.
## Overview
The KV-aware router operates on two key principles to optimize request routing:
......
......@@ -929,6 +929,7 @@ pub async fn build_worker_selection_pipeline_chat(
busy_threshold,
chooser,
hf_tokenizer,
None,
)
.await?;
......@@ -991,12 +992,7 @@ pub async fn create_worker_selection_pipeline_chat(
let model_manager = std::sync::Arc::new(ModelManager::new());
Some(
model_manager
.kv_chooser_for(
&card.display_name,
&component,
card.kv_cache_block_size,
kv_router_config,
)
.kv_chooser_for(&component, card.kv_cache_block_size, kv_router_config)
.await?,
)
} else {
......
......@@ -779,12 +779,7 @@ async fn create_kv_router_from_endpoint(
// Create ModelManager and use it to create KvRouter (ensures etcd registration)
let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
let kv_router = model_manager
.kv_chooser_for(
"dummy_name", // does not matter, never cached
component,
block_size as u32,
kv_router_config,
)
.kv_chooser_for(component, block_size as u32, kv_router_config)
.await
.map_err(to_pyerr)?;
......
......@@ -7,24 +7,34 @@ use std::{
};
use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;
use dynamo_runtime::component::Component;
use dynamo_runtime::component::{Component, Endpoint};
use dynamo_runtime::prelude::DistributedRuntimeProvider;
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
use crate::{
kv_router::KvRouter,
types::generic::tensor::TensorStreamingEngine,
types::openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
},
};
use crate::{
kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector},
discovery::KV_ROUTERS_ROOT_PATH,
kv_router::{KvRouter, KvRouterConfig, scheduler::DefaultWorkerSelector},
model_card::ModelDeploymentCard,
model_type::ModelType,
types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine,
},
},
};
/// State for prefill router activation rendezvous
enum PrefillActivationState {
/// Decode model registered, waiting for prefill endpoint
DecodeWaiting(oneshot::Sender<Endpoint>),
/// Prefill endpoint arrived, waiting for decode model to register
PrefillReady(oneshot::Receiver<Endpoint>),
}
#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
#[error("Model not found: {0}")]
......@@ -41,11 +51,13 @@ pub struct ModelManager {
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
prefill_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// Prefill models don't have engines - they're only tracked for discovery/lifecycle
prefill_engines: RwLock<ModelEngines<()>>,
// These are Mutex because we read and write rarely and equally
cards: Mutex<HashMap<String, ModelDeploymentCard>>,
kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>, // Key: component service_name
prefill_router_activators: Mutex<HashMap<String, PrefillActivationState>>,
}
impl Default for ModelManager {
......@@ -64,6 +76,7 @@ impl ModelManager {
prefill_engines: RwLock::new(ModelEngines::default()),
cards: Mutex::new(HashMap::new()),
kv_choosers: Mutex::new(HashMap::new()),
prefill_router_activators: Mutex::new(HashMap::new()),
}
}
......@@ -188,10 +201,9 @@ impl ModelManager {
&self,
model: &str,
card_checksum: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write();
clients.add(model, card_checksum, engine)
clients.add(model, card_checksum, ())
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
......@@ -263,17 +275,6 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_prefill_engine(
&self,
model: &str,
) -> Result<TensorStreamingEngine, ModelManagerError> {
self.prefill_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
......@@ -288,19 +289,20 @@ impl ModelManager {
pub async fn kv_chooser_for(
&self,
model_name: &str,
component: &Component,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> {
if let Some(kv_chooser) = self.get_kv_chooser(model_name) {
let service_name = component.service_name();
if let Some(kv_chooser) = self.get_kv_chooser(&service_name) {
// Check if the existing router has a different block size
if kv_chooser.block_size() != kv_cache_block_size {
tracing::warn!(
model_name = %model_name,
component = %service_name,
existing_block_size = %kv_chooser.block_size(),
requested_block_size = %kv_cache_block_size,
"KV Router block size mismatch! Model is requesting a different kv_cache_block_size than the existing router. \
"KV Router block size mismatch! Component is requesting a different kv_cache_block_size than the existing router. \
This will cause routing to fail silently. Consider using the same block size or restarting the router."
);
}
......@@ -339,12 +341,109 @@ impl ModelManager {
let new_kv_chooser = Arc::new(chooser);
self.kv_choosers
.lock()
.insert(model_name.to_string(), new_kv_chooser.clone());
.insert(service_name, new_kv_chooser.clone());
Ok(new_kv_chooser)
}
fn get_kv_chooser(&self, model_name: &str) -> Option<Arc<KvRouter>> {
self.kv_choosers.lock().get(model_name).cloned()
fn get_kv_chooser(&self, service_name: &str) -> Option<Arc<KvRouter>> {
self.kv_choosers.lock().get(service_name).cloned()
}
/// Register a prefill router for a decode model. Returns a receiver that will be
/// activated when the corresponding prefill model is discovered.
/// Returns None if the decode model was already registered.
pub fn register_prefill_router(
&self,
model_name: String,
) -> Option<oneshot::Receiver<Endpoint>> {
let mut activators = self.prefill_router_activators.lock();
match activators.remove(&model_name) {
Some(PrefillActivationState::PrefillReady(rx)) => {
// Prefill endpoint already arrived - rx will immediately resolve
tracing::debug!(
model_name = %model_name,
"Prefill endpoint already available, returning receiver with endpoint"
);
Some(rx)
}
Some(PrefillActivationState::DecodeWaiting(tx)) => {
// Decode already registered - this shouldn't happen, restore state and return None
tracing::error!(
model_name = %model_name,
"Decode model already registered for this prefill router"
);
activators.insert(model_name, PrefillActivationState::DecodeWaiting(tx));
None
}
None => {
// New registration: create tx/rx pair, store sender and return receiver
let (tx, rx) = oneshot::channel();
activators.insert(
model_name.clone(),
PrefillActivationState::DecodeWaiting(tx),
);
tracing::debug!(
model_name = %model_name,
"No prefill endpoint available yet, storing sender for future activation"
);
Some(rx)
}
}
}
/// Activate a prefill router by sending the endpoint through the oneshot channel.
/// If no decode model has registered yet, stores the endpoint for future retrieval.
pub fn activate_prefill_router(
&self,
model_name: &str,
endpoint: Endpoint,
) -> anyhow::Result<()> {
let mut activators = self.prefill_router_activators.lock();
match activators.remove(model_name) {
Some(PrefillActivationState::DecodeWaiting(sender)) => {
// Decode model already registered
sender.send(endpoint).map_err(|_| {
anyhow::anyhow!(
"Failed to send endpoint to prefill router activator for model: {}",
model_name
)
})?;
tracing::info!(
model_name = %model_name,
"Activated prefill router for already-registered decode model"
);
Ok(())
}
Some(PrefillActivationState::PrefillReady(_)) => {
// Prefill already activated - this shouldn't happen
anyhow::bail!("Prefill router for model {} already activated", model_name);
}
None => {
// Decode model not registered yet - create pair and immediately send endpoint
let (tx, rx) = oneshot::channel();
tx.send(endpoint).map_err(|_| {
anyhow::anyhow!("Failed to send endpoint for prefill model: {}", model_name)
})?;
// Store the receiver for when decode model registers
activators.insert(
model_name.to_string(),
PrefillActivationState::PrefillReady(rx),
);
tracing::info!(
model_name = %model_name,
"Stored prefill endpoint for future decode model registration"
);
Ok(())
}
}
}
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
......
......@@ -20,7 +20,7 @@ use dynamo_runtime::{
use crate::{
backend::Backend,
entrypoint,
kv_router::KvRouterConfig,
kv_router::{KvRouterConfig, PrefillRouter},
model_card::{self, ModelDeploymentCard},
model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
......@@ -318,15 +318,26 @@ impl ModelWatcher {
.drt
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?;
let endpoint = component.endpoint(&endpoint_id.name);
let client = endpoint.client().await?;
tracing::debug!(model_name = card.name(), "adding model");
self.manager.save_model_card(key, card.clone())?;
if self.manager.has_model_any(card.name()) {
// Check if we should skip registration:
// - Skip if a model with this name already exists
// - UNLESS this is a prefill model and no prefill model exists yet for this name
let is_new_prefill = card.model_type.supports_prefill()
&& !self
.manager
.list_prefill_models()
.contains(&card.name().to_string());
if self.manager.has_model_any(card.name()) && !is_new_prefill {
tracing::debug!(
model_name = card.name(),
namespace = endpoint_id.namespace,
"New endpoint for existing model"
model_type = %card.model_type,
"New endpoint for existing model, skipping"
);
return Ok(());
}
......@@ -346,12 +357,7 @@ impl ModelWatcher {
let kv_chooser = if self.router_mode == RouterMode::KV {
Some(
self.manager
.kv_chooser_for(
card.name(),
&component,
card.kv_cache_block_size,
self.kv_router_config,
)
.kv_chooser_for(&component, card.kv_cache_block_size, self.kv_router_config)
.await?,
)
} else {
......@@ -361,6 +367,25 @@ impl ModelWatcher {
// This is expensive, we are loading ~10MiB JSON, so only do it once
let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
// Create prefill chooser once if we're building pipelines
// Both chat and completions will share the same prefill chooser instance
let prefill_chooser = self
.manager
.register_prefill_router(card.name().to_string())
.map(|rx| {
// Create prefill-specific config with track_active_blocks disabled
let mut prefill_config = self.kv_router_config.unwrap_or_default();
prefill_config.router_track_active_blocks = false;
PrefillRouter::new(
rx,
self.manager.clone(),
self.router_mode,
card.kv_cache_block_size,
Some(prefill_config),
)
});
// Add chat engine only if the model supports chat
if card.model_type.supports_chat() {
let chat_engine = entrypoint::build_routed_pipeline::<
......@@ -373,6 +398,7 @@ impl ModelWatcher {
self.busy_threshold,
kv_chooser.clone(),
tokenizer_hf.clone(),
prefill_chooser.clone(),
)
.await
.context("build_routed_pipeline")?;
......@@ -403,6 +429,7 @@ impl ModelWatcher {
kv_chooser,
preprocessor,
tokenizer_hf,
prefill_chooser,
)
.await
.context("build_routed_pipeline_with_preprocessor")?;
......@@ -503,11 +530,28 @@ impl ModelWatcher {
);
}
// This is effectively a guardrail + passthrough for now
// TODO: Build proper prefill pipeline with KV router (track_active_blocks=false)
tracing::info!(
model_name = card.name(),
"Prefill model registered (passthrough, not yet functional)"
"Prefill model detected, registering and activating prefill router"
);
// Register prefill model for tracking (no engine needed, just lifecycle)
self.manager
.add_prefill_model(card.name(), checksum)
.context("add_prefill_model")?;
// Activate the prefill router with the endpoint for this prefill model
let Ok(()) = self.manager.activate_prefill_router(card.name(), endpoint) else {
tracing::warn!(
model_name = card.name(),
"Failed to activate prefill router - prefill model may already be activated"
);
return Ok(());
};
tracing::info!(
model_name = card.name(),
"Prefill model registered and router activated successfully"
);
} else {
// Reject unsupported combinations
......
......@@ -8,7 +8,7 @@ use crate::{
discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig},
kv_router::{KvPushRouter, KvRouter},
kv_router::{KvPushRouter, KvRouter, PrefillRouter},
migration::Migration,
model_card::{self, ModelDeploymentCard},
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
......@@ -122,7 +122,6 @@ pub async fn prepare_engine(
Some(
model_manager
.kv_chooser_for(
local_model.display_name(),
&component,
card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config),
......@@ -144,6 +143,7 @@ pub async fn prepare_engine(
None,
kv_chooser.clone(),
hf_tokenizer,
None, // No prefill chooser in static mode
)
.await?;
......@@ -232,6 +232,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>,
hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
Req: Data,
......@@ -254,10 +255,12 @@ where
chooser,
preprocessor,
hf_tokenizer,
prefill_chooser,
)
.await
}
#[allow(clippy::too_many_arguments)]
pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
card: &ModelDeploymentCard,
client: &Client,
......@@ -266,6 +269,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
chooser: Option<Arc<KvRouter>>,
preprocessor: Arc<OpenAIPreprocessor>,
hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
Req: Data,
......@@ -298,6 +302,7 @@ where
worker_monitor,
)
.await?;
let service_backend = match router_mode {
RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
ServiceBackend::from_engine(Arc::new(router))
......@@ -311,14 +316,22 @@ where
}
};
// Use the provided prefill chooser, or create a disabled one if not provided
let prefill_chooser = prefill_chooser.unwrap_or_else(|| PrefillRouter::disabled(router_mode));
let prefill_op = prefill_chooser.into_operator();
// Link with prefill chooser including backward edge for response flow
let engine = frontend
.link(preprocessor_op.forward_edge())?
.link(backend.forward_edge())?
.link(migration.forward_edge())?
.link(prefill_op.forward_edge())?
.link(service_backend)?
.link(prefill_op.backward_edge())?
.link(migration.backward_edge())?
.link(backend.backward_edge())?
.link(preprocessor_op.backward_edge())?
.link(frontend)?;
Ok(engine)
}
......@@ -81,7 +81,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Some(
manager
.kv_chooser_for(
local_model.display_name(),
&component,
card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config),
......@@ -103,6 +102,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
None,
kv_chooser.clone(),
tokenizer_hf.clone(),
None, // No prefill chooser in grpc static mode
)
.await?;
manager.add_chat_completions_model(
......@@ -111,12 +111,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
chat_engine,
)?;
let completions_engine =
entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?;
let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser,
tokenizer_hf,
None, // No prefill chooser in grpc static mode
)
.await?;
manager.add_completions_model(
local_model.display_name(),
checksum,
......
......@@ -121,7 +121,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Some(
manager
.kv_chooser_for(
local_model.display_name(),
&component,
card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config),
......@@ -143,6 +142,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
None,
kv_chooser.clone(),
tokenizer_hf.clone(),
None, // No prefill chooser in http static mode
)
.await?;
manager.add_chat_completions_model(
......@@ -151,12 +151,19 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
chat_engine,
)?;
let completions_engine =
entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?;
let completions_engine = entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(
card,
&client,
router_mode,
None,
kv_chooser,
tokenizer_hf,
None, // No prefill chooser in http static mode
)
.await?;
manager.add_completions_model(
local_model.display_name(),
checksum,
......
......@@ -22,6 +22,7 @@ use serde::{Deserialize, Serialize};
pub mod approx;
pub mod indexer;
pub mod prefill_router;
pub mod protocols;
pub mod publisher;
pub mod recorder;
......@@ -30,6 +31,8 @@ pub mod scoring;
pub mod sequence;
pub mod subscriber;
pub use prefill_router::PrefillRouter;
use crate::{
kv_router::{
approx::ApproxKvIndexer,
......@@ -247,7 +250,7 @@ impl KvRouter {
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
model_card::ROOT_PATH,
&format!("{}/{}", model_card::ROOT_PATH, component.path()),
key_extractors::lease_id,
|card: ModelDeploymentCard| Some(card.runtime_config),
cancellation_token.clone(),
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::{Arc, OnceLock};
use anyhow::{Result, bail};
use futures::StreamExt;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use dynamo_runtime::{
component::Endpoint,
pipeline::{
AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator,
PushRouter, RouterMode, ServerStreamingEngine, SingleIn, async_trait,
},
protocols::{annotated::Annotated, maybe_error::MaybeError},
};
use crate::{
discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
};
/// The inner router used by PrefillRouter
enum InnerPrefillRouter {
/// KV-aware routing using KvPushRouter
KvRouter(Arc<KvPushRouter>),
/// Simple routing (RoundRobin, Random, Direct)
SimpleRouter(Arc<PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>>),
}
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>,
cancel_token: CancellationToken,
router_mode: RouterMode,
}
impl PrefillRouter {
/// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled(router_mode: RouterMode) -> Arc<Self> {
Arc::new(Self {
prefill_router: OnceLock::new(),
cancel_token: CancellationToken::new(),
router_mode,
})
}
pub fn new(
activation_rx: oneshot::Receiver<Endpoint>,
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> Arc<Self> {
let prefill_router = OnceLock::new();
let cancel_token = CancellationToken::new();
let router = Arc::new(Self {
prefill_router,
cancel_token: cancel_token.clone(),
router_mode,
});
// Spawn background task to wait for activation
let router_clone = router.clone();
tokio::spawn(async move {
tokio::select! {
result = activation_rx => {
let Ok(endpoint) = result else {
tracing::debug!("Prefill router activation channel closed without receiving endpoint");
return;
};
if let Err(e) = router_clone.activate(
endpoint,
model_manager,
kv_cache_block_size,
kv_router_config,
).await {
tracing::error!(error = %e, "Failed to activate prefill router");
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Prefill router activation cancelled");
}
}
});
router
}
/// Activate the prefill router with the provided endpoint
async fn activate(
&self,
endpoint: Endpoint,
model_manager: Arc<ModelManager>,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> Result<()> {
tracing::info!(
router_mode = ?self.router_mode,
"Activating prefill router"
);
let client = endpoint.client().await?;
let inner_router = if self.router_mode.is_kv_routing() {
// Create KV chooser using the component from the endpoint
let kv_chooser = model_manager
.kv_chooser_for(endpoint.component(), kv_cache_block_size, kv_router_config)
.await?;
// Build the PushRouter for prefill with KV mode
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
RouterMode::KV,
None, // busy_threshold
None, // worker_monitor
)
.await?;
// Wrap it in KvPushRouter
InnerPrefillRouter::KvRouter(Arc::new(KvPushRouter::new(push_router, kv_chooser)))
} else {
// Create simple push router with the frontend's router mode
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
self.router_mode,
None, // busy_threshold
None, // worker_monitor
)
.await?;
InnerPrefillRouter::SimpleRouter(Arc::new(push_router))
};
// Set the router (ignore error if already set)
let _ = self.prefill_router.set(inner_router);
tracing::info!(
router_mode = ?self.router_mode,
"Prefill router activated successfully"
);
Ok(())
}
/// Call the prefill router and extract disaggregated_params
async fn call_prefill(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<serde_json::Value> {
// Get the prefill router, error if not activated
let Some(prefill_router) = self.prefill_router.get() else {
bail!("Prefill router not yet activated");
};
// Call the appropriate router based on the type
let mut prefill_response = match prefill_router {
InnerPrefillRouter::KvRouter(router) => router.generate(request).await?,
InnerPrefillRouter::SimpleRouter(router) => router.generate(request).await?,
};
let Some(first_output) = prefill_response.next().await else {
bail!("Prefill router returned no output (stream ended)");
};
if let Some(err) = first_output.err() {
while prefill_response.next().await.is_some() {}
bail!("Prefill router returned error in output: {:?}", err);
}
let Some(output) = &first_output.data else {
while prefill_response.next().await.is_some() {}
bail!("Prefill router output has no data field");
};
let Some(disaggregated_params) = output.disaggregated_params.clone() else {
while prefill_response.next().await.is_some() {}
bail!("Prefill router output missing disaggregated_params");
};
while prefill_response.next().await.is_some() {}
Ok(disaggregated_params)
}
}
impl Drop for PrefillRouter {
fn drop(&mut self) {
tracing::debug!("Dropping PrefillRouter, cancelling background activation task");
self.cancel_token.cancel();
}
}
#[async_trait]
impl
Operator<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
> for PrefillRouter
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
// Extract request data while preserving context
let (req, context) = request.into_parts();
let request_id = context.id().to_string();
// Prepare prefill request with linked context for cancellation propagation
let prefill_req = req.clone();
let prefill_context = Context::with_id(prefill_req, request_id.clone());
// Link the prefill context as a child so that kill signals propagate
context.controller().link_child(prefill_context.context());
let prefill_request = prefill_context;
// Attempt prefill and handle results
match self.call_prefill(prefill_request).await {
Ok(disaggregated_params) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode");
// Update request with disaggregated_params and router config
let mut decode_req = req;
decode_req.disaggregated_params = Some(disaggregated_params);
// Set router_config_override for decode: overlap_score_weight = 0
let existing_override = decode_req.router_config_override.take();
decode_req.router_config_override = Some(RouterConfigOverride {
overlap_score_weight: Some(0.0),
..existing_override.unwrap_or_default()
});
// Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await
}
Err(e) => {
tracing::debug!(error = %e, "Remote prefill failed, falling back to decode-only");
next.generate(context.map(|_| req)).await
}
}
}
}
......@@ -306,8 +306,7 @@ pub async fn start_kv_router_background(
let key = String::from_utf8_lossy(kv.key());
// Extract the hex worker ID after the colon (e.g., "generate:694d99badb9f7c07" -> "694d99badb9f7c07")
let Some(worker_id_str) = key.split(':').next_back() else {
let Some(worker_id_str) = key.split(&['/', ':'][..]).next_back() else {
tracing::warn!("Could not extract worker ID from instance key: {key}");
continue;
};
......
......@@ -40,18 +40,22 @@ class DynamoWorkerProcess(ManagedProcess):
"3",
]
health_check_urls = [
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
]
# Set port based on worker type
port = "8082" if is_prefill else "8081"
# Add prefill worker flag if needed
# Configure health check based on worker type
if is_prefill:
# Prefill workers check their own status endpoint
command.append("--is-prefill-worker")
health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)]
else:
# Decode workers should also check their own status endpoint first,
# then verify the frontend sees the model
health_check_urls = [
(f"http://localhost:{port}/health", self.is_ready),
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
]
# Set debug logging environment
env = os.environ.copy()
......@@ -155,10 +159,10 @@ def test_request_cancellation_vllm_aggregated(
# Send the request (non-blocking)
cancellable_req = send_cancellable_request(request_type)
# Poll for "New Request ID" pattern
# Poll for "Decode Request ID" pattern (vLLM v2 pattern)
request_id, worker_log_offset = poll_for_pattern(
process=worker,
pattern="New Request ID: ",
pattern="Decode Request ID: ",
log_offset=worker_log_offset,
match_type="contains",
)
......@@ -223,17 +227,17 @@ def test_request_cancellation_vllm_decode_cancel(
# Send streaming request (non-blocking)
cancellable_req = send_cancellable_request("chat_completion_stream")
# Poll for "New Request ID" pattern in decode worker
# Poll for "Decode Request ID" pattern in decode worker (vLLM v2 pattern)
request_id, decode_log_offset = poll_for_pattern(
process=decode_worker,
pattern="New Request ID: ",
pattern="Decode Request ID: ",
match_type="contains",
)
# Verify same request ID reached prefill worker (as "New Prefill Request ID")
# Verify same request ID reached prefill worker (as "Prefill Request ID")
_, prefill_log_offset = poll_for_pattern(
process=prefill_worker,
pattern=f"New Prefill Request ID: {request_id}",
pattern=f"Prefill Request ID: {request_id}",
)
# Read 5 streaming responses (decode phase)
......@@ -288,9 +292,11 @@ def test_request_cancellation_vllm_remote_prefill_cancel(
with DynamoWorkerProcess(request, is_prefill=False) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# Step 4: Test request cancellation during remote prefill phase
# Step 4: Test request cancellation during prefill phase
# Note: With the new architecture, prefill routing happens in the frontend,
# so the request goes directly to the prefill worker first
logger.info(
"Testing completion request cancellation during remote prefill phase..."
"Testing completion request cancellation during prefill phase..."
)
# Send request with long prompt (non-blocking)
......@@ -298,37 +304,25 @@ def test_request_cancellation_vllm_remote_prefill_cancel(
"completion", use_long_prompt=True
)
# Poll for "New Request ID" pattern in decode worker
request_id, decode_log_offset = poll_for_pattern(
process=decode_worker,
pattern="New Request ID: ",
match_type="contains",
)
# Poll for same request ID in prefill worker (as "New Prefill Request ID")
_, prefill_log_offset = poll_for_pattern(
# Poll for "Prefill Request ID" pattern in prefill worker (vLLM v2 pattern)
# With new architecture, prefill is routed by frontend's internal router
request_id, prefill_log_offset = poll_for_pattern(
process=prefill_worker,
pattern=f"New Prefill Request ID: {request_id}",
pattern="Prefill Request ID: ",
match_type="contains",
)
# Cancel during prefill phase
cancellable_req.cancel()
logger.info(f"Cancelled request ID: {request_id} during remote prefill")
logger.info(f"Cancelled request ID: {request_id} during prefill")
# Poll for "Aborted Prefill Request ID" in prefill worker first (where cancellation happens)
# Poll for "Aborted Prefill Request ID" in prefill worker (where cancellation happens)
_, prefill_log_offset = poll_for_pattern(
process=prefill_worker,
pattern=f"Aborted Prefill Request ID: {request_id}",
log_offset=prefill_log_offset,
)
# Then poll for "Aborted Remote Prefill Request ID" in decode worker
_, decode_log_offset = poll_for_pattern(
process=decode_worker,
pattern=f"Aborted Remote Prefill Request ID: {request_id}",
log_offset=decode_log_offset,
)
# Verify frontend log has kill message
_, frontend_log_offset = poll_for_pattern(
process=frontend,
......
......@@ -173,13 +173,23 @@ def validate_openai_response(response: requests.Response) -> None:
def check_worker_received_request(worker_process: DynamoWorkerProcess) -> bool:
"""Check if the worker logs contain 'New Request ID:' message indicating it received a request"""
"""Check if the worker logs contain request ID message indicating it received a request.
Supports multiple backend patterns:
- vLLM: "Decode Request ID:" or "Prefill Request ID:"
- SGLang/TensorRT-LLM: "New Request ID:"
"""
log_path = worker_process._log_path
if log_path and os.path.exists(log_path):
try:
with open(log_path, "r") as f:
log_content = f.read()
return "New Request ID: " in log_content
# Check for any of the supported patterns
return (
"New Request ID: " in log_content
or "Decode Request ID: " in log_content
or "Prefill Request ID: " in log_content
)
except Exception as e:
logger.warning(f"Could not read worker log file {log_path}: {e}")
return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment