Unverified Commit ce833983 authored by Yan Ru Pei's avatar Yan Ru Pei Committed by GitHub
Browse files

chore: better error handling in prefill router (#4286)


Signed-off-by: default avatarPeaBrane <yanrpei@gmail.com>
parent 0f5dd2b7
......@@ -162,6 +162,12 @@ def parse_args():
default=True,
help="KV Router: Disable tracking of active blocks (blocks being used for ongoing generation). By default, active blocks are tracked for load balancing.",
)
parser.add_argument(
"--enforce-disagg",
action="store_true",
default=False,
help="Enforce disaggregated prefill-decode. When set, unactivated prefill router will return an error instead of falling back to decode-only mode.",
)
parser.add_argument(
"--busy-threshold",
type=float,
......@@ -278,7 +284,7 @@ async def async_main():
"http_port": flags.http_port,
"kv_cache_block_size": flags.kv_cache_block_size,
"router_config": RouterConfig(
router_mode, kv_router_config, flags.busy_threshold
router_mode, kv_router_config, flags.busy_threshold, flags.enforce_disagg
),
}
......
......@@ -959,12 +959,16 @@ pub async fn create_worker_selection_pipeline_chat(
use dynamo_llm::discovery::ModelWatcher;
let model_manager = std::sync::Arc::new(ModelManager::new());
let router_config = dynamo_llm::entrypoint::RouterConfig {
router_mode,
kv_router_config: kv_router_config.unwrap_or_default(),
busy_threshold,
enforce_disagg: false,
};
let watcher = ModelWatcher::new(
component.drt().clone(),
model_manager.clone(),
router_mode,
kv_router_config,
busy_threshold,
router_config,
);
let cards = watcher
.cards_for_model(model_name, Some(namespace), false)
......@@ -1031,7 +1035,8 @@ pub async fn create_worker_selection_pipeline_chat(
busy_threshold,
chooser,
hf_tokenizer,
None,
None, // prefill_chooser
false, // enforce_disagg
)
.await?;
......
......@@ -71,21 +71,24 @@ pub struct RouterConfig {
router_mode: RouterMode,
kv_router_config: KvRouterConfig,
busy_threshold: Option<f64>,
enforce_disagg: bool,
}
#[pymethods]
impl RouterConfig {
#[new]
#[pyo3(signature = (mode, config=None, busy_threshold=None))]
#[pyo3(signature = (mode, config=None, busy_threshold=None, enforce_disagg=false))]
pub fn new(
mode: RouterMode,
config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
enforce_disagg: bool,
) -> Self {
Self {
router_mode: mode,
kv_router_config: config.unwrap_or_default(),
busy_threshold,
enforce_disagg,
}
}
}
......@@ -96,6 +99,7 @@ impl From<RouterConfig> for RsRouterConfig {
router_mode: rc.router_mode.into(),
kv_router_config: rc.kv_router_config.inner,
busy_threshold: rc.busy_threshold,
enforce_disagg: rc.enforce_disagg,
}
}
}
......
......@@ -20,8 +20,8 @@ use dynamo_runtime::{
use crate::{
backend::Backend,
entrypoint,
kv_router::{KvRouterConfig, PrefillRouter},
entrypoint::{self, RouterConfig},
kv_router::PrefillRouter,
model_card::ModelDeploymentCard,
model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
......@@ -50,11 +50,9 @@ pub enum ModelUpdate {
pub struct ModelWatcher {
manager: Arc<ModelManager>,
drt: DistributedRuntime,
router_mode: RouterMode,
router_config: RouterConfig,
notify_on_model: Notify,
model_update_tx: Option<Sender<ModelUpdate>>,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
}
const ALL_MODEL_TYPES: &[ModelType] = &[
......@@ -69,18 +67,14 @@ impl ModelWatcher {
pub fn new(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
router_config: RouterConfig,
) -> ModelWatcher {
Self {
manager: model_manager,
drt: runtime,
router_mode,
router_config,
notify_on_model: Notify::new(),
model_update_tx: None,
kv_router_config,
busy_threshold,
}
}
......@@ -371,10 +365,14 @@ impl ModelWatcher {
// handle Chat or Completions requests, so handle whatever the model supports.
let endpoint = component.endpoint(&endpoint_id.name);
let kv_chooser = if self.router_mode == RouterMode::KV {
let kv_chooser = if self.router_config.router_mode == RouterMode::KV {
Some(
self.manager
.kv_chooser_for(&endpoint, card.kv_cache_block_size, self.kv_router_config)
.kv_chooser_for(
&endpoint,
card.kv_cache_block_size,
Some(self.router_config.kv_router_config),
)
.await?,
)
} else {
......@@ -391,15 +389,16 @@ impl ModelWatcher {
.register_prefill_router(card.name().to_string())
.map(|rx| {
// Create prefill-specific config with track_active_blocks disabled
let mut prefill_config = self.kv_router_config.unwrap_or_default();
let mut prefill_config = self.router_config.kv_router_config;
prefill_config.router_track_active_blocks = false;
PrefillRouter::new(
rx,
self.manager.clone(),
self.router_mode,
self.router_config.router_mode,
card.kv_cache_block_size,
Some(prefill_config),
self.router_config.enforce_disagg,
)
});
......@@ -411,11 +410,12 @@ impl ModelWatcher {
>(
card,
&client,
self.router_mode,
self.busy_threshold,
self.router_config.router_mode,
self.router_config.busy_threshold,
kv_chooser.clone(),
tokenizer_hf.clone(),
prefill_chooser.clone(),
self.router_config.enforce_disagg,
)
.await
.context("build_routed_pipeline")?;
......@@ -441,12 +441,13 @@ impl ModelWatcher {
>(
card,
&client,
self.router_mode,
self.busy_threshold,
self.router_config.router_mode,
self.router_config.busy_threshold,
kv_chooser,
preprocessor,
tokenizer_hf,
prefill_chooser,
self.router_config.enforce_disagg,
)
.await
.context("build_routed_pipeline_with_preprocessor")?;
......@@ -461,7 +462,7 @@ impl ModelWatcher {
NvCreateEmbeddingRequest,
Annotated<NvCreateEmbeddingResponse>,
>::from_client_with_threshold(
client, self.router_mode, None, None
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
......@@ -469,11 +470,12 @@ impl ModelWatcher {
.add_embeddings_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
// Case 3: Text + Chat
let push_router =
PushRouter::<
let push_router = PushRouter::<
NvCreateChatCompletionRequest,
Annotated<NvCreateChatCompletionStreamResponse>,
>::from_client_with_threshold(client, self.router_mode, None, None)
>::from_client_with_threshold(
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
self.manager
......@@ -484,7 +486,7 @@ impl ModelWatcher {
NvCreateCompletionRequest,
Annotated<NvCreateCompletionResponse>,
>::from_client_with_threshold(
client, self.router_mode, None, None
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
......@@ -506,7 +508,7 @@ impl ModelWatcher {
PreprocessedEmbeddingRequest,
Annotated<EmbeddingsEngineOutput>,
>::from_client_with_threshold(
client, self.router_mode, None, None
client, self.router_config.router_mode, None, None
)
.await?;
......@@ -531,7 +533,7 @@ impl ModelWatcher {
NvCreateTensorRequest,
Annotated<NvCreateTensorResponse>,
>::from_client_with_threshold(
client, self.router_mode, None, None
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
......
......@@ -22,6 +22,7 @@ pub struct RouterConfig {
pub router_mode: RouterMode,
pub kv_router_config: KvRouterConfig,
pub busy_threshold: Option<f64>,
pub enforce_disagg: bool,
}
impl RouterConfig {
......@@ -30,6 +31,7 @@ impl RouterConfig {
router_mode,
kv_router_config,
busy_threshold: None,
enforce_disagg: false,
}
}
......@@ -37,6 +39,11 @@ impl RouterConfig {
self.busy_threshold = threshold;
self
}
pub fn with_enforce_disagg(mut self, enforce_disagg: bool) -> Self {
self.enforce_disagg = enforce_disagg;
self
}
}
#[derive(Clone)]
......
......@@ -7,7 +7,7 @@ use crate::{
backend::{Backend, ExecutionContext},
discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter,
entrypoint::EngineConfig,
entrypoint::{EngineConfig, RouterConfig},
kv_router::{KvPushRouter, KvRouter, PrefillRouter},
migration::Migration,
model_card::ModelDeploymentCard,
......@@ -63,9 +63,7 @@ pub async fn prepare_engine(
let watch_obj = Arc::new(ModelWatcher::new(
distributed_runtime.clone(),
model_manager.clone(),
dynamo_runtime::pipeline::RouterMode::RoundRobin,
None,
None,
RouterConfig::default(),
));
let discovery = distributed_runtime.discovery();
let discovery_stream = discovery
......@@ -163,6 +161,7 @@ where
.link(frontend)?)
}
#[allow(clippy::too_many_arguments)]
pub async fn build_routed_pipeline<Req, Resp>(
card: &ModelDeploymentCard,
client: &Client,
......@@ -171,6 +170,7 @@ pub async fn build_routed_pipeline<Req, Resp>(
chooser: Option<Arc<KvRouter>>,
hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>,
enforce_disagg: bool,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
Req: Data,
......@@ -194,6 +194,7 @@ where
preprocessor,
hf_tokenizer,
prefill_chooser,
enforce_disagg,
)
.await
}
......@@ -208,6 +209,7 @@ pub async fn build_routed_pipeline_with_preprocessor<Req, Resp>(
preprocessor: Arc<OpenAIPreprocessor>,
hf_tokenizer: tokenizers::Tokenizer,
prefill_chooser: Option<Arc<PrefillRouter>>,
enforce_disagg: bool,
) -> anyhow::Result<ServiceEngine<SingleIn<Req>, ManyOut<Annotated<Resp>>>>
where
Req: Data,
......@@ -265,7 +267,8 @@ where
};
// Use the provided prefill chooser, or create a disabled one if not provided
let prefill_chooser = prefill_chooser.unwrap_or_else(|| PrefillRouter::disabled(router_mode));
let prefill_chooser =
prefill_chooser.unwrap_or_else(|| PrefillRouter::disabled(router_mode, enforce_disagg));
let prefill_op = prefill_chooser.into_operator();
// Link with prefill chooser including backward edge for response flow
......
......@@ -6,9 +6,8 @@ use std::sync::Arc;
use crate::{
discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter,
entrypoint::{EngineConfig, input::common},
entrypoint::{EngineConfig, RouterConfig, input::common},
grpc::service::kserve,
kv_router::KvRouterConfig,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
......@@ -16,7 +15,6 @@ use crate::{
},
};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
/// Build and run an KServe gRPC service
pub async fn run(
......@@ -41,9 +39,7 @@ pub async fn run(
run_watcher(
distributed_runtime.clone(),
grpc_service.state().manager_clone(),
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
router_config.clone(),
target_namespace,
)
.await?;
......@@ -97,18 +93,10 @@ pub async fn run(
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
router_config: RouterConfig,
target_namespace: Option<String>,
) -> anyhow::Result<()> {
let watch_obj = ModelWatcher::new(
runtime.clone(),
model_manager,
router_mode,
kv_router_config,
busy_threshold,
);
let watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config);
tracing::debug!("Waiting for remote model");
let discovery = runtime.discovery();
let discovery_stream = discovery
......
......@@ -7,9 +7,8 @@ use crate::{
discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType,
engines::StreamingEngineAdapter,
entrypoint::{EngineConfig, input::common},
entrypoint::{EngineConfig, RouterConfig, input::common},
http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
......@@ -17,7 +16,6 @@ use crate::{
},
};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
/// Build and run an HTTP service
pub async fn run(
......@@ -81,9 +79,7 @@ pub async fn run(
run_watcher(
distributed_runtime.clone(),
http_service.state().manager_clone(),
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
router_config.clone(),
target_namespace,
Arc::new(http_service.clone()),
http_service.state().metrics_clone(),
......@@ -191,24 +187,15 @@ pub async fn run(
/// Spawns a task that watches for new models in store,
/// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)]
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
router_config: RouterConfig,
target_namespace: Option<String>,
http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>,
) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new(
runtime.clone(),
model_manager,
router_mode,
kv_router_config,
busy_threshold,
);
let mut watch_obj = ModelWatcher::new(runtime.clone(), model_manager, router_config);
tracing::debug!("Waiting for remote model");
let discovery = runtime.discovery();
let discovery_stream = discovery
......
......@@ -3,7 +3,7 @@
use std::sync::{Arc, OnceLock};
use anyhow::{Result, bail};
use anyhow::Result;
use futures::StreamExt;
use tokio::sync::oneshot;
use tokio_util::sync::CancellationToken;
......@@ -23,6 +23,23 @@ use crate::{
protocols::common::llm_backend::{LLMEngineOutput, PreprocessedRequest},
};
/// Errors that can occur during prefill routing
#[derive(Debug, thiserror::Error)]
pub enum PrefillError {
/// Prefill router has not been activated yet
#[error("Prefill router not yet activated")]
NotActivated,
/// Error during prefill execution
/// TODO: Separate prefill worker error from prefill router error
#[error("Prefill execution failed: {0}")]
PrefillError(String),
/// Disaggregated params not found in prefill response
#[error("No disaggregated params in prefill response: {0}")]
NoDisaggregatedParams(String),
}
/// The inner router used by PrefillRouter
enum InnerPrefillRouter {
/// KV-aware routing using KvPushRouter
......@@ -38,15 +55,17 @@ pub struct PrefillRouter {
prefill_router: OnceLock<InnerPrefillRouter>,
cancel_token: CancellationToken,
router_mode: RouterMode,
enforce_disagg: bool,
}
impl PrefillRouter {
/// Create a disabled prefill router that will never activate (passthrough only)
pub fn disabled(router_mode: RouterMode) -> Arc<Self> {
pub fn disabled(router_mode: RouterMode, enforce_disagg: bool) -> Arc<Self> {
Arc::new(Self {
prefill_router: OnceLock::new(),
cancel_token: CancellationToken::new(),
router_mode,
enforce_disagg,
})
}
......@@ -56,6 +75,7 @@ impl PrefillRouter {
router_mode: RouterMode,
kv_cache_block_size: u32,
kv_router_config: Option<KvRouterConfig>,
enforce_disagg: bool,
) -> Arc<Self> {
let prefill_router = OnceLock::new();
let cancel_token = CancellationToken::new();
......@@ -64,6 +84,7 @@ impl PrefillRouter {
prefill_router,
cancel_token: cancel_token.clone(),
router_mode,
enforce_disagg,
});
// Spawn background task to wait for activation
......@@ -158,34 +179,48 @@ impl PrefillRouter {
async fn call_prefill(
&self,
request: SingleIn<PreprocessedRequest>,
) -> Result<serde_json::Value> {
) -> Result<serde_json::Value, PrefillError> {
// Get the prefill router, error if not activated
let Some(prefill_router) = self.prefill_router.get() else {
bail!("Prefill router not yet activated");
return Err(PrefillError::NotActivated);
};
// Call the appropriate router based on the type
let mut prefill_response = match prefill_router {
InnerPrefillRouter::KvRouter(router) => router.generate(request).await?,
InnerPrefillRouter::SimpleRouter(router) => router.generate(request).await?,
InnerPrefillRouter::KvRouter(router) => router
.generate(request)
.await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?,
InnerPrefillRouter::SimpleRouter(router) => router
.generate(request)
.await
.map_err(|e| PrefillError::PrefillError(e.to_string()))?,
};
let Some(first_output) = prefill_response.next().await else {
bail!("Prefill router returned no output (stream ended)");
return Err(PrefillError::PrefillError(
"Prefill router returned no output (stream ended)".to_string(),
));
};
while prefill_response.next().await.is_some() {}
if let Some(err) = first_output.err() {
bail!("Prefill router returned error in output: {err:?}");
return Err(PrefillError::PrefillError(format!(
"Prefill router returned error in output: {err:?}"
)));
}
let Some(output) = &first_output.data else {
bail!("Prefill router output has no data field");
return Err(PrefillError::NoDisaggregatedParams(
"Prefill router output has no data field".to_string(),
));
};
let Some(disaggregated_params) = output.disaggregated_params.clone() else {
bail!("Prefill router output missing disaggregated_params");
return Err(PrefillError::NoDisaggregatedParams(
"Prefill router output missing disaggregated_params".to_string(),
));
};
Ok(disaggregated_params)
......@@ -252,7 +287,24 @@ impl
let decode_request = context.map(|_| decode_req);
next.generate(decode_request).await
}
Err(PrefillError::NotActivated) => {
if self.enforce_disagg {
tracing::error!(
"Prefill router not activated, but disaggregated mode is enforced. Failing request."
);
return Err(anyhow::anyhow!(PrefillError::NotActivated));
}
tracing::debug!("Prefill router not activated, falling back to decode-only");
next.generate(context.map(|_| req)).await
}
Err(e) => {
if self.enforce_disagg {
tracing::error!(
error = %e,
"Remote prefill failed, but disaggregated mode is enforced. Failing request."
);
return Err(anyhow::anyhow!(e));
}
tracing::warn!(
error = %e,
"Remote prefill failed, falling back to decode-only. This may impact performance in disaggregated deployments. Verify prefill workers are healthy and accessible."
......
......@@ -339,9 +339,7 @@ mod integration_tests {
let model_watcher = ModelWatcher::new(
distributed_runtime.clone(),
service.state().manager_clone(),
RouterMode::RoundRobin,
None,
None,
dynamo_llm::entrypoint::RouterConfig::default(),
);
// Start watching for model registrations via discovery interface
let discovery = distributed_runtime.discovery();
......@@ -512,9 +510,7 @@ mod integration_tests {
let watcher = ModelWatcher::new(
distributed_runtime.clone(),
service.state().manager_clone(),
RouterMode::RoundRobin,
None,
None,
dynamo_llm::entrypoint::RouterConfig::default(),
);
// Get all model entries for our test model
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment