Unverified Commit e01c6e99 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

feat: bake prefill router into frontend, supporting vllm for now (#3762)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent cd5e7e33
...@@ -118,27 +118,14 @@ python -m dynamo.frontend --help ...@@ -118,27 +118,14 @@ python -m dynamo.frontend --help
For detailed explanations of router arguments (especially KV cache routing parameters), see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md). For detailed explanations of router arguments (especially KV cache routing parameters), see the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md).
#### Launching a Standalone Router for Prefill Workers (Optional) #### Disaggregated Serving with Automatic Prefill Routing
If you're using disaggregated serving with separate prefill and decode workers, you should also launch a standalone router for prefill workers. This router handles routing prefill requests to dedicated prefill workers. When using a standalone prefill router, it's recommended to start the frontend (decode router) with `--kv-overlap-score-weight 0` for pure load balancing (as prefix-aware routing is now handled by the standalone router): When you launch prefill workers using `run_engines.sh --prefill`, the frontend automatically detects them and activates an internal prefill router. This prefill router:
- Automatically routes initial token processing to dedicated prefill workers
- Uses KV-aware routing regardless of the frontend's `--router-mode` setting
- Seamlessly integrates with your decode workers for token generation
```bash No additional configuration is needed - simply launch both decode and prefill workers, and the system handles the rest. See the [KV Cache Routing documentation](../../docs/architecture/kv_cache_routing.md#disaggregated-serving-prefill-and-decode) for more details.
# Start the decode router with pure load balancing
python -m dynamo.frontend \
--router-mode kv \
--router-reset-states \
--http-port 8000 \
--kv-overlap-score-weight 0
# In another terminal, start the standalone router for prefill workers
python -m dynamo.router \
--endpoint dynamo.prefill.generate \
--block-size 64 \
--router-reset-states \
--no-track-active-blocks
```
The `--router-reset-states` flag clears any previous state, and `--no-track-active-blocks` disables active block tracking (suitable for prefill-only routing where decode load is not relevant).
**Note**: If you're unsure whether your backend engines correctly emit KV events for certain models (e.g., hybrid models like gpt-oss or nemotron nano 2), use the `--no-kv-events` flag to disable KV event tracking and use approximate KV indexing instead: **Note**: If you're unsure whether your backend engines correctly emit KV events for certain models (e.g., hybrid models like gpt-oss or nemotron nano 2), use the `--no-kv-events` flag to disable KV event tracking and use approximate KV indexing instead:
......
...@@ -11,20 +11,13 @@ export PYTHONHASHSEED=0 ...@@ -11,20 +11,13 @@ export PYTHONHASHSEED=0
MODEL="Qwen/Qwen3-0.6B" MODEL="Qwen/Qwen3-0.6B"
BLOCK_SIZE=64 BLOCK_SIZE=64
# run decode router with kv-overlap-score-weight 0 for pure load balancing # Start frontend with KV routing
# The frontend will automatically detect prefill workers and activate an internal prefill router
python -m dynamo.frontend \ python -m dynamo.frontend \
--router-mode kv \ --router-mode kv \
--http-port 8000 \ --http-port 8000 \
--kv-overlap-score-weight 0 \
--router-reset-states & --router-reset-states &
# run standalone router service for prefill workers
python -m dynamo.router \
--endpoint dynamo.prefill.generate \
--block-size $BLOCK_SIZE \
--router-reset-states \
--no-track-active-blocks &
# two decode workers # two decode workers
# --enforce-eager is added for quick deployment. for production use, need to remove this flag # --enforce-eager is added for quick deployment. for production use, need to remove this flag
CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \ CUDA_VISIBLE_DEVICES=0 python3 -m dynamo.vllm \
...@@ -38,6 +31,8 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \ ...@@ -38,6 +31,8 @@ CUDA_VISIBLE_DEVICES=1 python3 -m dynamo.vllm \
--enforce-eager & --enforce-eager &
# two prefill workers # two prefill workers
# When registered with --is-prefill-worker, these workers are automatically detected
# by the frontend, which activates an internal prefill router for KV-aware prefill routing
CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.vllm \ CUDA_VISIBLE_DEVICES=2 python3 -m dynamo.vllm \
--model $MODEL \ --model $MODEL \
--block-size $BLOCK_SIZE \ --block-size $BLOCK_SIZE \
......
...@@ -40,7 +40,12 @@ The standalone router exposes two endpoints via the Dynamo runtime: ...@@ -40,7 +40,12 @@ The standalone router exposes two endpoints via the Dynamo runtime:
Clients query the `find_best_worker` endpoint to determine which worker should process each request, then call the selected worker directly. Clients query the `find_best_worker` endpoint to determine which worker should process each request, then call the selected worker directly.
## Example: Disaggregated Serving with Prefill Workers ## Example: Manual Disaggregated Serving (Alternative Setup)
> [!Note]
> **This is an alternative advanced setup.** The recommended approach for disaggregated serving is to use the frontend's automatic prefill routing, which activates when you register workers with `ModelType.Prefill`. See the [KV Cache Routing documentation](/docs/architecture/kv_cache_routing.md#disaggregated-serving-prefill-and-decode) for the default setup.
>
> Use this manual setup if you need explicit control over prefill routing configuration or want to manage prefill and decode routers separately.
See [`components/backends/vllm/launch/disagg_router.sh`](/components/backends/vllm/launch/disagg_router.sh) for a complete example. See [`components/backends/vllm/launch/disagg_router.sh`](/components/backends/vllm/launch/disagg_router.sh) for a complete example.
......
...@@ -4,13 +4,10 @@ ...@@ -4,13 +4,10 @@
import asyncio import asyncio
import logging import logging
import os import os
import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from copy import deepcopy
from typing import Any, AsyncGenerator, Dict from typing import Any, AsyncGenerator, Dict
import msgspec
from vllm.inputs import TokensPrompt from vllm.inputs import TokensPrompt
from vllm.sampling_params import SamplingParams from vllm.sampling_params import SamplingParams
from vllm.v1.engine.exceptions import EngineDeadError from vllm.v1.engine.exceptions import EngineDeadError
...@@ -23,6 +20,35 @@ configure_dynamo_logging() ...@@ -23,6 +20,35 @@ configure_dynamo_logging()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def build_sampling_params(
request: Dict[str, Any], default_sampling_params: Dict[str, Any]
) -> SamplingParams:
"""
Build SamplingParams from a PreprocessedRequest.
Args:
request: The PreprocessedRequest dict with 'sampling_options' and 'stop_conditions'
default_sampling_params: Default sampling parameters to initialize with
Returns:
SamplingParams configured from the request
"""
sampling_params = SamplingParams(**default_sampling_params)
sampling_params.detokenize = False
# Apply sampling_options
for key, value in request["sampling_options"].items():
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
# Apply stop_conditions
for key, value in request["stop_conditions"].items():
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
return sampling_params
class BaseWorkerHandler(ABC): class BaseWorkerHandler(ABC):
""" """
Request handler for the generate and clear_kv_blocks endpoints. Request handler for the generate and clear_kv_blocks endpoints.
...@@ -130,93 +156,31 @@ class DecodeWorkerHandler(BaseWorkerHandler): ...@@ -130,93 +156,31 @@ class DecodeWorkerHandler(BaseWorkerHandler):
component, component,
engine, engine,
default_sampling_params, default_sampling_params,
prefill_worker_client=None,
prefill_router_client=None,
): ):
super().__init__(runtime, component, engine, default_sampling_params) super().__init__(runtime, component, engine, default_sampling_params)
self.prefill_worker_client = prefill_worker_client
self.prefill_router_client = prefill_router_client
async def generate(self, request, context): async def generate(self, request, context):
request_id = str(uuid.uuid4().hex) # Use context ID for request tracking and correlation
logger.debug(f"New Request ID: {request_id}") request_id = context.id()
logger.debug(f"Decode Request ID: {request_id}")
prompt = TokensPrompt(prompt_token_ids=request["token_ids"]) prompt = TokensPrompt(prompt_token_ids=request["token_ids"])
sampling_params = SamplingParams(**self.default_sampling_params) # Build sampling params from request
sampling_params = build_sampling_params(request, self.default_sampling_params)
sampling_params.detokenize = False
for key, value in request["sampling_options"].items():
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
for key, value in request["stop_conditions"].items():
if value is not None and hasattr(sampling_params, key):
setattr(sampling_params, key, value)
# Use prefill router or worker if available
can_prefill = (
self.prefill_worker_client is not None
) and self.prefill_worker_client.instance_ids()
if can_prefill:
# Create prefill sampling params with modifications
prefill_sampling_params = deepcopy(sampling_params)
if prefill_sampling_params.extra_args is None:
prefill_sampling_params.extra_args = {}
prefill_sampling_params.extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
prefill_sampling_params.max_tokens = 1
prefill_sampling_params.min_tokens = 1
try:
# Send request with sampling_params and request_id in extra_args
prefill_request = request.copy()
# TODO (PeaBrane): this smells a bit bad as not we have two nestings
# of extra_args (an inner one again in sampling_params)
prefill_request["extra_args"] = {
"sampling_params": msgspec.to_builtins(prefill_sampling_params),
"request_id": request_id,
}
# Try router first if available, fallback to worker
if (
self.prefill_router_client is not None
and self.prefill_router_client.instance_ids()
):
# Call router's generate endpoint which returns LLMEngineOutput
prefill_response = await anext(
await self.prefill_router_client.generate(
prefill_request, context=context
)
)
else:
# Fallback to direct worker with same format
prefill_response = await anext(
await self.prefill_worker_client.round_robin(
prefill_request, context=context
)
)
prefill_output = prefill_response.data()
# Extract kv_transfer_params from response # Extract disaggregated_params from request (set by prefill router in Rust frontend)
kv_transfer_params = prefill_output.get("extra_args", {}).get( disaggregated_params = request.get("disaggregated_params")
"kv_transfer_params" if disaggregated_params:
) # Prefill was performed - use the disaggregated params
if kv_transfer_params:
if sampling_params.extra_args is None: if sampling_params.extra_args is None:
sampling_params.extra_args = {} sampling_params.extra_args = {}
sampling_params.extra_args[ sampling_params.extra_args["kv_transfer_params"] = disaggregated_params.get(
"kv_transfer_params" "kv_transfer_params"
] = kv_transfer_params )
logger.debug(
except Exception as e: f"Using disaggregated params from prefill for request {request_id}"
if context.is_stopped() or context.is_killed(): )
logger.debug(f"Aborted Remote Prefill Request ID: {request_id}")
return
logger.warning(f"Prefill error: {e}, falling back to local prefill")
dp_rank = request.get("dp_rank", None) dp_rank = request.get("dp_rank", None)
...@@ -238,17 +202,25 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -238,17 +202,25 @@ class PrefillWorkerHandler(BaseWorkerHandler):
super().__init__(runtime, component, engine, default_sampling_params) super().__init__(runtime, component, engine, default_sampling_params)
async def generate(self, request, context): async def generate(self, request, context):
# Extract from PreprocessedRequest format - request_id and sampling_params from extra_args # Use context ID for request tracking and correlation with decode phase
extra_args = request.get("extra_args", {}) request_id = context.id()
request_id = extra_args.get("request_id", str(uuid.uuid4().hex)) logger.debug(f"Prefill Request ID: {request_id}")
logger.debug(f"New Prefill Request ID: {request_id}")
token_ids = request["token_ids"] token_ids = request["token_ids"]
prompt = TokensPrompt(prompt_token_ids=token_ids) prompt = TokensPrompt(prompt_token_ids=token_ids)
# Get sampling_params from extra_args # Build sampling params from request using shared utility
sampling_params_dict = extra_args.get("sampling_params", {}) sampling_params = build_sampling_params(request, self.default_sampling_params)
sampling_params = msgspec.convert(sampling_params_dict, SamplingParams)
# Configure for prefill-only mode with remote decode
if sampling_params.extra_args is None:
sampling_params.extra_args = {}
sampling_params.extra_args["kv_transfer_params"] = {
"do_remote_decode": True,
}
# Override for prefill: only generate 1 token
sampling_params.max_tokens = 1
sampling_params.min_tokens = 1
dp_rank = request.get("dp_rank", None) dp_rank = request.get("dp_rank", None)
...@@ -271,10 +243,10 @@ class PrefillWorkerHandler(BaseWorkerHandler): ...@@ -271,10 +243,10 @@ class PrefillWorkerHandler(BaseWorkerHandler):
output: Dict[str, Any] = { output: Dict[str, Any] = {
"token_ids": list(token_ids), "token_ids": list(token_ids),
"extra_args": ( "disaggregated_params": (
{"kv_transfer_params": res.kv_transfer_params} {"kv_transfer_params": res.kv_transfer_params}
if res.kv_transfer_params if res.kv_transfer_params
else {} else None
), ),
} }
......
...@@ -85,12 +85,12 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload): ...@@ -85,12 +85,12 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload):
""" """
vLLM-specific health check payload for prefill workers in disaggregated mode. vLLM-specific health check payload for prefill workers in disaggregated mode.
The prefill handler expects a different structure with 'request_id' and 'sampling_params'. The prefill handler expects PreprocessedRequest format with sampling_options and stop_conditions.
""" """
def __init__(self, engine_client=None): def __init__(self, engine_client=None):
""" """
Initialize vLLM prefill health check payload with proper structure. Initialize vLLM prefill health check payload with proper PreprocessedRequest structure.
Args: Args:
engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from. engine_client: Optional vLLM AsyncLLM engine client to extract BOS token from.
...@@ -98,25 +98,21 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload): ...@@ -98,25 +98,21 @@ class VllmPrefillHealthCheckPayload(HealthCheckPayload):
""" """
bos_token_id = _get_bos_token_id_from_engine(engine_client) bos_token_id = _get_bos_token_id_from_engine(engine_client)
# Prefill handler expects request_id, token_ids, and sampling_params # Prefill handler expects PreprocessedRequest format: token_ids, sampling_options, stop_conditions
# The sampling_params are converted via msgspec in the handler # The handler will override max_tokens/min_tokens to 1 and add do_remote_decode
self.default_payload = { self.default_payload = {
"request_id": "health_check",
"token_ids": [bos_token_id], "token_ids": [bos_token_id],
"sampling_params": { "sampling_options": {
"max_tokens": 1,
"min_tokens": 1,
"temperature": 0.0, "temperature": 0.0,
"top_p": 1.0, "top_p": 1.0,
"top_k": -1, "top_k": -1,
"detokenize": False, },
"stop_conditions": {
"stop": None,
"stop_token_ids": None,
"include_stop_str_in_output": False, "include_stop_str_in_output": False,
"ignore_eos": False, "ignore_eos": False,
"extra_args": { "min_tokens": 0,
"kv_transfer_params": {
"do_remote_decode": True,
}
},
}, },
} }
super().__init__() super().__init__()
...@@ -191,6 +191,60 @@ def setup_vllm_engine(config, stat_logger=None): ...@@ -191,6 +191,60 @@ def setup_vllm_engine(config, stat_logger=None):
return engine_client, vllm_config, default_sampling_params return engine_client, vllm_config, default_sampling_params
async def register_vllm_model(
model_input: ModelInput,
model_type: ModelType,
generate_endpoint,
config: Config,
engine_client: AsyncLLM,
vllm_config,
migration_limit: int,
):
"""
Helper function to register a vLLM model with runtime configuration.
Args:
model_input: Input type for the model (e.g., ModelInput.Tokens)
model_type: Type of model (e.g., ModelType.Chat, ModelType.Prefill)
generate_endpoint: Endpoint to register
config: Configuration object
engine_client: vLLM engine client
vllm_config: vLLM configuration
migration_limit: Migration limit for the model
"""
runtime_config = ModelRuntimeConfig()
# Get runtime configuration from vLLM engine
logging.info(
f"Getting engine runtime configuration metadata from vLLM engine for {model_type}..."
)
runtime_values = get_engine_cache_info(engine_client)
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
# Add tool/reasoning parsers for decode models
if model_type != ModelType.Prefill:
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(vllm_config.parallel_config, "data_parallel_size", 1)
runtime_config.data_parallel_size = data_parallel_size
await register_llm(
model_input,
model_type,
generate_endpoint,
config.model,
config.served_model_name,
kv_cache_block_size=config.engine_args.block_size,
migration_limit=migration_limit,
runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
)
async def init_prefill(runtime: DistributedRuntime, config: Config): async def init_prefill(runtime: DistributedRuntime, config: Config):
""" """
Instantiate and serve Instantiate and serve
...@@ -214,6 +268,18 @@ async def init_prefill(runtime: DistributedRuntime, config: Config): ...@@ -214,6 +268,18 @@ async def init_prefill(runtime: DistributedRuntime, config: Config):
if kv_publishers: if kv_publishers:
handler.kv_publishers = kv_publishers handler.kv_publishers = kv_publishers
# Register prefill model with ModelType.Prefill
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
await register_vllm_model(
ModelInput.Tokens,
ModelType.Prefill,
generate_endpoint,
config,
engine_client,
vllm_config,
migration_limit=0, # Prefill doesn't support migration
)
health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict() health_check_payload = VllmPrefillHealthCheckPayload(engine_client).to_dict()
try: try:
...@@ -255,20 +321,6 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -255,20 +321,6 @@ async def init(runtime: DistributedRuntime, config: Config):
generate_endpoint = component.endpoint(config.endpoint) generate_endpoint = component.endpoint(config.endpoint)
clear_endpoint = component.endpoint("clear_kv_blocks") clear_endpoint = component.endpoint("clear_kv_blocks")
prefill_router_client = (
await runtime.namespace(config.namespace)
.component("router") # Standalone router for prefill workers
.endpoint("generate")
.client()
)
prefill_worker_client = (
await runtime.namespace(config.namespace)
.component("prefill") # TODO don't hardcode
.endpoint("generate")
.client()
)
factory = StatLoggerFactory( factory = StatLoggerFactory(
component, component,
config.engine_args.data_parallel_rank or 0, config.engine_args.data_parallel_rank or 0,
...@@ -288,8 +340,6 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -288,8 +340,6 @@ async def init(runtime: DistributedRuntime, config: Config):
component, component,
engine_client, engine_client,
default_sampling_params, default_sampling_params,
prefill_worker_client,
prefill_router_client,
) )
# Set up KV event publishers for prefix caching if enabled (one per dp_rank) # Set up KV event publishers for prefix caching if enabled (one per dp_rank)
...@@ -305,35 +355,14 @@ async def init(runtime: DistributedRuntime, config: Config): ...@@ -305,35 +355,14 @@ async def init(runtime: DistributedRuntime, config: Config):
register_engine_metrics_callback(generate_endpoint, REGISTRY, "vllm:", "vLLM") register_engine_metrics_callback(generate_endpoint, REGISTRY, "vllm:", "vLLM")
if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register if not config.engine_args.data_parallel_rank: # if rank is 0 or None then register
runtime_config = ModelRuntimeConfig() await register_vllm_model(
# make a `collective_rpc` call to get runtime configuration values
logging.info(
"Getting engine runtime configuration metadata from vLLM engine..."
)
runtime_values = get_engine_cache_info(engine_client)
runtime_config.total_kv_blocks = runtime_values["num_gpu_blocks"]
runtime_config.max_num_seqs = runtime_values["max_num_seqs"]
runtime_config.max_num_batched_tokens = runtime_values["max_num_batched_tokens"]
runtime_config.tool_call_parser = config.tool_call_parser
runtime_config.reasoning_parser = config.reasoning_parser
# Get data_parallel_size from vllm_config (defaults to 1)
data_parallel_size = getattr(
vllm_config.parallel_config, "data_parallel_size", 1
)
runtime_config.data_parallel_size = data_parallel_size
await register_llm(
ModelInput.Tokens, ModelInput.Tokens,
ModelType.Chat | ModelType.Completions, ModelType.Chat | ModelType.Completions,
generate_endpoint, generate_endpoint,
config.model, config,
config.served_model_name, engine_client,
kv_cache_block_size=config.engine_args.block_size, vllm_config,
migration_limit=config.migration_limit, migration_limit=config.migration_limit,
runtime_config=runtime_config,
custom_template_path=config.custom_jinja_template,
) )
health_check_payload = VllmHealthCheckPayload(engine_client).to_dict() health_check_payload = VllmHealthCheckPayload(engine_client).to_dict()
......
...@@ -54,6 +54,51 @@ The main KV-aware routing arguments: ...@@ -54,6 +54,51 @@ The main KV-aware routing arguments:
For basic model registration without KV routing, you can use `--router-mode round-robin` or `--router-mode random` with both static and dynamic endpoints. For basic model registration without KV routing, you can use `--router-mode round-robin` or `--router-mode random` with both static and dynamic endpoints.
## Disaggregated Serving (Prefill and Decode)
Dynamo supports disaggregated serving where prefill (prompt processing) and decode (token generation) are handled by separate worker pools. When you register workers with `ModelType.Prefill` (see [Backend Guide](../development/backend-guide.md#model-types)), the frontend automatically detects them and activates an internal prefill router.
### Automatic Prefill Router Activation
The prefill router is automatically created when:
1. A decode model is registered (e.g., via `register_llm()` with `ModelType.Chat | ModelType.Completions`)
2. A prefill worker is detected with the same model name and `ModelType.Prefill`
**Key characteristics of the prefill router:**
- **Always uses KV-aware routing** regardless of the frontend's `--router-mode` setting
- **Always disables active block tracking** (`track_active_blocks=false`) since prefill workers don't perform decode
- **Seamlessly integrated** into the request pipeline between preprocessing and decode routing
- **Falls back gracefully** to decode-only mode if prefill fails or no prefill workers are available
### Setup Example
```python
# Decode worker registration (in your decode worker)
await register_llm(
model_input=ModelInput.Tokens,
model_type=ModelType.Chat | ModelType.Completions,
endpoint=generate_endpoint,
model_name="meta-llama/Llama-2-7b-hf",
# ... other parameters
)
# Prefill worker registration (in your prefill worker)
await register_llm(
model_input=ModelInput.Tokens,
model_type=ModelType.Prefill, # <-- Mark as prefill worker
endpoint=generate_endpoint,
model_name="meta-llama/Llama-2-7b-hf", # Must match decode model name
# ... other parameters
)
```
When both workers are registered, requests are automatically routed:
1. **Prefill phase** → Prefill router selects best prefill worker (KV-aware)
2. **Decode phase** → Decode router selects decode worker (uses frontend's `--router-mode`)
> [!Note]
> **WIP**: Currently, the prefill router always uses KV routing. Future updates will provide more fine-grained control over prefill routing behavior to match user-specified frontend router modes.
## Overview ## Overview
The KV-aware router operates on two key principles to optimize request routing: The KV-aware router operates on two key principles to optimize request routing:
......
...@@ -929,6 +929,7 @@ pub async fn build_worker_selection_pipeline_chat( ...@@ -929,6 +929,7 @@ pub async fn build_worker_selection_pipeline_chat(
busy_threshold, busy_threshold,
chooser, chooser,
hf_tokenizer, hf_tokenizer,
None,
) )
.await?; .await?;
...@@ -991,12 +992,7 @@ pub async fn create_worker_selection_pipeline_chat( ...@@ -991,12 +992,7 @@ pub async fn create_worker_selection_pipeline_chat(
let model_manager = std::sync::Arc::new(ModelManager::new()); let model_manager = std::sync::Arc::new(ModelManager::new());
Some( Some(
model_manager model_manager
.kv_chooser_for( .kv_chooser_for(&component, card.kv_cache_block_size, kv_router_config)
&card.display_name,
&component,
card.kv_cache_block_size,
kv_router_config,
)
.await?, .await?,
) )
} else { } else {
......
...@@ -779,12 +779,7 @@ async fn create_kv_router_from_endpoint( ...@@ -779,12 +779,7 @@ async fn create_kv_router_from_endpoint(
// Create ModelManager and use it to create KvRouter (ensures etcd registration) // Create ModelManager and use it to create KvRouter (ensures etcd registration)
let model_manager = Arc::new(llm_rs::discovery::ModelManager::new()); let model_manager = Arc::new(llm_rs::discovery::ModelManager::new());
let kv_router = model_manager let kv_router = model_manager
.kv_chooser_for( .kv_chooser_for(component, block_size as u32, kv_router_config)
"dummy_name", // does not matter, never cached
component,
block_size as u32,
kv_router_config,
)
.await .await
.map_err(to_pyerr)?; .map_err(to_pyerr)?;
......
...@@ -7,23 +7,33 @@ use std::{ ...@@ -7,23 +7,33 @@ use std::{
}; };
use parking_lot::{Mutex, RwLock}; use parking_lot::{Mutex, RwLock};
use tokio::sync::oneshot;
use dynamo_runtime::component::Component; use dynamo_runtime::component::{Component, Endpoint};
use dynamo_runtime::prelude::DistributedRuntimeProvider; use dynamo_runtime::prelude::DistributedRuntimeProvider;
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
use crate::{ use crate::{
kv_router::KvRouter, discovery::KV_ROUTERS_ROOT_PATH,
types::generic::tensor::TensorStreamingEngine, kv_router::{KvRouter, KvRouterConfig, scheduler::DefaultWorkerSelector},
types::openai::{ model_card::ModelDeploymentCard,
model_type::ModelType,
types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine, chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine,
},
}, },
}; };
use crate::{
kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector}, /// State for prefill router activation rendezvous
model_type::ModelType, enum PrefillActivationState {
}; /// Decode model registered, waiting for prefill endpoint
DecodeWaiting(oneshot::Sender<Endpoint>),
/// Prefill endpoint arrived, waiting for decode model to register
PrefillReady(oneshot::Receiver<Endpoint>),
}
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum ModelManagerError { pub enum ModelManagerError {
...@@ -41,11 +51,13 @@ pub struct ModelManager { ...@@ -41,11 +51,13 @@ pub struct ModelManager {
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>, chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>, embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>, tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
prefill_engines: RwLock<ModelEngines<TensorStreamingEngine>>, // Prefill models don't have engines - they're only tracked for discovery/lifecycle
prefill_engines: RwLock<ModelEngines<()>>,
// These are Mutex because we read and write rarely and equally // These are Mutex because we read and write rarely and equally
cards: Mutex<HashMap<String, ModelDeploymentCard>>, cards: Mutex<HashMap<String, ModelDeploymentCard>>,
kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>, kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>, // Key: component service_name
prefill_router_activators: Mutex<HashMap<String, PrefillActivationState>>,
} }
impl Default for ModelManager { impl Default for ModelManager {
...@@ -64,6 +76,7 @@ impl ModelManager { ...@@ -64,6 +76,7 @@ impl ModelManager {
prefill_engines: RwLock::new(ModelEngines::default()), prefill_engines: RwLock::new(ModelEngines::default()),
cards: Mutex::new(HashMap::new()), cards: Mutex::new(HashMap::new()),
kv_choosers: Mutex::new(HashMap::new()), kv_choosers: Mutex::new(HashMap::new()),
prefill_router_activators: Mutex::new(HashMap::new()),
} }
} }
...@@ -188,10 +201,9 @@ impl ModelManager { ...@@ -188,10 +201,9 @@ impl ModelManager {
&self, &self,
model: &str, model: &str,
card_checksum: &str, card_checksum: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write(); let mut clients = self.prefill_engines.write();
clients.add(model, card_checksum, engine) clients.add(model, card_checksum, ())
} }
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
...@@ -263,17 +275,6 @@ impl ModelManager { ...@@ -263,17 +275,6 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
pub fn get_prefill_engine(
&self,
model: &str,
) -> Result<TensorStreamingEngine, ModelManagerError> {
self.prefill_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is /// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
/// deleted. /// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> { pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
...@@ -288,19 +289,20 @@ impl ModelManager { ...@@ -288,19 +289,20 @@ impl ModelManager {
pub async fn kv_chooser_for( pub async fn kv_chooser_for(
&self, &self,
model_name: &str,
component: &Component, component: &Component,
kv_cache_block_size: u32, kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
) -> anyhow::Result<Arc<KvRouter>> { ) -> anyhow::Result<Arc<KvRouter>> {
if let Some(kv_chooser) = self.get_kv_chooser(model_name) { let service_name = component.service_name();
if let Some(kv_chooser) = self.get_kv_chooser(&service_name) {
// Check if the existing router has a different block size // Check if the existing router has a different block size
if kv_chooser.block_size() != kv_cache_block_size { if kv_chooser.block_size() != kv_cache_block_size {
tracing::warn!( tracing::warn!(
model_name = %model_name, component = %service_name,
existing_block_size = %kv_chooser.block_size(), existing_block_size = %kv_chooser.block_size(),
requested_block_size = %kv_cache_block_size, requested_block_size = %kv_cache_block_size,
"KV Router block size mismatch! Model is requesting a different kv_cache_block_size than the existing router. \ "KV Router block size mismatch! Component is requesting a different kv_cache_block_size than the existing router. \
This will cause routing to fail silently. Consider using the same block size or restarting the router." This will cause routing to fail silently. Consider using the same block size or restarting the router."
); );
} }
...@@ -339,12 +341,109 @@ impl ModelManager { ...@@ -339,12 +341,109 @@ impl ModelManager {
let new_kv_chooser = Arc::new(chooser); let new_kv_chooser = Arc::new(chooser);
self.kv_choosers self.kv_choosers
.lock() .lock()
.insert(model_name.to_string(), new_kv_chooser.clone()); .insert(service_name, new_kv_chooser.clone());
Ok(new_kv_chooser) Ok(new_kv_chooser)
} }
fn get_kv_chooser(&self, model_name: &str) -> Option<Arc<KvRouter>> { fn get_kv_chooser(&self, service_name: &str) -> Option<Arc<KvRouter>> {
self.kv_choosers.lock().get(model_name).cloned() self.kv_choosers.lock().get(service_name).cloned()
}
/// Register a prefill router for a decode model. Returns a receiver that will be
/// activated when the corresponding prefill model is discovered.
/// Returns None if the decode model was already registered.
pub fn register_prefill_router(
&self,
model_name: String,
) -> Option<oneshot::Receiver<Endpoint>> {
let mut activators = self.prefill_router_activators.lock();
match activators.remove(&model_name) {
Some(PrefillActivationState::PrefillReady(rx)) => {
// Prefill endpoint already arrived - rx will immediately resolve
tracing::debug!(
model_name = %model_name,
"Prefill endpoint already available, returning receiver with endpoint"
);
Some(rx)
}
Some(PrefillActivationState::DecodeWaiting(tx)) => {
// Decode already registered - this shouldn't happen, restore state and return None
tracing::error!(
model_name = %model_name,
"Decode model already registered for this prefill router"
);
activators.insert(model_name, PrefillActivationState::DecodeWaiting(tx));
None
}
None => {
// New registration: create tx/rx pair, store sender and return receiver
let (tx, rx) = oneshot::channel();
activators.insert(
model_name.clone(),
PrefillActivationState::DecodeWaiting(tx),
);
tracing::debug!(
model_name = %model_name,
"No prefill endpoint available yet, storing sender for future activation"
);
Some(rx)
}
}
}
/// Activate a prefill router by sending the endpoint through the oneshot channel.
/// If no decode model has registered yet, stores the endpoint for future retrieval.
pub fn activate_prefill_router(
&self,
model_name: &str,
endpoint: Endpoint,
) -> anyhow::Result<()> {
let mut activators = self.prefill_router_activators.lock();
match activators.remove(model_name) {
Some(PrefillActivationState::DecodeWaiting(sender)) => {
// Decode model already registered
sender.send(endpoint).map_err(|_| {
anyhow::anyhow!(
"Failed to send endpoint to prefill router activator for model: {}",
model_name
)
})?;
tracing::info!(
model_name = %model_name,
"Activated prefill router for already-registered decode model"
);
Ok(())
}
Some(PrefillActivationState::PrefillReady(_)) => {
// Prefill already activated - this shouldn't happen
anyhow::bail!("Prefill router for model {} already activated", model_name);
}
None => {
// Decode model not registered yet - create pair and immediately send endpoint
let (tx, rx) = oneshot::channel();
tx.send(endpoint).map_err(|_| {
anyhow::anyhow!("Failed to send endpoint for prefill model: {}", model_name)
})?;
// Store the receiver for when decode model registers
activators.insert(
model_name.to_string(),
PrefillActivationState::PrefillReady(rx),
);
tracing::info!(
model_name = %model_name,
"Stored prefill endpoint for future decode model registration"
);
Ok(())
}
}
} }
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> { pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
......
...@@ -20,7 +20,7 @@ use dynamo_runtime::{ ...@@ -20,7 +20,7 @@ use dynamo_runtime::{
use crate::{ use crate::{
backend::Backend, backend::Backend,
entrypoint, entrypoint,
kv_router::KvRouterConfig, kv_router::{KvRouterConfig, PrefillRouter},
model_card::{self, ModelDeploymentCard}, model_card::{self, ModelDeploymentCard},
model_type::{ModelInput, ModelType}, model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
...@@ -318,15 +318,26 @@ impl ModelWatcher { ...@@ -318,15 +318,26 @@ impl ModelWatcher {
.drt .drt
.namespace(&endpoint_id.namespace)? .namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?; .component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?; let endpoint = component.endpoint(&endpoint_id.name);
let client = endpoint.client().await?;
tracing::debug!(model_name = card.name(), "adding model"); tracing::debug!(model_name = card.name(), "adding model");
self.manager.save_model_card(key, card.clone())?; self.manager.save_model_card(key, card.clone())?;
if self.manager.has_model_any(card.name()) { // Check if we should skip registration:
// - Skip if a model with this name already exists
// - UNLESS this is a prefill model and no prefill model exists yet for this name
let is_new_prefill = card.model_type.supports_prefill()
&& !self
.manager
.list_prefill_models()
.contains(&card.name().to_string());
if self.manager.has_model_any(card.name()) && !is_new_prefill {
tracing::debug!( tracing::debug!(
model_name = card.name(), model_name = card.name(),
namespace = endpoint_id.namespace, namespace = endpoint_id.namespace,
"New endpoint for existing model" model_type = %card.model_type,
"New endpoint for existing model, skipping"
); );
return Ok(()); return Ok(());
} }
...@@ -346,12 +357,7 @@ impl ModelWatcher { ...@@ -346,12 +357,7 @@ impl ModelWatcher {
let kv_chooser = if self.router_mode == RouterMode::KV { let kv_chooser = if self.router_mode == RouterMode::KV {
Some( Some(
self.manager self.manager
.kv_chooser_for( .kv_chooser_for(&component, card.kv_cache_block_size, self.kv_router_config)
card.name(),
&component,
card.kv_cache_block_size,
self.kv_router_config,
)
.await?, .await?,
) )
} else { } else {
...@@ -361,6 +367,25 @@ impl ModelWatcher { ...@@ -361,6 +367,25 @@ impl ModelWatcher {
// This is expensive, we are loading ~10MiB JSON, so only do it once // This is expensive, we are loading ~10MiB JSON, so only do it once
let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?; let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
// Create prefill chooser once if we're building pipelines
// Both chat and completions will share the same prefill chooser instance
let prefill_chooser = self
.manager
.register_prefill_router(card.name().to_string())
.map(|rx| {
// Create prefill-specific config with track_active_blocks disabled
let mut prefill_config = self.kv_router_config.unwrap_or_default();
prefill_config.router_track_active_blocks = false;
PrefillRouter::new(
rx,
self.manager.clone(),
self.router_mode,
card.kv_cache_block_size,
Some(prefill_config),
)
});
// Add chat engine only if the model supports chat // Add chat engine only if the model supports chat
if card.model_type.supports_chat() { if card.model_type.supports_chat() {
let chat_engine = entrypoint::build_routed_pipeline::< let chat_engine = entrypoint::build_routed_pipeline::<
...@@ -373,6 +398,7 @@ impl ModelWatcher { ...@@ -373,6 +398,7 @@ impl ModelWatcher {
self.busy_threshold, self.busy_threshold,
kv_chooser.clone(), kv_chooser.clone(),
tokenizer_hf.clone(), tokenizer_hf.clone(),
prefill_chooser.clone(),
) )
.await .await
.context("build_routed_pipeline")?; .context("build_routed_pipeline")?;
...@@ -403,6 +429,7 @@ impl ModelWatcher { ...@@ -403,6 +429,7 @@ impl ModelWatcher {
kv_chooser, kv_chooser,
preprocessor, preprocessor,
tokenizer_hf, tokenizer_hf,
prefill_chooser,
) )
.await .await
.context("build_routed_pipeline_with_preprocessor")?; .context("build_routed_pipeline_with_preprocessor")?;
...@@ -503,11 +530,28 @@ impl ModelWatcher { ...@@ -503,11 +530,28 @@ impl ModelWatcher {
); );
} }
// This is effectively a guardrail + passthrough for now
// TODO: Build proper prefill pipeline with KV router (track_active_blocks=false)
tracing::info!( tracing::info!(
model_name = card.name(), model_name = card.name(),
"Prefill model registered (passthrough, not yet functional)" "Prefill model detected, registering and activating prefill router"
);
// Register prefill model for tracking (no engine needed, just lifecycle)
self.manager
.add_prefill_model(card.name(), checksum)
.context("add_prefill_model")?;
// Activate the prefill router with the endpoint for this prefill model
let Ok(()) = self.manager.activate_prefill_router(card.name(), endpoint) else {
tracing::warn!(
model_name = card.name(),
"Failed to activate prefill router - prefill model may already be activated"
);
return Ok(());
};
tracing::info!(
model_name = card.name(),
"Prefill model registered and router activated successfully"
); );
} else { } else {
// Reject unsupported combinations // Reject unsupported combinations
......
...@@ -8,7 +8,7 @@ use crate::{ ...@@ -8,7 +8,7 @@ use crate::{
discovery::{ModelManager, ModelWatcher}, discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig}, entrypoint::{self, EngineConfig},
kv_router::{KvPushRouter, KvRouter}, kv_router::{KvPushRouter, KvRouter, PrefillRouter},
migration::Migration, migration::Migration,
model_card::{self, ModelDeploymentCard}, model_card::{self, ModelDeploymentCard},
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
...@@ -122,7 +122,6 @@ pub async fn prepare_engine( ...@@ -122,7 +122,6 @@ pub async fn prepare_engine(
Some( Some(
model_manager model_manager
.kv_chooser_for( .kv_chooser_for(
local_model.display_name(),
&component, &component,
card.kv_cache_block_size, card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config), Some(local_model.router_config().kv_router_config),
...@@ -144,6 +143,7 @@ pub async fn prepare_engine( ...@@ -144,6 +143,7 @@ pub async fn prepare_engine(
None, None,
kv_chooser.clone(), kv_chooser.clone(),
hf_tokenizer, hf_tokenizer,
None, // No prefill chooser in static mode
) )
.await?; .await?;
...@@ -232,6 +232,7 @@ pub async fn build_routed_pipeline<Req, Resp>( ...@@ -232,6 +232,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
busy_threshold: Option<f64>, busy_threshold: Option<f64>,
chooser: Option<Arc<KvRouter>>, chooser: Option<Arc<KvRouter>>,
hf_tokenizer: tokenizers::Tokenizer, hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>> ) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where where
Req: Data, Req: Data,
...@@ -254,10 +255,12 @@ where ...@@ -254,10 +255,12 @@ where
chooser, chooser,
preprocessor, preprocessor,
hf_tokenizer, hf_tokenizer,
prefill_chooser,
) )
.await .await
} }
#[allow(clippy::too_many_arguments)]
pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>( pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
card: &ModelDeploymentCard, card: &ModelDeploymentCard,
client: &Client, client: &Client,
...@@ -266,6 +269,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>( ...@@ -266,6 +269,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
chooser: Option<Arc<KvRouter>>, chooser: Option<Arc<KvRouter>>,
preprocessor: Arc<OpenAIPreprocessor>, preprocessor: Arc<OpenAIPreprocessor>,
hf_tokenizer: tokenizers::Tokenizer, hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>> ) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where where
Req: Data, Req: Data,
...@@ -298,6 +302,7 @@ where ...@@ -298,6 +302,7 @@ where
worker_monitor, worker_monitor,
) )
.await?; .await?;
let service_backend = match router_mode { let service_backend = match router_mode {
RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => { RouterMode::Random | RouterMode::RoundRobin | RouterMode::Direct(_) => {
ServiceBackend::from_engine(Arc::new(router)) ServiceBackend::from_engine(Arc::new(router))
...@@ -311,14 +316,22 @@ where ...@@ -311,14 +316,22 @@ where
} }
}; };
// Use the provided prefill chooser, or create a disabled one if not provided
let prefill_chooser = prefill_chooser.unwrap_or_else(|| PrefillRouter::disabled(router_mode));
let prefill_op = prefill_chooser.into_operator();
// Link with prefill chooser including backward edge for response flow
let engine = frontend let engine = frontend
.link(preprocessor_op.forward_edge())? .link(preprocessor_op.forward_edge())?
.link(backend.forward_edge())? .link(backend.forward_edge())?
.link(migration.forward_edge())? .link(migration.forward_edge())?
.link(prefill_op.forward_edge())?
.link(service_backend)? .link(service_backend)?
.link(prefill_op.backward_edge())?
.link(migration.backward_edge())? .link(migration.backward_edge())?
.link(backend.backward_edge())? .link(backend.backward_edge())?
.link(preprocessor_op.backward_edge())? .link(preprocessor_op.backward_edge())?
.link(frontend)?; .link(frontend)?;
Ok(engine) Ok(engine)
} }
...@@ -81,7 +81,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -81,7 +81,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Some( Some(
manager manager
.kv_chooser_for( .kv_chooser_for(
local_model.display_name(),
&component, &component,
card.kv_cache_block_size, card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config), Some(local_model.router_config().kv_router_config),
...@@ -103,6 +102,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -103,6 +102,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
None, None,
kv_chooser.clone(), kv_chooser.clone(),
tokenizer_hf.clone(), tokenizer_hf.clone(),
None, // No prefill chooser in grpc static mode
) )
.await?; .await?;
manager.add_chat_completions_model( manager.add_chat_completions_model(
...@@ -111,11 +111,18 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -111,11 +111,18 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
chat_engine, chat_engine,
)?; )?;
let completions_engine = let completions_engine = entrypoint::build_routed_pipeline::<
entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf) >(
card,
&client,
router_mode,
None,
kv_chooser,
tokenizer_hf,
None, // No prefill chooser in grpc static mode
)
.await?; .await?;
manager.add_completions_model( manager.add_completions_model(
local_model.display_name(), local_model.display_name(),
......
...@@ -121,7 +121,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -121,7 +121,6 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Some( Some(
manager manager
.kv_chooser_for( .kv_chooser_for(
local_model.display_name(),
&component, &component,
card.kv_cache_block_size, card.kv_cache_block_size,
Some(local_model.router_config().kv_router_config), Some(local_model.router_config().kv_router_config),
...@@ -143,6 +142,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -143,6 +142,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
None, None,
kv_chooser.clone(), kv_chooser.clone(),
tokenizer_hf.clone(), tokenizer_hf.clone(),
None, // No prefill chooser in http static mode
) )
.await?; .await?;
manager.add_chat_completions_model( manager.add_chat_completions_model(
...@@ -151,11 +151,18 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -151,11 +151,18 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
chat_engine, chat_engine,
)?; )?;
let completions_engine = let completions_engine = entrypoint::build_routed_pipeline::<
entrypoint::build_routed_pipeline::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf) >(
card,
&client,
router_mode,
None,
kv_chooser,
tokenizer_hf,
None, // No prefill chooser in http static mode
)
.await?; .await?;
manager.add_completions_model( manager.add_completions_model(
local_model.display_name(), local_model.display_name(),
......
...@@ -22,6 +22,7 @@ use serde::{Deserialize, Serialize}; ...@@ -22,6 +22,7 @@ use serde::{Deserialize, Serialize};
pub mod approx; pub mod approx;
pub mod indexer; pub mod indexer;
pub mod prefill_router;
pub mod protocols; pub mod protocols;
pub mod publisher; pub mod publisher;
pub mod recorder; pub mod recorder;
...@@ -30,6 +31,8 @@ pub mod scoring; ...@@ -30,6 +31,8 @@ pub mod scoring;
pub mod sequence; pub mod sequence;
pub mod subscriber; pub mod subscriber;
pub use prefill_router::PrefillRouter;
use crate::{ use crate::{
kv_router::{ kv_router::{
approx::ApproxKvIndexer, approx::ApproxKvIndexer,
...@@ -247,7 +250,7 @@ impl KvRouter { ...@@ -247,7 +250,7 @@ impl KvRouter {
let runtime_configs_watcher = watch_prefix_with_extraction( let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client, etcd_client,
model_card::ROOT_PATH, &format!("{}/{}", model_card::ROOT_PATH, component.path()),
key_extractors::lease_id, key_extractors::lease_id,
|card: ModelDeploymentCard| Some(card.runtime_config), |card: ModelDeploymentCard| Some(card.runtime_config),
cancellation_token.clone(), cancellation_token.clone(),
......
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::{Arc, OnceLock};
use anyhow::{Result, bail};
use futures::StreamExt;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
use dynamo_runtime::{
component::Endpoint,
pipeline::{
AsyncEngine, AsyncEngineContext, AsyncEngineContextProvider, Context, ManyOut, Operator,
PushRouter, RouterMode, ServerStreamingEngine, SingleIn, async_trait,
},
protocols::{annotated::Annotated, maybe_error::MaybeError},
};
use crate::{
discovery::ModelManager,
kv_router::{KvPushRouter, KvRouterConfig, RouterConfigOverride},
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
};
/// The inner router used by PrefillRouter
enum InnerPrefillRouter {
/// KV-aware routing using KvPushRouter
KvRouter(Arc<KvPushRouter>),
/// Simple routing (RoundRobin, Random, Direct)
SimpleRouter(Arc<PushRouter<PreprocessedRequest, Annotated<LLMEngineOutput>>>),
}
/// PrefillRouter is a forward-only operator that sits between Migration and the decode router.
/// It optionally calls a prefill worker before routing to decode, extracting disaggregated_params
/// from the prefill response and injecting them into the decode request.
pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>,
cancel_token: CancellationToken,
router_mode: RouterMode,
}
impl PrefillRouter {
/// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled(router_mode: RouterMode) -> Arc<Self> {
Arc::new(Self {
prefill_router: OnceLock::new(),
cancel_token: CancellationToken::new(),
router_mode,
})
}
pub fn new(
activation_rx: oneshot::Receiver<Endpoint>,
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> Arc<Self> {
let prefill_router = OnceLock::new();
let cancel_token = CancellationToken::new();
let router = Arc::new(Self {
prefill_router,
cancel_token: cancel_token.clone(),
router_mode,
});
// Spawn background task to wait for activation
let router_clone = router.clone();
tokio::spawn(async move {
tokio::select! {
result = activation_rx => {
let Ok(endpoint) = result else {
tracing::debug!("Prefill router activation channel closed without receiving endpoint");
return;
};
if let Err(e) = router_clone.activate(
endpoint,
model_manager,
kv_cache_block_size,
kv_router_config,
).await {
tracing::error!(error = %e, "Failed to activate prefill router");
}
}
_ = cancel_token.cancelled() => {
tracing::debug!("Prefill router activation cancelled");
}
}
});
router
}
/// Activate the prefill router with the provided endpoint
async fn activate(
&self,
endpoint: Endpoint,
model_manager: Arc<ModelManager>,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
) -> Result<()> {
tracing::info!(
router_mode = ?self.router_mode,
"Activating prefill router"
);
let client = endpoint.client().await?;
let inner_router = if self.router_mode.is_kv_routing() {
// Create KV chooser using the component from the endpoint
let kv_chooser = model_manager
.kv_chooser_for(endpoint.component(), kv_cache_block_size, kv_router_config)
.await?;
// Build the PushRouter for prefill with KV mode
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
RouterMode::KV,
None, // busy_threshold
None, // worker_monitor
)
.await?;
// Wrap it in KvPushRouter
InnerPrefillRouter::KvRouter(Arc::new(KvPushRouter::new(push_router, kv_chooser)))
} else {
// Create simple push router with the frontend's router mode
let push_router = PushRouter::<PreprocessedRequest, Annotated<LLMEngineOutput>>::from_client_with_threshold(
client,
self.router_mode,
None, // busy_threshold
None, // worker_monitor
)
.await?;
InnerPrefillRouter::SimpleRouter(Arc::new(push_router))
};
// Set the router (ignore error if already set)
let _ = self.prefill_router.set(inner_router);
tracing::info!(
router_mode = ?self.router_mode,
"Prefill router activated successfully"
);
Ok(())
}
/// Call the prefill router and extract disaggregated_params
async fn call_prefill(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<serde_json::Value> {
// Get the prefill router, error if not activated
let Some(prefill_router) = self.prefill_router.get() else {
bail!("Prefill router not yet activated");
};
// Call the appropriate router based on the type
let mut prefill_response = match prefill_router {
InnerPrefillRouter::KvRouter(router) => router.generate(request).await?,
InnerPrefillRouter::SimpleRouter(router) => router.generate(request).await?,
};
let Some(first_output) = prefill_response.next().await else {
bail!("Prefill router returned no output (stream ended)");
};
if let Some(err) = first_output.err() {
while prefill_response.next().await.is_some() {}
bail!("Prefill router returned error in output: {:?}", err);
}
let Some(output) = &first_output.data else {
while prefill_response.next().await.is_some() {}
bail!("Prefill router output has no data field");
};
let Some(disaggregated_params) = output.disaggregated_params.clone() else {
while prefill_response.next().await.is_some() {}
bail!("Prefill router output missing disaggregated_params");
};
while prefill_response.next().await.is_some() {}
Ok(disaggregated_params)
}
}
impl Drop for PrefillRouter {
fn drop(&mut self) {
tracing::debug!("Dropping PrefillRouter, cancelling background activation task");
self.cancel_token.cancel();
}
}
#[async_trait]
impl
Operator<
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
SingleIn<PreprocessedRequest>,
ManyOut<Annotated<LLMEngineOutput>>,
> for PrefillRouter
{
async fn generate(
&self,
request: SingleIn<PreprocessedRequest>,
next: ServerStreamingEngine<PreprocessedRequest, Annotated<LLMEngineOutput>>,
) -> Result<ManyOut<Annotated<LLMEngineOutput>>> {
// Extract request data while preserving context
let (req, context) = request.into_parts();
let request_id = context.id().to_string();
// Prepare prefill request with linked context for cancellation propagation
let prefill_req = req.clone();
let prefill_context = Context::with_id(prefill_req, request_id.clone());
// Link the prefill context as a child so that kill signals propagate
context.controller().link_child(prefill_context.context());
let prefill_request = prefill_context;
// Attempt prefill and handle results
match self.call_prefill(prefill_request).await {
Ok(disaggregated_params) => {
tracing::debug!("Prefill succeeded, using disaggregated params for decode");
// Update request with disaggregated_params and router config
let mut decode_req = req;
decode_req.disaggregated_params = Some(disaggregated_params);
// Set router_config_override for decode: overlap_score_weight = 0
let existing_override = decode_req.router_config_override.take();
decode_req.router_config_override = Some(RouterConfigOverride {
overlap_score_weight: Some(0.0),
..existing_override.unwrap_or_default()
});
// Map the modified request through with preserved context
let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await
}
Err(e) => {
tracing::debug!(error = %e, "Remote prefill failed, falling back to decode-only");
next.generate(context.map(|_| req)).await
}
}
}
}
...@@ -306,8 +306,7 @@ pub async fn start_kv_router_background( ...@@ -306,8 +306,7 @@ pub async fn start_kv_router_background(
let key = String::from_utf8_lossy(kv.key()); let key = String::from_utf8_lossy(kv.key());
// Extract the hex worker ID after the colon (e.g., "generate:694d99badb9f7c07" -> "694d99badb9f7c07") let Some(worker_id_str) = key.split(&['/', ':'][..]).next_back() else {
let Some(worker_id_str) = key.split(':').next_back() else {
tracing::warn!("Could not extract worker ID from instance key: {key}"); tracing::warn!("Could not extract worker ID from instance key: {key}");
continue; continue;
}; };
......
...@@ -40,18 +40,22 @@ class DynamoWorkerProcess(ManagedProcess): ...@@ -40,18 +40,22 @@ class DynamoWorkerProcess(ManagedProcess):
"3", "3",
] ]
health_check_urls = [
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
]
# Set port based on worker type # Set port based on worker type
port = "8082" if is_prefill else "8081" port = "8082" if is_prefill else "8081"
# Add prefill worker flag if needed # Configure health check based on worker type
if is_prefill: if is_prefill:
# Prefill workers check their own status endpoint
command.append("--is-prefill-worker") command.append("--is-prefill-worker")
health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)] health_check_urls = [(f"http://localhost:{port}/health", self.is_ready)]
else:
# Decode workers should also check their own status endpoint first,
# then verify the frontend sees the model
health_check_urls = [
(f"http://localhost:{port}/health", self.is_ready),
(f"http://localhost:{FRONTEND_PORT}/v1/models", check_models_api),
(f"http://localhost:{FRONTEND_PORT}/health", check_health_generate),
]
# Set debug logging environment # Set debug logging environment
env = os.environ.copy() env = os.environ.copy()
...@@ -155,10 +159,10 @@ def test_request_cancellation_vllm_aggregated( ...@@ -155,10 +159,10 @@ def test_request_cancellation_vllm_aggregated(
# Send the request (non-blocking) # Send the request (non-blocking)
cancellable_req = send_cancellable_request(request_type) cancellable_req = send_cancellable_request(request_type)
# Poll for "New Request ID" pattern # Poll for "Decode Request ID" pattern (vLLM v2 pattern)
request_id, worker_log_offset = poll_for_pattern( request_id, worker_log_offset = poll_for_pattern(
process=worker, process=worker,
pattern="New Request ID: ", pattern="Decode Request ID: ",
log_offset=worker_log_offset, log_offset=worker_log_offset,
match_type="contains", match_type="contains",
) )
...@@ -223,17 +227,17 @@ def test_request_cancellation_vllm_decode_cancel( ...@@ -223,17 +227,17 @@ def test_request_cancellation_vllm_decode_cancel(
# Send streaming request (non-blocking) # Send streaming request (non-blocking)
cancellable_req = send_cancellable_request("chat_completion_stream") cancellable_req = send_cancellable_request("chat_completion_stream")
# Poll for "New Request ID" pattern in decode worker # Poll for "Decode Request ID" pattern in decode worker (vLLM v2 pattern)
request_id, decode_log_offset = poll_for_pattern( request_id, decode_log_offset = poll_for_pattern(
process=decode_worker, process=decode_worker,
pattern="New Request ID: ", pattern="Decode Request ID: ",
match_type="contains", match_type="contains",
) )
# Verify same request ID reached prefill worker (as "New Prefill Request ID") # Verify same request ID reached prefill worker (as "Prefill Request ID")
_, prefill_log_offset = poll_for_pattern( _, prefill_log_offset = poll_for_pattern(
process=prefill_worker, process=prefill_worker,
pattern=f"New Prefill Request ID: {request_id}", pattern=f"Prefill Request ID: {request_id}",
) )
# Read 5 streaming responses (decode phase) # Read 5 streaming responses (decode phase)
...@@ -288,9 +292,11 @@ def test_request_cancellation_vllm_remote_prefill_cancel( ...@@ -288,9 +292,11 @@ def test_request_cancellation_vllm_remote_prefill_cancel(
with DynamoWorkerProcess(request, is_prefill=False) as decode_worker: with DynamoWorkerProcess(request, is_prefill=False) as decode_worker:
logger.info(f"Decode Worker PID: {decode_worker.get_pid()}") logger.info(f"Decode Worker PID: {decode_worker.get_pid()}")
# Step 4: Test request cancellation during remote prefill phase # Step 4: Test request cancellation during prefill phase
# Note: With the new architecture, prefill routing happens in the frontend,
# so the request goes directly to the prefill worker first
logger.info( logger.info(
"Testing completion request cancellation during remote prefill phase..." "Testing completion request cancellation during prefill phase..."
) )
# Send request with long prompt (non-blocking) # Send request with long prompt (non-blocking)
...@@ -298,37 +304,25 @@ def test_request_cancellation_vllm_remote_prefill_cancel( ...@@ -298,37 +304,25 @@ def test_request_cancellation_vllm_remote_prefill_cancel(
"completion", use_long_prompt=True "completion", use_long_prompt=True
) )
# Poll for "New Request ID" pattern in decode worker # Poll for "Prefill Request ID" pattern in prefill worker (vLLM v2 pattern)
request_id, decode_log_offset = poll_for_pattern( # With new architecture, prefill is routed by frontend's internal router
process=decode_worker, request_id, prefill_log_offset = poll_for_pattern(
pattern="New Request ID: ",
match_type="contains",
)
# Poll for same request ID in prefill worker (as "New Prefill Request ID")
_, prefill_log_offset = poll_for_pattern(
process=prefill_worker, process=prefill_worker,
pattern=f"New Prefill Request ID: {request_id}", pattern="Prefill Request ID: ",
match_type="contains",
) )
# Cancel during prefill phase # Cancel during prefill phase
cancellable_req.cancel() cancellable_req.cancel()
logger.info(f"Cancelled request ID: {request_id} during remote prefill") logger.info(f"Cancelled request ID: {request_id} during prefill")
# Poll for "Aborted Prefill Request ID" in prefill worker first (where cancellation happens) # Poll for "Aborted Prefill Request ID" in prefill worker (where cancellation happens)
_, prefill_log_offset = poll_for_pattern( _, prefill_log_offset = poll_for_pattern(
process=prefill_worker, process=prefill_worker,
pattern=f"Aborted Prefill Request ID: {request_id}", pattern=f"Aborted Prefill Request ID: {request_id}",
log_offset=prefill_log_offset, log_offset=prefill_log_offset,
) )
# Then poll for "Aborted Remote Prefill Request ID" in decode worker
_, decode_log_offset = poll_for_pattern(
process=decode_worker,
pattern=f"Aborted Remote Prefill Request ID: {request_id}",
log_offset=decode_log_offset,
)
# Verify frontend log has kill message # Verify frontend log has kill message
_, frontend_log_offset = poll_for_pattern( _, frontend_log_offset = poll_for_pattern(
process=frontend, process=frontend,
......
...@@ -173,13 +173,23 @@ def validate_openai_response(response: requests.Response) -> None: ...@@ -173,13 +173,23 @@ def validate_openai_response(response: requests.Response) -> None:
def check_worker_received_request(worker_process: DynamoWorkerProcess) -> bool: def check_worker_received_request(worker_process: DynamoWorkerProcess) -> bool:
"""Check if the worker logs contain 'New Request ID:' message indicating it received a request""" """Check if the worker logs contain request ID message indicating it received a request.
Supports multiple backend patterns:
- vLLM: "Decode Request ID:" or "Prefill Request ID:"
- SGLang/TensorRT-LLM: "New Request ID:"
"""
log_path = worker_process._log_path log_path = worker_process._log_path
if log_path and os.path.exists(log_path): if log_path and os.path.exists(log_path):
try: try:
with open(log_path, "r") as f: with open(log_path, "r") as f:
log_content = f.read() log_content = f.read()
return "New Request ID: " in log_content # Check for any of the supported patterns
return (
"New Request ID: " in log_content
or "Decode Request ID: " in log_content
or "Prefill Request ID: " in log_content
)
except Exception as e: except Exception as e:
logger.warning(f"Could not read worker log file {log_path}: {e}") logger.warning(f"Could not read worker log file {log_path}: {e}")
return False return False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment