Unverified Commit d2aad651 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix(runtime): graceful removal of disagg model from /v1/models when prefill engine dies (#7131)

parent 49feb284
...@@ -176,7 +176,7 @@ impl Model { ...@@ -176,7 +176,7 @@ impl Model {
self.worker_sets.iter().any(|entry| { self.worker_sets.iter().any(|entry| {
let ws = entry.value(); let ws = entry.value();
if ws.worker_count() == 0 { if ws.worker_count() == 0 || !ws.can_serve_requests() {
return false; return false;
} }
has_serving_engine(ws.as_ref()) || (!has_any_serving_engine && ws.is_prefill_set()) has_serving_engine(ws.as_ref()) || (!has_any_serving_engine && ws.is_prefill_set())
...@@ -189,41 +189,41 @@ impl Model { ...@@ -189,41 +189,41 @@ impl Model {
&self, &self,
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> { ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.chat_engine.clone()) self.select_worker_set_with(|ws| ws.chat_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| self.engine_error(self.has_chat_engine()))
} }
pub fn get_completions_engine( pub fn get_completions_engine(
&self, &self,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> { ) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.completions_engine.clone()) self.select_worker_set_with(|ws| ws.completions_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| self.engine_error(self.has_completions_engine()))
} }
pub fn get_embeddings_engine( pub fn get_embeddings_engine(
&self, &self,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> { ) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.embeddings_engine.clone()) self.select_worker_set_with(|ws| ws.embeddings_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| self.engine_error(self.has_embeddings_engine()))
} }
pub fn get_images_engine(&self) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> { pub fn get_images_engine(&self) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.images_engine.clone()) self.select_worker_set_with(|ws| ws.images_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| self.engine_error(self.has_images_engine()))
} }
pub fn get_videos_engine(&self) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> { pub fn get_videos_engine(&self) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.videos_engine.clone()) self.select_worker_set_with(|ws| ws.videos_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| self.engine_error(self.has_videos_engine()))
} }
pub fn get_audios_engine(&self) -> Result<OpenAIAudiosStreamingEngine, ModelManagerError> { pub fn get_audios_engine(&self) -> Result<OpenAIAudiosStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.audios_engine.clone()) self.select_worker_set_with(|ws| ws.audios_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| self.engine_error(self.has_audios_engine()))
} }
pub fn get_tensor_engine(&self) -> Result<TensorStreamingEngine, ModelManagerError> { pub fn get_tensor_engine(&self) -> Result<TensorStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.tensor_engine.clone()) self.select_worker_set_with(|ws| ws.tensor_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| self.engine_error(self.has_tensor_engine()))
} }
// -- Combined engine + parsing options (atomically from one WorkerSet) -- // -- Combined engine + parsing options (atomically from one WorkerSet) --
...@@ -232,7 +232,7 @@ impl Model { ...@@ -232,7 +232,7 @@ impl Model {
&self, &self,
) -> Result<(OpenAIChatCompletionsStreamingEngine, ParsingOptions), ModelManagerError> { ) -> Result<(OpenAIChatCompletionsStreamingEngine, ParsingOptions), ModelManagerError> {
self.select_worker_set_with(|ws| ws.chat_engine.clone().map(|e| (e, ws.parsing_options()))) self.select_worker_set_with(|ws| ws.chat_engine.clone().map(|e| (e, ws.parsing_options())))
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| self.engine_error(self.has_chat_engine()))
} }
pub fn get_completions_engine_with_parsing( pub fn get_completions_engine_with_parsing(
...@@ -243,7 +243,7 @@ impl Model { ...@@ -243,7 +243,7 @@ impl Model {
.clone() .clone()
.map(|e| (e, ws.parsing_options())) .map(|e| (e, ws.parsing_options()))
}) })
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone())) .ok_or_else(|| self.engine_error(self.has_completions_engine()))
} }
// -- Worker monitoring (aggregated across WorkerSets) -- // -- Worker monitoring (aggregated across WorkerSets) --
...@@ -283,6 +283,19 @@ impl Model { ...@@ -283,6 +283,19 @@ impl Model {
.sum() .sum()
} }
// -- Internal helpers --
/// Return the appropriate error when no servable WorkerSet was found.
/// If the engine exists but no WorkerSet can serve (zero workers, prefill not activated,
/// etc.), return ModelUnavailable (maps to 503). Otherwise ModelNotFound (maps to 404).
fn engine_error(&self, engine_exists: bool) -> ModelManagerError {
if engine_exists {
ModelManagerError::ModelUnavailable(self.name.clone())
} else {
ModelManagerError::ModelNotFound(self.name.clone())
}
}
// -- Internal selection -- // -- Internal selection --
/// Select a WorkerSet and extract a value from it. /// Select a WorkerSet and extract a value from it.
...@@ -298,19 +311,18 @@ impl Model { ...@@ -298,19 +311,18 @@ impl Model {
F: Fn(&WorkerSet) -> Option<T>, F: Fn(&WorkerSet) -> Option<T>,
{ {
// Fast path: single set (same zero-worker filtering as the multi-set path below) // Fast path: single set (same zero-worker filtering as the multi-set path below)
// TODO: When the single set has 0 workers, this returns None which maps to
// ModelNotFound (404). Ideally should be 503 "no available workers" — see follow-up.
if self.worker_sets.len() == 1 { if self.worker_sets.len() == 1 {
return self.worker_sets.iter().next().and_then(|entry| { return self.worker_sets.iter().next().and_then(|entry| {
let ws = entry.value(); let ws = entry.value();
if ws.worker_count() == 0 { if ws.worker_count() == 0 || !ws.can_serve_requests() {
return None; return None;
} }
extract(ws) extract(ws)
}); });
} }
// Collect eligible sets with their worker counts, skipping sets with no workers. // Collect eligible sets with their worker counts, skipping sets with no workers
// or sets whose prefill router has died under enforce_disagg.
// In-process models (no discovery watcher) return count=1, so they always participate. // In-process models (no discovery watcher) return count=1, so they always participate.
// Discovery models with count=0 have no available workers and are skipped. // Discovery models with count=0 have no available workers and are skipped.
let eligible: Vec<(T, usize)> = self let eligible: Vec<(T, usize)> = self
...@@ -319,7 +331,7 @@ impl Model { ...@@ -319,7 +331,7 @@ impl Model {
.filter_map(|entry| { .filter_map(|entry| {
let ws = entry.value(); let ws = entry.value();
let count = ws.worker_count(); let count = ws.worker_count();
if count == 0 { if count == 0 || !ws.can_serve_requests() {
return None; return None;
} }
extract(ws).map(|val| (val, count)) extract(ws).map(|val| (val, count))
...@@ -600,4 +612,112 @@ mod tests { ...@@ -600,4 +612,112 @@ mod tests {
// Both have 0 workers → all filtered → Err // Both have 0 workers → all filtered → Err
assert!(model.get_chat_engine().is_err()); assert!(model.get_chat_engine().is_err());
} }
// -- Disaggregated prefill death tests --
use crate::kv_router::PrefillRouter;
/// Build a WorkerSet with a deactivated PrefillRouter simulating "was activated, now dead".
/// worker_count defaults to 1 (no instance_count_rx -> in-process default).
fn make_worker_set_with_dead_prefill(namespace: &str, enforce_disagg: bool) -> Arc<WorkerSet> {
let mut ws = WorkerSet::new(
namespace.to_string(),
"abc".to_string(),
crate::model_card::ModelDeploymentCard::default(),
);
let pr = PrefillRouter::disabled(
std::sync::Arc::new(crate::discovery::ModelManager::new()),
dynamo_runtime::pipeline::RouterMode::RoundRobin,
enforce_disagg,
);
pr.deactivate();
ws.prefill_router = Some(pr);
Arc::new(ws)
}
/// Baseline: a WorkerSet without a PrefillRouter is always displayable
/// (worker_count=1, is_prefill_set=true, no can_serve_requests block).
#[test]
fn test_is_displayable_true_basic() {
let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
assert!(
model.is_displayable(),
"model with an unconstrained WorkerSet must be displayable"
);
}
/// When the prefill engine dies and enforce_disagg is set, the model must be
/// hidden from /v1/models.
#[test]
fn test_is_displayable_false_when_prefill_dies_enforce_disagg() {
let model = Model::new("llama".to_string());
model.add_worker_set(
"ns1".to_string(),
make_worker_set_with_dead_prefill("ns1", true),
);
assert!(
!model.is_displayable(),
"model must be hidden when prefill died and enforce_disagg=true"
);
}
/// When enforce_disagg is false the deployment can fall back to aggregated mode,
/// so the model should remain visible in /v1/models.
#[test]
fn test_is_displayable_true_when_prefill_dies_no_enforce() {
let model = Model::new("llama".to_string());
model.add_worker_set(
"ns1".to_string(),
make_worker_set_with_dead_prefill("ns1", false),
);
assert!(
model.is_displayable(),
"model must remain visible when prefill died but enforce_disagg=false (fallback)"
);
}
/// A single WorkerSet with a deactivated prefill router (enforce_disagg=true) must be
/// skipped by select_worker_set_with(), causing engine accessors to return Err.
#[test]
fn test_dead_prefill_single_set_not_selectable() {
let model = Model::new("llama".to_string());
model.add_worker_set(
"ns1".to_string(),
make_worker_set_with_dead_prefill("ns1", true),
);
assert!(model.get_chat_engine().is_err());
assert!(model.get_completions_engine().is_err());
}
/// With two WorkerSets -- one healthy, one with dead prefill -- the healthy set
/// keeps the model displayable. Removing the healthy set hides the model.
#[test]
fn test_dead_prefill_multi_set_skips_dead_namespace() {
let model = Model::new("llama".to_string());
// Healthy set (no prefill constraint)
model.add_worker_set("healthy".to_string(), make_worker_set("healthy", "abc"));
// Dead set (deactivated prefill + enforce_disagg)
model.add_worker_set(
"dead".to_string(),
make_worker_set_with_dead_prefill("dead", true),
);
assert!(
model.is_displayable(),
"model must be displayable when at least one healthy set exists"
);
// Removing the healthy set leaves only the dead set -- model must be hidden.
model.remove_worker_set("healthy");
assert!(
!model.is_displayable(),
"model must be hidden when only the dead prefill set remains"
);
}
} }
...@@ -46,6 +46,9 @@ pub enum ModelManagerError { ...@@ -46,6 +46,9 @@ pub enum ModelManagerError {
#[error("Model not found: {0}")] #[error("Model not found: {0}")]
ModelNotFound(String), ModelNotFound(String),
#[error("Model unavailable: {0}")]
ModelUnavailable(String),
#[error("Model already exists: {0}")] #[error("Model already exists: {0}")]
ModelAlreadyExists(String), ModelAlreadyExists(String),
} }
...@@ -703,6 +706,39 @@ impl ModelManager { ...@@ -703,6 +706,39 @@ impl ModelManager {
); );
} }
None => { None => {
// Try to reactivate an existing deactivated router first.
// This handles prefill rejoin after a transient failure: the decode
// WorkerSet's PrefillRouter already exists but is deactivated.
if let Some(model) = self.get_model(model_name)
&& let Some(ws) = model.get_worker_set(namespace)
&& let Some(ref pr) = ws.prefill_router
&& pr.is_deactivated()
{
pr.reactivate();
// Store the endpoint so that if the decode WorkerSet is rebuilt
// (removed and re-added), a subsequent register_prefill_router call
// finds PrefillReady instead of falling back to DecodeWaiting and
// stalling.
let (tx, rx) = oneshot::channel();
tx.send(endpoint).map_err(|_| {
anyhow::anyhow!(
"Failed to send endpoint for prefill model {}:{}",
model_name,
namespace
)
})?;
self.prefill_router_activators
.insert(key, PrefillActivationState::PrefillReady(rx));
tracing::info!(
model_name = %model_name,
namespace = %namespace,
"Reactivated existing prefill router for decode WorkerSet (prefill rejoin)"
);
return Ok(());
}
// No existing deactivated router -- store endpoint for a future decode
// registration.
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
tx.send(endpoint).map_err(|_| { tx.send(endpoint).map_err(|_| {
anyhow::anyhow!( anyhow::anyhow!(
...@@ -723,6 +759,18 @@ impl ModelManager { ...@@ -723,6 +759,18 @@ impl ModelManager {
} }
} }
/// Deactivate the prefill router on the decode WorkerSet for the given model/namespace.
/// Called by the watcher when all prefill workers in a namespace are removed.
/// After deactivation, requests fall back to aggregated mode (or fail if enforce_disagg).
pub fn deactivate_prefill_router_for_decode(&self, model_name: &str, namespace: &str) {
if let Some(model) = self.get_model(model_name)
&& let Some(ws) = model.get_worker_set(namespace)
&& let Some(ref pr) = ws.prefill_router
{
pr.deactivate();
}
}
/// Remove the prefill router activator for a (model, namespace) pair. /// Remove the prefill router activator for a (model, namespace) pair.
/// Called when a WorkerSet is removed to prevent stale activators. /// Called when a WorkerSet is removed to prevent stale activators.
pub fn remove_prefill_activator(&self, model_name: &str, namespace: &str) { pub fn remove_prefill_activator(&self, model_name: &str, namespace: &str) {
...@@ -1090,4 +1138,162 @@ mod tests { ...@@ -1090,4 +1138,162 @@ mod tests {
"gpt-4:default-abc" "gpt-4:default-abc"
); );
} }
// -- deactivate_prefill_router_for_decode tests --
use crate::kv_router::PrefillRouter;
/// Helper: make a WorkerSet with an activated PrefillRouter attached.
/// The router is marked as activated to simulate a real deployment where
/// the prefill endpoint has already rendezvoused with the decode side.
fn make_worker_set_with_prefill_router(
namespace: &str,
mdcsum: &str,
enforce_disagg: bool,
) -> WorkerSet {
let mut ws = make_worker_set(namespace, mdcsum);
let pr = PrefillRouter::disabled(
std::sync::Arc::new(ModelManager::new()),
dynamo_runtime::pipeline::RouterMode::RoundRobin,
enforce_disagg,
);
pr.mark_activated_for_test();
ws.prefill_router = Some(pr);
ws
}
/// Calling deactivate on a non-existent model must not panic.
#[test]
fn test_deactivate_prefill_router_for_decode_noop_missing_model() {
let mm = ModelManager::new();
mm.deactivate_prefill_router_for_decode("nonexistent", "ns1");
}
/// Calling deactivate on a WorkerSet without a prefill_router must not panic.
#[test]
fn test_deactivate_prefill_router_for_decode_noop_no_router() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
mm.deactivate_prefill_router_for_decode("llama", "ns1");
}
/// Full pipeline test: deactivate finds the WorkerSet, calls deactivate() on its
/// PrefillRouter, and the model is hidden from model_display_names() when
/// enforce_disagg=true.
#[test]
fn test_deactivate_prefill_router_for_decode_hides_model() {
let mm = ModelManager::new();
mm.add_worker_set(
"llama",
"ns1",
make_worker_set_with_prefill_router("ns1", "abc", true),
);
// Model is visible before deactivation.
assert!(mm.model_display_names().contains("llama"));
mm.deactivate_prefill_router_for_decode("llama", "ns1");
// Model must be hidden after deactivation with enforce_disagg=true.
assert!(
!mm.model_display_names().contains("llama"),
"model must be hidden after prefill deactivation with enforce_disagg=true"
);
// Idempotent: calling again must not panic.
mm.deactivate_prefill_router_for_decode("llama", "ns1");
assert!(!mm.model_display_names().contains("llama"));
}
/// Full disagg lifecycle with enforce_disagg=true:
/// decode registers -> prefill registers -> prefill dies -> model hidden.
#[test]
fn test_disagg_lifecycle_prefill_death_hides_model() {
let mm = ModelManager::new();
// Step 1: Decode WorkerSet with a PrefillRouter (not yet deactivated).
mm.add_worker_set(
"llama",
"decode-ns",
make_worker_set_with_prefill_router("decode-ns", "abc", true),
);
assert!(
mm.model_display_names().contains("llama"),
"step 1: model must be visible with active prefill router"
);
// Step 2: Prefill WorkerSet registers (same model, different namespace key).
mm.add_worker_set("llama", "prefill-ns", make_worker_set("prefill-ns", "abc"));
assert!(
mm.model_display_names().contains("llama"),
"step 2: model must be visible with both decode and prefill"
);
// Step 3: Prefill WorkerSet removed (engine dies).
mm.remove_worker_set("llama", "prefill-ns");
// Step 4: Deactivate the prefill router on the decode side.
mm.deactivate_prefill_router_for_decode("llama", "decode-ns");
assert!(
!mm.model_display_names().contains("llama"),
"step 4: model must be hidden after prefill death with enforce_disagg=true"
);
}
/// Full disagg lifecycle with enforce_disagg=false (fallback allowed).
#[test]
fn test_disagg_lifecycle_prefill_death_keeps_model_no_enforce() {
let mm = ModelManager::new();
mm.add_worker_set(
"llama",
"decode-ns",
make_worker_set_with_prefill_router("decode-ns", "abc", false),
);
assert!(mm.model_display_names().contains("llama"));
// Deactivate -- model stays visible (enforce_disagg=false, fallback allowed).
mm.deactivate_prefill_router_for_decode("llama", "decode-ns");
assert!(
mm.model_display_names().contains("llama"),
"model must remain visible (enforce_disagg=false, fallback allowed)"
);
}
/// Full disagg lifecycle including prefill rejoin after transient failure.
/// decode registers -> prefill dies -> model hidden -> prefill rejoins -> model visible.
#[test]
fn test_disagg_lifecycle_prefill_rejoin_restores_model() {
let mm = ModelManager::new();
// Decode WorkerSet with enforce_disagg=true.
mm.add_worker_set(
"llama",
"decode-ns",
make_worker_set_with_prefill_router("decode-ns", "abc", true),
);
assert!(mm.model_display_names().contains("llama"));
// Prefill dies -> deactivate.
mm.deactivate_prefill_router_for_decode("llama", "decode-ns");
assert!(
!mm.model_display_names().contains("llama"),
"model must be hidden after prefill death"
);
// Prefill rejoins -> reactivate via the WorkerSet's PrefillRouter.
if let Some(model) = mm.get_model("llama")
&& let Some(ws) = model.get_worker_set("decode-ns")
&& let Some(ref pr) = ws.prefill_router
{
pr.reactivate();
} else {
panic!("decode WorkerSet or prefill_router not found");
}
assert!(
mm.model_display_names().contains("llama"),
"model must be visible again after prefill rejoin"
);
}
} }
...@@ -336,6 +336,15 @@ impl ModelWatcher { ...@@ -336,6 +336,15 @@ impl ModelWatcher {
"Removed WorkerSet (no remaining instances in namespace)" "Removed WorkerSet (no remaining instances in namespace)"
); );
} }
// If the removed component was a prefill worker, deactivate the decode-side
// prefill router so requests fall back to aggregated mode (or fail cleanly
// with enforce_disagg). The decode WorkerSet's namespace matches the
// deployment namespace, not the ws_key.
if card.model_type.supports_prefill() {
self.manager
.deactivate_prefill_router_for_decode(&model_name, worker_namespace);
}
} }
// Check if the Model still has instances in any namespace // Check if the Model still has instances in any namespace
...@@ -542,9 +551,12 @@ impl ModelWatcher { ...@@ -542,9 +551,12 @@ impl ModelWatcher {
self.router_config.load_threshold_config.clone(), self.router_config.load_threshold_config.clone(),
)); ));
// Store KV router and worker monitor on the WorkerSet // Store KV router, worker monitor, and prefill router on the WorkerSet.
// The prefill router is stored so the watcher can deactivate/reactivate it
// when prefill workers die or rejoin.
worker_set.kv_router = kv_chooser.clone(); worker_set.kv_router = kv_chooser.clone();
worker_set.worker_monitor = worker_monitor.clone(); worker_set.worker_monitor = worker_monitor.clone();
worker_set.prefill_router = prefill_chooser.clone();
// Add chat engine only if the model supports chat // Add chat engine only if the model supports chat
if card.model_type.supports_chat() { if card.model_type.supports_chat() {
......
...@@ -11,7 +11,7 @@ use tokio::sync::watch; ...@@ -11,7 +11,7 @@ use tokio::sync::watch;
use crate::{ use crate::{
discovery::KvWorkerMonitor, discovery::KvWorkerMonitor,
kv_router::KvRouter, kv_router::{KvRouter, PrefillRouter},
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
types::{ types::{
generic::tensor::TensorStreamingEngine, generic::tensor::TensorStreamingEngine,
...@@ -51,6 +51,10 @@ pub struct WorkerSet { ...@@ -51,6 +51,10 @@ pub struct WorkerSet {
/// Worker monitor for load-based rejection /// Worker monitor for load-based rejection
pub(crate) worker_monitor: Option<KvWorkerMonitor>, pub(crate) worker_monitor: Option<KvWorkerMonitor>,
/// Prefill router for disaggregated serving. Stored here so the watcher can
/// deactivate it when all prefill workers die, and reactivate when they rejoin.
pub(crate) prefill_router: Option<Arc<PrefillRouter>>,
/// Watcher for available instance IDs (from the Client's discovery watch). /// Watcher for available instance IDs (from the Client's discovery watch).
/// None for in-process models (http/grpc) which don't have a discovery client. /// None for in-process models (http/grpc) which don't have a discovery client.
instance_count_rx: Option<watch::Receiver<Vec<u64>>>, instance_count_rx: Option<watch::Receiver<Vec<u64>>>,
...@@ -71,6 +75,7 @@ impl WorkerSet { ...@@ -71,6 +75,7 @@ impl WorkerSet {
tensor_engine: None, tensor_engine: None,
kv_router: None, kv_router: None,
worker_monitor: None, worker_monitor: None,
prefill_router: None,
instance_count_rx: None, instance_count_rx: None,
} }
} }
...@@ -152,6 +157,16 @@ impl WorkerSet { ...@@ -152,6 +157,16 @@ impl WorkerSet {
pub fn set_instance_watcher(&mut self, rx: watch::Receiver<Vec<u64>>) { pub fn set_instance_watcher(&mut self, rx: watch::Receiver<Vec<u64>>) {
self.instance_count_rx = Some(rx); self.instance_count_rx = Some(rx);
} }
/// Whether this WorkerSet can serve requests. Delegates to the prefill router
/// if one exists; otherwise always returns true.
/// When the prefill router is deactivated and enforce_disagg is set, this returns
/// false, causing the model to be hidden from /v1/models and requests to be rejected.
pub fn can_serve_requests(&self) -> bool {
self.prefill_router
.as_ref()
.is_none_or(|pr| pr.can_serve_requests())
}
} }
#[cfg(test)] #[cfg(test)]
......
...@@ -136,7 +136,26 @@ pub async fn prepare_engine( ...@@ -136,7 +136,26 @@ pub async fn prepare_engine(
let model_service_name = watch_obj.wait_for_chat_model().await; let model_service_name = watch_obj.wait_for_chat_model().await;
tracing::info!("Connected to {model_service_name}"); tracing::info!("Connected to {model_service_name}");
let engine = model_manager.get_chat_completions_engine(&model_service_name)?; // In disaggregated deployments the model may be listed before the prefill
// router is fully activated, causing a transient ModelUnavailable. Retry
// with a timeout so the startup path doesn't fail during this cold-start
// window, but also doesn't hang indefinitely on misconfiguration.
let deadline = tokio::time::Instant::now() + Duration::from_secs(120);
let engine = loop {
match model_manager.get_chat_completions_engine(&model_service_name) {
Ok(engine) => break engine,
Err(crate::discovery::ModelManagerError::ModelUnavailable(_))
if tokio::time::Instant::now() < deadline =>
{
tracing::debug!(
model = %model_service_name,
"Model listed but not yet servable, waiting for prefill activation"
);
tokio::time::sleep(Duration::from_millis(500)).await;
}
Err(e) => return Err(e.into()),
}
};
Ok(PreparedEngine { Ok(PreparedEngine {
service_name: model_service_name, service_name: model_service_name,
engine, engine,
......
...@@ -85,7 +85,12 @@ pub async fn completion_response_stream( ...@@ -85,7 +85,12 @@ pub async fn completion_response_stream(
let (engine, parsing_options) = state let (engine, parsing_options) = state
.manager() .manager()
.get_completions_engine_with_parsing(model) .get_completions_engine_with_parsing(model)
.map_err(|_| Status::not_found("model not found"))?; .map_err(|e| match e {
crate::discovery::ModelManagerError::ModelUnavailable(_) => {
Status::unavailable("model temporarily unavailable")
}
_ => Status::not_found("model not found"),
})?;
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model); let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
......
...@@ -86,7 +86,12 @@ pub async fn tensor_response_stream( ...@@ -86,7 +86,12 @@ pub async fn tensor_response_stream(
let engine = state let engine = state
.manager() .manager()
.get_tensor_engine(model) .get_tensor_engine(model)
.map_err(|_| Status::not_found("model not found"))?; .map_err(|e| match e {
crate::discovery::ModelManagerError::ModelUnavailable(_) => {
Status::unavailable("model temporarily unavailable")
}
_ => Status::not_found("model not found"),
})?;
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model); let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
......
...@@ -143,6 +143,29 @@ impl ErrorMessage { ...@@ -143,6 +143,29 @@ impl ErrorMessage {
) )
} }
/// Model exists but is temporarily unable to serve (e.g., prefill not activated,
/// no available workers). Returns 503 so clients can retry.
pub fn model_unavailable() -> ErrorResponse {
let code = StatusCode::SERVICE_UNAVAILABLE;
let error_type = map_error_code_to_error_type(code);
(
code,
Json(ErrorMessage {
message: "Model temporarily unavailable".to_string(),
error_type,
code: code.as_u16(),
}),
)
}
/// Convert a ModelManagerError to the appropriate HTTP response.
pub fn from_model_error(e: &crate::discovery::ModelManagerError) -> ErrorResponse {
match e {
crate::discovery::ModelManagerError::ModelUnavailable(_) => Self::model_unavailable(),
_ => Self::model_not_found(),
}
}
/// Service Unavailable /// Service Unavailable
/// This is returned when the service is live, but not ready. /// This is returned when the service is live, but not ready.
pub fn _service_unavailable() -> ErrorResponse { pub fn _service_unavailable() -> ErrorResponse {
...@@ -469,8 +492,8 @@ async fn completions_single( ...@@ -469,8 +492,8 @@ async fn completions_single(
let (engine, parsing_options) = state let (engine, parsing_options) = state
.manager() .manager()
.get_completions_engine_with_parsing(&model) .get_completions_engine_with_parsing(&model)
.map_err(|_| { .map_err(|e| {
let err_response = ErrorMessage::model_not_found(); let err_response = ErrorMessage::from_model_error(&e);
inflight_guard.mark_error(extract_error_type_from_response(&err_response)); inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response err_response
})?; })?;
...@@ -609,8 +632,8 @@ async fn completions_batch( ...@@ -609,8 +632,8 @@ async fn completions_batch(
let (engine, parsing_options) = state let (engine, parsing_options) = state
.manager() .manager()
.get_completions_engine_with_parsing(&model) .get_completions_engine_with_parsing(&model)
.map_err(|_| { .map_err(|e| {
let err_response = ErrorMessage::model_not_found(); let err_response = ErrorMessage::from_model_error(&e);
inflight_guard.mark_error(extract_error_type_from_response(&err_response)); inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response err_response
})?; })?;
...@@ -790,8 +813,8 @@ async fn embeddings( ...@@ -790,8 +813,8 @@ async fn embeddings(
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model); let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
// todo - error handling should be more robust // todo - error handling should be more robust
let engine = state.manager().get_embeddings_engine(model).map_err(|_| { let engine = state.manager().get_embeddings_engine(model).map_err(|e| {
let err_response = ErrorMessage::model_not_found(); let err_response = ErrorMessage::from_model_error(&e);
inflight.mark_error(extract_error_type_from_response(&err_response)); inflight.mark_error(extract_error_type_from_response(&err_response));
err_response err_response
})?; })?;
...@@ -1200,8 +1223,8 @@ async fn chat_completions( ...@@ -1200,8 +1223,8 @@ async fn chat_completions(
let (engine, parsing_options) = state let (engine, parsing_options) = state
.manager() .manager()
.get_chat_completions_engine_with_parsing(&model) .get_chat_completions_engine_with_parsing(&model)
.map_err(|_| { .map_err(|e| {
let err_response = ErrorMessage::model_not_found(); let err_response = ErrorMessage::from_model_error(&e);
inflight_guard.mark_error(extract_error_type_from_response(&err_response)); inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response err_response
})?; })?;
...@@ -1612,8 +1635,8 @@ async fn responses( ...@@ -1612,8 +1635,8 @@ async fn responses(
let (engine, parsing_options) = state let (engine, parsing_options) = state
.manager() .manager()
.get_chat_completions_engine_with_parsing(&model) .get_chat_completions_engine_with_parsing(&model)
.map_err(|_| { .map_err(|e| {
let err_response = ErrorMessage::model_not_found(); let err_response = ErrorMessage::from_model_error(&e);
inflight_guard.mark_error(extract_error_type_from_response(&err_response)); inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response err_response
})?; })?;
...@@ -2065,7 +2088,7 @@ async fn images( ...@@ -2065,7 +2088,7 @@ async fn images(
let engine = state let engine = state
.manager() .manager()
.get_images_engine(&model) .get_images_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|e| ErrorMessage::from_model_error(&e))?;
// this will increment the inflight gauge for the model // this will increment the inflight gauge for the model
let mut inflight = state.metrics_clone().create_inflight_guard( let mut inflight = state.metrics_clone().create_inflight_guard(
...@@ -2183,7 +2206,7 @@ async fn videos( ...@@ -2183,7 +2206,7 @@ async fn videos(
let engine = state let engine = state
.manager() .manager()
.get_videos_engine(&model) .get_videos_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|e| ErrorMessage::from_model_error(&e))?;
// this will increment the inflight gauge for the model // this will increment the inflight gauge for the model
let mut inflight = state.metrics_clone().create_inflight_guard( let mut inflight = state.metrics_clone().create_inflight_guard(
...@@ -2256,7 +2279,7 @@ async fn video_stream( ...@@ -2256,7 +2279,7 @@ async fn video_stream(
let engine = state let engine = state
.manager() .manager()
.get_videos_engine(&model) .get_videos_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|e| ErrorMessage::from_model_error(&e))?;
let mut inflight = let mut inflight =
state state
...@@ -2432,7 +2455,7 @@ async fn audio_speech( ...@@ -2432,7 +2455,7 @@ async fn audio_speech(
let engine = state let engine = state
.manager() .manager()
.get_audios_engine(&model) .get_audios_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|e| ErrorMessage::from_model_error(&e))?;
let mut inflight = state.metrics_clone().create_inflight_guard( let mut inflight = state.metrics_clone().create_inflight_guard(
&model, &model,
...@@ -3397,6 +3420,30 @@ mod tests { ...@@ -3397,6 +3420,30 @@ mod tests {
); );
} }
#[test]
fn test_extract_error_type_from_response_unavailable() {
let response = ErrorMessage::model_unavailable();
assert_eq!(
extract_error_type_from_response(&response),
ErrorType::Overload
);
}
#[test]
fn test_from_model_error_maps_correctly() {
let not_found = ModelManagerError::ModelNotFound("x".to_string());
assert_eq!(
ErrorMessage::from_model_error(&not_found).0,
StatusCode::NOT_FOUND
);
let unavailable = ModelManagerError::ModelUnavailable("x".to_string());
assert_eq!(
ErrorMessage::from_model_error(&unavailable).0,
StatusCode::SERVICE_UNAVAILABLE
);
}
#[test] #[test]
fn test_extract_error_type_from_response_internal() { fn test_extract_error_type_from_response_internal() {
let response = ErrorMessage::internal_server_error("Something went wrong"); let response = ErrorMessage::internal_server_error("Something went wrong");
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::Ordering;
use anyhow::Result; use anyhow::Result;
use tokio::sync::oneshot; use tokio::sync::oneshot;
...@@ -41,6 +42,8 @@ impl PrefillRouter { ...@@ -41,6 +42,8 @@ impl PrefillRouter {
model_name: String::new(), // Not used for disabled router model_name: String::new(), // Not used for disabled router
namespace: String::new(), // Not used for disabled router namespace: String::new(), // Not used for disabled router
is_eagle: false, is_eagle: false,
deactivated: std::sync::atomic::AtomicBool::new(false),
activated: std::sync::atomic::AtomicBool::new(false),
}) })
} }
...@@ -71,6 +74,8 @@ impl PrefillRouter { ...@@ -71,6 +74,8 @@ impl PrefillRouter {
model_name, model_name,
namespace, namespace,
is_eagle, is_eagle,
deactivated: std::sync::atomic::AtomicBool::new(false),
activated: std::sync::atomic::AtomicBool::new(false),
}); });
// Spawn background task to wait for activation // Spawn background task to wait for activation
...@@ -175,6 +180,7 @@ impl PrefillRouter { ...@@ -175,6 +180,7 @@ impl PrefillRouter {
// Set the router (ignore error if already set) // Set the router (ignore error if already set)
let _ = self.prefill_router.set(inner_router); let _ = self.prefill_router.set(inner_router);
self.activated.store(true, Ordering::Release);
tracing::info!( tracing::info!(
router_mode = ?self.router_mode, router_mode = ?self.router_mode,
...@@ -191,4 +197,70 @@ impl PrefillRouter { ...@@ -191,4 +197,70 @@ impl PrefillRouter {
monitor.set_prefill_client(client.clone()); monitor.set_prefill_client(client.clone());
} }
} }
// -- Prefill death handling --
/// Deactivate the prefill router. Called when all prefill workers are removed.
/// After deactivation, requests fall back to aggregated mode (or fail if enforce_disagg).
/// The inner router is preserved so that when workers rejoin (same endpoint/discovery),
/// the Client's discovery subscription picks them up automatically.
pub fn deactivate(&self) {
self.deactivated.store(true, Ordering::Release);
tracing::info!(
model_name = %self.model_name,
namespace = %self.namespace,
enforce_disagg = self.enforce_disagg,
"Prefill router deactivated (all prefill workers removed)"
);
}
/// Reactivate a deactivated router. Called when prefill workers rejoin.
/// The inner router's Client re-discovers workers via its discovery subscription.
///
/// Note: there is a brief race between flipping `deactivated=false` (making
/// `can_serve_requests()` return true) and the Client actually rediscovering
/// workers. Requests arriving in this window may fail at prefill resolution.
/// This is bounded by discovery propagation time (typically sub-second).
///
/// Also note: reactivation reuses the existing inner router built from the
/// original endpoint. If prefill rejoins under a different endpoint identity
/// (e.g., reconfigured deployment), the stale Client would not discover the
/// new workers. This is acceptable for normal restart scenarios where the
/// endpoint identity is stable.
pub fn reactivate(&self) {
self.deactivated.store(false, Ordering::Release);
tracing::info!(
model_name = %self.model_name,
namespace = %self.namespace,
"Prefill router reactivated (prefill workers rejoined)"
);
}
/// Whether this router is currently deactivated (prefill workers died).
pub fn is_deactivated(&self) -> bool {
self.deactivated.load(Ordering::Acquire)
}
/// Whether this router can serve requests in its current state.
/// - !enforce_disagg (aggregated passthrough): always servable unless deactivated
/// - enforce_disagg: only servable when prefill has activated AND is not deactivated,
/// so a cold-started strict-disagg model isn't listed before prefill rendezvoused.
pub fn can_serve_requests(&self) -> bool {
if self.is_deactivated() {
return !self.enforce_disagg;
}
if !self.enforce_disagg {
return true;
}
self.activated.load(Ordering::Acquire)
}
/// Mark this router as activated for testing purposes.
/// In production, `activate()` sets this flag when the inner router is populated.
#[cfg(test)]
pub(crate) fn mark_activated_for_test(&self) {
self.activated.store(true, Ordering::Release);
}
} }
...@@ -312,9 +312,10 @@ impl PrefillRouter { ...@@ -312,9 +312,10 @@ impl PrefillRouter {
} }
} }
/// Check if disaggregated mode is currently active (prefill router activated) /// Check if disaggregated mode is currently active (prefill router activated).
/// Uses the same `activated` flag as `can_serve_requests()` for consistency.
pub fn is_activated(&self) -> bool { pub fn is_activated(&self) -> bool {
self.prefill_router.get().is_some() self.activated.load(std::sync::atomic::Ordering::Acquire)
} }
/// Whether disaggregated mode is strictly enforced (fail if no prefill workers). /// Whether disaggregated mode is strictly enforced (fail if no prefill workers).
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock}; use std::sync::{Arc, OnceLock};
use anyhow::Result; use anyhow::Result;
...@@ -53,6 +54,13 @@ pub struct PrefillRouter { ...@@ -53,6 +54,13 @@ pub struct PrefillRouter {
/// Namespace used to look up the correct WorkerSet's worker monitor /// Namespace used to look up the correct WorkerSet's worker monitor
namespace: String, namespace: String,
is_eagle: bool, is_eagle: bool,
/// Set to true when all prefill workers die. Checked in generate() to prevent
/// routing to dead workers. Cleared on reactivation when workers rejoin.
deactivated: AtomicBool,
/// Set to true when the prefill router has been activated (inner router populated).
/// Used by `can_serve_requests()` to gate enforce_disagg readiness so a cold-started
/// strict-disagg model isn't listed before the prefill has rendezvoused.
activated: AtomicBool,
} }
impl Drop for PrefillRouter { impl Drop for PrefillRouter {
...@@ -84,10 +92,10 @@ impl ...@@ -84,10 +92,10 @@ impl
// Save original max_tokens for decode // Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens; let original_max_tokens = req.stop_conditions.max_tokens;
// If prefill router is not activated (no prefill workers discovered), // If prefill router is not activated (no prefill workers discovered) or has been
// this is aggregated mode route directly to decode. // deactivated (all prefill workers died), this is aggregated mode -- route directly
// With --enforce-disagg, fail instead of falling back. // to decode. With --enforce-disagg, fail instead of falling back.
if self.prefill_router.get().is_none() { if self.prefill_router.get().is_none() || self.deactivated.load(Ordering::Relaxed) {
if self.enforce_disagg { if self.enforce_disagg {
return Err(anyhow::anyhow!(PrefillError::NotActivated)); return Err(anyhow::anyhow!(PrefillError::NotActivated));
} }
...@@ -269,4 +277,102 @@ mod tests { ...@@ -269,4 +277,102 @@ mod tests {
assert_eq!(override_config.track_prefill_tokens, Some(false)); assert_eq!(override_config.track_prefill_tokens, Some(false));
assert_eq!(override_config.router_temperature, Some(0.7)); assert_eq!(override_config.router_temperature, Some(0.7));
} }
// -- Prefill death handling tests --
/// Helper: create a disabled PrefillRouter for testing deactivation behavior.
fn make_test_router(enforce_disagg: bool) -> Arc<PrefillRouter> {
PrefillRouter::disabled(
Arc::new(crate::discovery::ModelManager::new()),
RouterMode::RoundRobin,
enforce_disagg,
)
}
#[test]
fn test_deactivated_flag_blocks_when_enforce_disagg() {
let router = make_test_router(true);
// Not activated, so enforce_disagg blocks even before deactivation
assert!(
!router.can_serve_requests(),
"enforce_disagg must block before prefill activation"
);
router.deactivate();
assert!(router.is_deactivated());
assert!(
!router.can_serve_requests(),
"deactivated + enforce_disagg must block"
);
}
#[test]
fn test_deactivated_flag_allows_fallback_no_enforce() {
let router = make_test_router(false);
router.deactivate();
assert!(router.is_deactivated());
assert!(
router.can_serve_requests(),
"deactivated + !enforce_disagg must allow fallback"
);
}
#[test]
fn test_reactivate_clears_deactivated_no_enforce() {
let router = make_test_router(false);
router.deactivate();
// !enforce_disagg allows fallback even while deactivated
assert!(router.can_serve_requests());
router.reactivate();
assert!(!router.is_deactivated());
assert!(
router.can_serve_requests(),
"reactivated non-enforce router must serve requests"
);
}
#[test]
fn test_reactivate_clears_deactivated_enforce_needs_activation() {
// disabled() never sets the activated flag, so enforce_disagg stays blocked.
// In a real deployment, activate() sets the flag before the first
// deactivate/reactivate cycle, so this only exercises the flag reset.
let router = make_test_router(true);
router.deactivate();
assert!(!router.can_serve_requests());
router.reactivate();
assert!(!router.is_deactivated());
assert!(
!router.can_serve_requests(),
"enforce_disagg without activation still can't serve"
);
}
#[test]
fn test_fresh_router_not_deactivated() {
let router = make_test_router(true);
assert!(!router.is_deactivated());
// enforce_disagg + no prefill activation => not servable
assert!(!router.can_serve_requests());
}
#[test]
fn test_fresh_router_no_enforce_disagg_can_serve() {
let router = make_test_router(false);
assert!(!router.is_deactivated());
assert!(
router.can_serve_requests(),
"non-enforce_disagg router must be servable even without prefill activation"
);
}
#[test]
fn test_deactivate_is_idempotent() {
let router = make_test_router(true);
router.deactivate();
router.deactivate();
assert!(router.is_deactivated());
assert!(!router.can_serve_requests());
}
} }
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment