"tests/vscode:/vscode.git/clone" did not exist on "e54894fc85a9861fb38a49701b5844462c3d77e4"
Unverified Commit d2aad651 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix(runtime): graceful removal of disagg model from /v1/models when prefill engine dies (#7131)

parent 49feb284
......@@ -176,7 +176,7 @@ impl Model {
self.worker_sets.iter().any(|entry| {
let ws = entry.value();
if ws.worker_count() == 0 {
if ws.worker_count() == 0 || !ws.can_serve_requests() {
return false;
}
has_serving_engine(ws.as_ref()) || (!has_any_serving_engine && ws.is_prefill_set())
......@@ -189,41 +189,41 @@ impl Model {
&self,
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.chat_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
.ok_or_else(|| self.engine_error(self.has_chat_engine()))
}
pub fn get_completions_engine(
&self,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.completions_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
.ok_or_else(|| self.engine_error(self.has_completions_engine()))
}
pub fn get_embeddings_engine(
&self,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.embeddings_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
.ok_or_else(|| self.engine_error(self.has_embeddings_engine()))
}
pub fn get_images_engine(&self) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.images_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
.ok_or_else(|| self.engine_error(self.has_images_engine()))
}
pub fn get_videos_engine(&self) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.videos_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
.ok_or_else(|| self.engine_error(self.has_videos_engine()))
}
pub fn get_audios_engine(&self) -> Result<OpenAIAudiosStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.audios_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
.ok_or_else(|| self.engine_error(self.has_audios_engine()))
}
pub fn get_tensor_engine(&self) -> Result<TensorStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.tensor_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
.ok_or_else(|| self.engine_error(self.has_tensor_engine()))
}
// -- Combined engine + parsing options (atomically from one WorkerSet) --
......@@ -232,7 +232,7 @@ impl Model {
&self,
) -> Result<(OpenAIChatCompletionsStreamingEngine, ParsingOptions), ModelManagerError> {
self.select_worker_set_with(|ws| ws.chat_engine.clone().map(|e| (e, ws.parsing_options())))
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
.ok_or_else(|| self.engine_error(self.has_chat_engine()))
}
pub fn get_completions_engine_with_parsing(
......@@ -243,7 +243,7 @@ impl Model {
.clone()
.map(|e| (e, ws.parsing_options()))
})
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
.ok_or_else(|| self.engine_error(self.has_completions_engine()))
}
// -- Worker monitoring (aggregated across WorkerSets) --
......@@ -283,6 +283,19 @@ impl Model {
.sum()
}
// -- Internal helpers --
/// Return the appropriate error when no servable WorkerSet was found.
/// If the engine exists but no WorkerSet can serve (zero workers, prefill not activated,
/// etc.), return ModelUnavailable (maps to 503). Otherwise ModelNotFound (maps to 404).
fn engine_error(&self, engine_exists: bool) -> ModelManagerError {
if engine_exists {
ModelManagerError::ModelUnavailable(self.name.clone())
} else {
ModelManagerError::ModelNotFound(self.name.clone())
}
}
// -- Internal selection --
/// Select a WorkerSet and extract a value from it.
......@@ -298,19 +311,18 @@ impl Model {
F: Fn(&WorkerSet) -> Option<T>,
{
// Fast path: single set (same zero-worker filtering as the multi-set path below)
// TODO: When the single set has 0 workers, this returns None which maps to
// ModelNotFound (404). Ideally should be 503 "no available workers" — see follow-up.
if self.worker_sets.len() == 1 {
return self.worker_sets.iter().next().and_then(|entry| {
let ws = entry.value();
if ws.worker_count() == 0 {
if ws.worker_count() == 0 || !ws.can_serve_requests() {
return None;
}
extract(ws)
});
}
// Collect eligible sets with their worker counts, skipping sets with no workers.
// Collect eligible sets with their worker counts, skipping sets with no workers
// or sets whose prefill router has died under enforce_disagg.
// In-process models (no discovery watcher) return count=1, so they always participate.
// Discovery models with count=0 have no available workers and are skipped.
let eligible: Vec<(T, usize)> = self
......@@ -319,7 +331,7 @@ impl Model {
.filter_map(|entry| {
let ws = entry.value();
let count = ws.worker_count();
if count == 0 {
if count == 0 || !ws.can_serve_requests() {
return None;
}
extract(ws).map(|val| (val, count))
......@@ -600,4 +612,112 @@ mod tests {
// Both have 0 workers → all filtered → Err
assert!(model.get_chat_engine().is_err());
}
// -- Disaggregated prefill death tests --
use crate::kv_router::PrefillRouter;
/// Build a WorkerSet with a deactivated PrefillRouter simulating "was activated, now dead".
/// worker_count defaults to 1 (no instance_count_rx -> in-process default).
fn make_worker_set_with_dead_prefill(namespace: &str, enforce_disagg: bool) -> Arc<WorkerSet> {
let mut ws = WorkerSet::new(
namespace.to_string(),
"abc".to_string(),
crate::model_card::ModelDeploymentCard::default(),
);
let pr = PrefillRouter::disabled(
std::sync::Arc::new(crate::discovery::ModelManager::new()),
dynamo_runtime::pipeline::RouterMode::RoundRobin,
enforce_disagg,
);
pr.deactivate();
ws.prefill_router = Some(pr);
Arc::new(ws)
}
/// Baseline: a WorkerSet without a PrefillRouter is always displayable
/// (worker_count=1, is_prefill_set=true, no can_serve_requests block).
#[test]
fn test_is_displayable_true_basic() {
let model = Model::new("llama".to_string());
model.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"));
assert!(
model.is_displayable(),
"model with an unconstrained WorkerSet must be displayable"
);
}
/// When the prefill engine dies and enforce_disagg is set, the model must be
/// hidden from /v1/models.
#[test]
fn test_is_displayable_false_when_prefill_dies_enforce_disagg() {
let model = Model::new("llama".to_string());
model.add_worker_set(
"ns1".to_string(),
make_worker_set_with_dead_prefill("ns1", true),
);
assert!(
!model.is_displayable(),
"model must be hidden when prefill died and enforce_disagg=true"
);
}
/// When enforce_disagg is false the deployment can fall back to aggregated mode,
/// so the model should remain visible in /v1/models.
#[test]
fn test_is_displayable_true_when_prefill_dies_no_enforce() {
let model = Model::new("llama".to_string());
model.add_worker_set(
"ns1".to_string(),
make_worker_set_with_dead_prefill("ns1", false),
);
assert!(
model.is_displayable(),
"model must remain visible when prefill died but enforce_disagg=false (fallback)"
);
}
/// A single WorkerSet with a deactivated prefill router (enforce_disagg=true) must be
/// skipped by select_worker_set_with(), causing engine accessors to return Err.
#[test]
fn test_dead_prefill_single_set_not_selectable() {
let model = Model::new("llama".to_string());
model.add_worker_set(
"ns1".to_string(),
make_worker_set_with_dead_prefill("ns1", true),
);
assert!(model.get_chat_engine().is_err());
assert!(model.get_completions_engine().is_err());
}
/// With two WorkerSets -- one healthy, one with dead prefill -- the healthy set
/// keeps the model displayable. Removing the healthy set hides the model.
#[test]
fn test_dead_prefill_multi_set_skips_dead_namespace() {
let model = Model::new("llama".to_string());
// Healthy set (no prefill constraint)
model.add_worker_set("healthy".to_string(), make_worker_set("healthy", "abc"));
// Dead set (deactivated prefill + enforce_disagg)
model.add_worker_set(
"dead".to_string(),
make_worker_set_with_dead_prefill("dead", true),
);
assert!(
model.is_displayable(),
"model must be displayable when at least one healthy set exists"
);
// Removing the healthy set leaves only the dead set -- model must be hidden.
model.remove_worker_set("healthy");
assert!(
!model.is_displayable(),
"model must be hidden when only the dead prefill set remains"
);
}
}
......@@ -46,6 +46,9 @@ pub enum ModelManagerError {
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Model unavailable: {0}")]
ModelUnavailable(String),
#[error("Model already exists: {0}")]
ModelAlreadyExists(String),
}
......@@ -703,6 +706,39 @@ impl ModelManager {
);
}
None => {
// Try to reactivate an existing deactivated router first.
// This handles prefill rejoin after a transient failure: the decode
// WorkerSet's PrefillRouter already exists but is deactivated.
if let Some(model) = self.get_model(model_name)
&& let Some(ws) = model.get_worker_set(namespace)
&& let Some(ref pr) = ws.prefill_router
&& pr.is_deactivated()
{
pr.reactivate();
// Store the endpoint so that if the decode WorkerSet is rebuilt
// (removed and re-added), a subsequent register_prefill_router call
// finds PrefillReady instead of falling back to DecodeWaiting and
// stalling.
let (tx, rx) = oneshot::channel();
tx.send(endpoint).map_err(|_| {
anyhow::anyhow!(
"Failed to send endpoint for prefill model {}:{}",
model_name,
namespace
)
})?;
self.prefill_router_activators
.insert(key, PrefillActivationState::PrefillReady(rx));
tracing::info!(
model_name = %model_name,
namespace = %namespace,
"Reactivated existing prefill router for decode WorkerSet (prefill rejoin)"
);
return Ok(());
}
// No existing deactivated router -- store endpoint for a future decode
// registration.
let (tx, rx) = oneshot::channel();
tx.send(endpoint).map_err(|_| {
anyhow::anyhow!(
......@@ -723,6 +759,18 @@ impl ModelManager {
}
}
/// Deactivate the prefill router on the decode WorkerSet for the given model/namespace.
/// Called by the watcher when all prefill workers in a namespace are removed.
/// After deactivation, requests fall back to aggregated mode (or fail if enforce_disagg).
pub fn deactivate_prefill_router_for_decode(&self, model_name: &str, namespace: &str) {
if let Some(model) = self.get_model(model_name)
&& let Some(ws) = model.get_worker_set(namespace)
&& let Some(ref pr) = ws.prefill_router
{
pr.deactivate();
}
}
/// Remove the prefill router activator for a (model, namespace) pair.
/// Called when a WorkerSet is removed to prevent stale activators.
pub fn remove_prefill_activator(&self, model_name: &str, namespace: &str) {
......@@ -1090,4 +1138,162 @@ mod tests {
"gpt-4:default-abc"
);
}
// -- deactivate_prefill_router_for_decode tests --
use crate::kv_router::PrefillRouter;
/// Helper: make a WorkerSet with an activated PrefillRouter attached.
/// The router is marked as activated to simulate a real deployment where
/// the prefill endpoint has already rendezvoused with the decode side.
fn make_worker_set_with_prefill_router(
namespace: &str,
mdcsum: &str,
enforce_disagg: bool,
) -> WorkerSet {
let mut ws = make_worker_set(namespace, mdcsum);
let pr = PrefillRouter::disabled(
std::sync::Arc::new(ModelManager::new()),
dynamo_runtime::pipeline::RouterMode::RoundRobin,
enforce_disagg,
);
pr.mark_activated_for_test();
ws.prefill_router = Some(pr);
ws
}
/// Calling deactivate on a non-existent model must not panic.
#[test]
fn test_deactivate_prefill_router_for_decode_noop_missing_model() {
let mm = ModelManager::new();
mm.deactivate_prefill_router_for_decode("nonexistent", "ns1");
}
/// Calling deactivate on a WorkerSet without a prefill_router must not panic.
#[test]
fn test_deactivate_prefill_router_for_decode_noop_no_router() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"));
mm.deactivate_prefill_router_for_decode("llama", "ns1");
}
/// Full pipeline test: deactivate finds the WorkerSet, calls deactivate() on its
/// PrefillRouter, and the model is hidden from model_display_names() when
/// enforce_disagg=true.
#[test]
fn test_deactivate_prefill_router_for_decode_hides_model() {
let mm = ModelManager::new();
mm.add_worker_set(
"llama",
"ns1",
make_worker_set_with_prefill_router("ns1", "abc", true),
);
// Model is visible before deactivation.
assert!(mm.model_display_names().contains("llama"));
mm.deactivate_prefill_router_for_decode("llama", "ns1");
// Model must be hidden after deactivation with enforce_disagg=true.
assert!(
!mm.model_display_names().contains("llama"),
"model must be hidden after prefill deactivation with enforce_disagg=true"
);
// Idempotent: calling again must not panic.
mm.deactivate_prefill_router_for_decode("llama", "ns1");
assert!(!mm.model_display_names().contains("llama"));
}
/// Full disagg lifecycle with enforce_disagg=true:
/// decode registers -> prefill registers -> prefill dies -> model hidden.
#[test]
fn test_disagg_lifecycle_prefill_death_hides_model() {
let mm = ModelManager::new();
// Step 1: Decode WorkerSet with a PrefillRouter (not yet deactivated).
mm.add_worker_set(
"llama",
"decode-ns",
make_worker_set_with_prefill_router("decode-ns", "abc", true),
);
assert!(
mm.model_display_names().contains("llama"),
"step 1: model must be visible with active prefill router"
);
// Step 2: Prefill WorkerSet registers (same model, different namespace key).
mm.add_worker_set("llama", "prefill-ns", make_worker_set("prefill-ns", "abc"));
assert!(
mm.model_display_names().contains("llama"),
"step 2: model must be visible with both decode and prefill"
);
// Step 3: Prefill WorkerSet removed (engine dies).
mm.remove_worker_set("llama", "prefill-ns");
// Step 4: Deactivate the prefill router on the decode side.
mm.deactivate_prefill_router_for_decode("llama", "decode-ns");
assert!(
!mm.model_display_names().contains("llama"),
"step 4: model must be hidden after prefill death with enforce_disagg=true"
);
}
/// Full disagg lifecycle with enforce_disagg=false (fallback allowed).
#[test]
fn test_disagg_lifecycle_prefill_death_keeps_model_no_enforce() {
let mm = ModelManager::new();
mm.add_worker_set(
"llama",
"decode-ns",
make_worker_set_with_prefill_router("decode-ns", "abc", false),
);
assert!(mm.model_display_names().contains("llama"));
// Deactivate -- model stays visible (enforce_disagg=false, fallback allowed).
mm.deactivate_prefill_router_for_decode("llama", "decode-ns");
assert!(
mm.model_display_names().contains("llama"),
"model must remain visible (enforce_disagg=false, fallback allowed)"
);
}
/// Full disagg lifecycle including prefill rejoin after transient failure.
/// decode registers -> prefill dies -> model hidden -> prefill rejoins -> model visible.
#[test]
fn test_disagg_lifecycle_prefill_rejoin_restores_model() {
let mm = ModelManager::new();
// Decode WorkerSet with enforce_disagg=true.
mm.add_worker_set(
"llama",
"decode-ns",
make_worker_set_with_prefill_router("decode-ns", "abc", true),
);
assert!(mm.model_display_names().contains("llama"));
// Prefill dies -> deactivate.
mm.deactivate_prefill_router_for_decode("llama", "decode-ns");
assert!(
!mm.model_display_names().contains("llama"),
"model must be hidden after prefill death"
);
// Prefill rejoins -> reactivate via the WorkerSet's PrefillRouter.
if let Some(model) = mm.get_model("llama")
&& let Some(ws) = model.get_worker_set("decode-ns")
&& let Some(ref pr) = ws.prefill_router
{
pr.reactivate();
} else {
panic!("decode WorkerSet or prefill_router not found");
}
assert!(
mm.model_display_names().contains("llama"),
"model must be visible again after prefill rejoin"
);
}
}
......@@ -336,6 +336,15 @@ impl ModelWatcher {
"Removed WorkerSet (no remaining instances in namespace)"
);
}
// If the removed component was a prefill worker, deactivate the decode-side
// prefill router so requests fall back to aggregated mode (or fail cleanly
// with enforce_disagg). The decode WorkerSet's namespace matches the
// deployment namespace, not the ws_key.
if card.model_type.supports_prefill() {
self.manager
.deactivate_prefill_router_for_decode(&model_name, worker_namespace);
}
}
// Check if the Model still has instances in any namespace
......@@ -542,9 +551,12 @@ impl ModelWatcher {
self.router_config.load_threshold_config.clone(),
));
// Store KV router and worker monitor on the WorkerSet
// Store KV router, worker monitor, and prefill router on the WorkerSet.
// The prefill router is stored so the watcher can deactivate/reactivate it
// when prefill workers die or rejoin.
worker_set.kv_router = kv_chooser.clone();
worker_set.worker_monitor = worker_monitor.clone();
worker_set.prefill_router = prefill_chooser.clone();
// Add chat engine only if the model supports chat
if card.model_type.supports_chat() {
......
......@@ -11,7 +11,7 @@ use tokio::sync::watch;
use crate::{
discovery::KvWorkerMonitor,
kv_router::KvRouter,
kv_router::{KvRouter, PrefillRouter},
model_card::ModelDeploymentCard,
types::{
generic::tensor::TensorStreamingEngine,
......@@ -51,6 +51,10 @@ pub struct WorkerSet {
/// Worker monitor for load-based rejection
pub(crate) worker_monitor: Option<KvWorkerMonitor>,
/// Prefill router for disaggregated serving. Stored here so the watcher can
/// deactivate it when all prefill workers die, and reactivate when they rejoin.
pub(crate) prefill_router: Option<Arc<PrefillRouter>>,
/// Watcher for available instance IDs (from the Client's discovery watch).
/// None for in-process models (http/grpc) which don't have a discovery client.
instance_count_rx: Option<watch::Receiver<Vec<u64>>>,
......@@ -71,6 +75,7 @@ impl WorkerSet {
tensor_engine: None,
kv_router: None,
worker_monitor: None,
prefill_router: None,
instance_count_rx: None,
}
}
......@@ -152,6 +157,16 @@ impl WorkerSet {
pub fn set_instance_watcher(&mut self, rx: watch::Receiver<Vec<u64>>) {
self.instance_count_rx = Some(rx);
}
/// Whether this WorkerSet can serve requests. Delegates to the prefill router
/// if one exists; otherwise always returns true.
/// When the prefill router is deactivated and enforce_disagg is set, this returns
/// false, causing the model to be hidden from /v1/models and requests to be rejected.
pub fn can_serve_requests(&self) -> bool {
self.prefill_router
.as_ref()
.is_none_or(|pr| pr.can_serve_requests())
}
}
#[cfg(test)]
......
......@@ -136,7 +136,26 @@ pub async fn prepare_engine(
let model_service_name = watch_obj.wait_for_chat_model().await;
tracing::info!("Connected to {model_service_name}");
let engine = model_manager.get_chat_completions_engine(&model_service_name)?;
// In disaggregated deployments the model may be listed before the prefill
// router is fully activated, causing a transient ModelUnavailable. Retry
// with a timeout so the startup path doesn't fail during this cold-start
// window, but also doesn't hang indefinitely on misconfiguration.
let deadline = tokio::time::Instant::now() + Duration::from_secs(120);
let engine = loop {
match model_manager.get_chat_completions_engine(&model_service_name) {
Ok(engine) => break engine,
Err(crate::discovery::ModelManagerError::ModelUnavailable(_))
if tokio::time::Instant::now() < deadline =>
{
tracing::debug!(
model = %model_service_name,
"Model listed but not yet servable, waiting for prefill activation"
);
tokio::time::sleep(Duration::from_millis(500)).await;
}
Err(e) => return Err(e.into()),
}
};
Ok(PreparedEngine {
service_name: model_service_name,
engine,
......
......@@ -85,7 +85,12 @@ pub async fn completion_response_stream(
let (engine, parsing_options) = state
.manager()
.get_completions_engine_with_parsing(model)
.map_err(|_| Status::not_found("model not found"))?;
.map_err(|e| match e {
crate::discovery::ModelManagerError::ModelUnavailable(_) => {
Status::unavailable("model temporarily unavailable")
}
_ => Status::not_found("model not found"),
})?;
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
......
......@@ -86,7 +86,12 @@ pub async fn tensor_response_stream(
let engine = state
.manager()
.get_tensor_engine(model)
.map_err(|_| Status::not_found("model not found"))?;
.map_err(|e| match e {
crate::discovery::ModelManagerError::ModelUnavailable(_) => {
Status::unavailable("model temporarily unavailable")
}
_ => Status::not_found("model not found"),
})?;
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
......
......@@ -143,6 +143,29 @@ impl ErrorMessage {
)
}
/// Model exists but is temporarily unable to serve (e.g., prefill not activated,
/// no available workers). Returns 503 so clients can retry.
pub fn model_unavailable() -> ErrorResponse {
let code = StatusCode::SERVICE_UNAVAILABLE;
let error_type = map_error_code_to_error_type(code);
(
code,
Json(ErrorMessage {
message: "Model temporarily unavailable".to_string(),
error_type,
code: code.as_u16(),
}),
)
}
/// Convert a ModelManagerError to the appropriate HTTP response.
pub fn from_model_error(e: &crate::discovery::ModelManagerError) -> ErrorResponse {
match e {
crate::discovery::ModelManagerError::ModelUnavailable(_) => Self::model_unavailable(),
_ => Self::model_not_found(),
}
}
/// Service Unavailable
/// This is returned when the service is live, but not ready.
pub fn _service_unavailable() -> ErrorResponse {
......@@ -469,8 +492,8 @@ async fn completions_single(
let (engine, parsing_options) = state
.manager()
.get_completions_engine_with_parsing(&model)
.map_err(|_| {
let err_response = ErrorMessage::model_not_found();
.map_err(|e| {
let err_response = ErrorMessage::from_model_error(&e);
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
......@@ -609,8 +632,8 @@ async fn completions_batch(
let (engine, parsing_options) = state
.manager()
.get_completions_engine_with_parsing(&model)
.map_err(|_| {
let err_response = ErrorMessage::model_not_found();
.map_err(|e| {
let err_response = ErrorMessage::from_model_error(&e);
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
......@@ -790,8 +813,8 @@ async fn embeddings(
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
// todo - error handling should be more robust
let engine = state.manager().get_embeddings_engine(model).map_err(|_| {
let err_response = ErrorMessage::model_not_found();
let engine = state.manager().get_embeddings_engine(model).map_err(|e| {
let err_response = ErrorMessage::from_model_error(&e);
inflight.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
......@@ -1200,8 +1223,8 @@ async fn chat_completions(
let (engine, parsing_options) = state
.manager()
.get_chat_completions_engine_with_parsing(&model)
.map_err(|_| {
let err_response = ErrorMessage::model_not_found();
.map_err(|e| {
let err_response = ErrorMessage::from_model_error(&e);
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
......@@ -1612,8 +1635,8 @@ async fn responses(
let (engine, parsing_options) = state
.manager()
.get_chat_completions_engine_with_parsing(&model)
.map_err(|_| {
let err_response = ErrorMessage::model_not_found();
.map_err(|e| {
let err_response = ErrorMessage::from_model_error(&e);
inflight_guard.mark_error(extract_error_type_from_response(&err_response));
err_response
})?;
......@@ -2065,7 +2088,7 @@ async fn images(
let engine = state
.manager()
.get_images_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
.map_err(|e| ErrorMessage::from_model_error(&e))?;
// this will increment the inflight gauge for the model
let mut inflight = state.metrics_clone().create_inflight_guard(
......@@ -2183,7 +2206,7 @@ async fn videos(
let engine = state
.manager()
.get_videos_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
.map_err(|e| ErrorMessage::from_model_error(&e))?;
// this will increment the inflight gauge for the model
let mut inflight = state.metrics_clone().create_inflight_guard(
......@@ -2256,7 +2279,7 @@ async fn video_stream(
let engine = state
.manager()
.get_videos_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
.map_err(|e| ErrorMessage::from_model_error(&e))?;
let mut inflight =
state
......@@ -2432,7 +2455,7 @@ async fn audio_speech(
let engine = state
.manager()
.get_audios_engine(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
.map_err(|e| ErrorMessage::from_model_error(&e))?;
let mut inflight = state.metrics_clone().create_inflight_guard(
&model,
......@@ -3397,6 +3420,30 @@ mod tests {
);
}
#[test]
fn test_extract_error_type_from_response_unavailable() {
let response = ErrorMessage::model_unavailable();
assert_eq!(
extract_error_type_from_response(&response),
ErrorType::Overload
);
}
#[test]
fn test_from_model_error_maps_correctly() {
let not_found = ModelManagerError::ModelNotFound("x".to_string());
assert_eq!(
ErrorMessage::from_model_error(&not_found).0,
StatusCode::NOT_FOUND
);
let unavailable = ModelManagerError::ModelUnavailable("x".to_string());
assert_eq!(
ErrorMessage::from_model_error(&unavailable).0,
StatusCode::SERVICE_UNAVAILABLE
);
}
#[test]
fn test_extract_error_type_from_response_internal() {
let response = ErrorMessage::internal_server_error("Something went wrong");
......
......@@ -2,6 +2,7 @@
// SPDX-License-Identifier: Apache-2.0
use std::sync::Arc;
use std::sync::atomic::Ordering;
use anyhow::Result;
use tokio::sync::oneshot;
......@@ -41,6 +42,8 @@ impl PrefillRouter {
model_name: String::new(), // Not used for disabled router
namespace: String::new(), // Not used for disabled router
is_eagle: false,
deactivated: std::sync::atomic::AtomicBool::new(false),
activated: std::sync::atomic::AtomicBool::new(false),
})
}
......@@ -71,6 +74,8 @@ impl PrefillRouter {
model_name,
namespace,
is_eagle,
deactivated: std::sync::atomic::AtomicBool::new(false),
activated: std::sync::atomic::AtomicBool::new(false),
});
// Spawn background task to wait for activation
......@@ -175,6 +180,7 @@ impl PrefillRouter {
// Set the router (ignore error if already set)
let _ = self.prefill_router.set(inner_router);
self.activated.store(true, Ordering::Release);
tracing::info!(
router_mode = ?self.router_mode,
......@@ -191,4 +197,70 @@ impl PrefillRouter {
monitor.set_prefill_client(client.clone());
}
}
// -- Prefill death handling --
/// Deactivate the prefill router. Called when all prefill workers are removed.
/// After deactivation, requests fall back to aggregated mode (or fail if enforce_disagg).
/// The inner router is preserved so that when workers rejoin (same endpoint/discovery),
/// the Client's discovery subscription picks them up automatically.
pub fn deactivate(&self) {
self.deactivated.store(true, Ordering::Release);
tracing::info!(
model_name = %self.model_name,
namespace = %self.namespace,
enforce_disagg = self.enforce_disagg,
"Prefill router deactivated (all prefill workers removed)"
);
}
/// Reactivate a deactivated router. Called when prefill workers rejoin.
/// The inner router's Client re-discovers workers via its discovery subscription.
///
/// Note: there is a brief race between flipping `deactivated=false` (making
/// `can_serve_requests()` return true) and the Client actually rediscovering
/// workers. Requests arriving in this window may fail at prefill resolution.
/// This is bounded by discovery propagation time (typically sub-second).
///
/// Also note: reactivation reuses the existing inner router built from the
/// original endpoint. If prefill rejoins under a different endpoint identity
/// (e.g., reconfigured deployment), the stale Client would not discover the
/// new workers. This is acceptable for normal restart scenarios where the
/// endpoint identity is stable.
pub fn reactivate(&self) {
self.deactivated.store(false, Ordering::Release);
tracing::info!(
model_name = %self.model_name,
namespace = %self.namespace,
"Prefill router reactivated (prefill workers rejoined)"
);
}
/// Whether this router is currently deactivated (prefill workers died).
pub fn is_deactivated(&self) -> bool {
self.deactivated.load(Ordering::Acquire)
}
/// Whether this router can serve requests in its current state.
/// - !enforce_disagg (aggregated passthrough): always servable unless deactivated
/// - enforce_disagg: only servable when prefill has activated AND is not deactivated,
/// so a cold-started strict-disagg model isn't listed before prefill rendezvoused.
pub fn can_serve_requests(&self) -> bool {
if self.is_deactivated() {
return !self.enforce_disagg;
}
if !self.enforce_disagg {
return true;
}
self.activated.load(Ordering::Acquire)
}
/// Mark this router as activated for testing purposes.
/// In production, `activate()` sets this flag when the inner router is populated.
#[cfg(test)]
pub(crate) fn mark_activated_for_test(&self) {
self.activated.store(true, Ordering::Release);
}
}
......@@ -312,9 +312,10 @@ impl PrefillRouter {
}
}
/// Check if disaggregated mode is currently active (prefill router activated)
/// Check if disaggregated mode is currently active (prefill router activated).
/// Uses the same `activated` flag as `can_serve_requests()` for consistency.
pub fn is_activated(&self) -> bool {
self.prefill_router.get().is_some()
self.activated.load(std::sync::atomic::Ordering::Acquire)
}
/// Whether disaggregated mode is strictly enforced (fail if no prefill workers).
......
// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, OnceLock};
use anyhow::Result;
......@@ -53,6 +54,13 @@ pub struct PrefillRouter {
/// Namespace used to look up the correct WorkerSet's worker monitor
namespace: String,
is_eagle: bool,
/// Set to true when all prefill workers die. Checked in generate() to prevent
/// routing to dead workers. Cleared on reactivation when workers rejoin.
deactivated: AtomicBool,
/// Set to true when the prefill router has been activated (inner router populated).
/// Used by `can_serve_requests()` to gate enforce_disagg readiness so a cold-started
/// strict-disagg model isn't listed before the prefill has rendezvoused.
activated: AtomicBool,
}
impl Drop for PrefillRouter {
......@@ -84,10 +92,10 @@ impl
// Save original max_tokens for decode
let original_max_tokens = req.stop_conditions.max_tokens;
// If prefill router is not activated (no prefill workers discovered),
// this is aggregated mode route directly to decode.
// With --enforce-disagg, fail instead of falling back.
if self.prefill_router.get().is_none() {
// If prefill router is not activated (no prefill workers discovered) or has been
// deactivated (all prefill workers died), this is aggregated mode -- route directly
// to decode. With --enforce-disagg, fail instead of falling back.
if self.prefill_router.get().is_none() || self.deactivated.load(Ordering::Relaxed) {
if self.enforce_disagg {
return Err(anyhow::anyhow!(PrefillError::NotActivated));
}
......@@ -269,4 +277,102 @@ mod tests {
assert_eq!(override_config.track_prefill_tokens, Some(false));
assert_eq!(override_config.router_temperature, Some(0.7));
}
// -- Prefill death handling tests --
/// Helper: create a disabled PrefillRouter for testing deactivation behavior.
fn make_test_router(enforce_disagg: bool) -> Arc<PrefillRouter> {
PrefillRouter::disabled(
Arc::new(crate::discovery::ModelManager::new()),
RouterMode::RoundRobin,
enforce_disagg,
)
}
#[test]
fn test_deactivated_flag_blocks_when_enforce_disagg() {
let router = make_test_router(true);
// Not activated, so enforce_disagg blocks even before deactivation
assert!(
!router.can_serve_requests(),
"enforce_disagg must block before prefill activation"
);
router.deactivate();
assert!(router.is_deactivated());
assert!(
!router.can_serve_requests(),
"deactivated + enforce_disagg must block"
);
}
#[test]
fn test_deactivated_flag_allows_fallback_no_enforce() {
let router = make_test_router(false);
router.deactivate();
assert!(router.is_deactivated());
assert!(
router.can_serve_requests(),
"deactivated + !enforce_disagg must allow fallback"
);
}
#[test]
fn test_reactivate_clears_deactivated_no_enforce() {
let router = make_test_router(false);
router.deactivate();
// !enforce_disagg allows fallback even while deactivated
assert!(router.can_serve_requests());
router.reactivate();
assert!(!router.is_deactivated());
assert!(
router.can_serve_requests(),
"reactivated non-enforce router must serve requests"
);
}
#[test]
fn test_reactivate_clears_deactivated_enforce_needs_activation() {
// disabled() never sets the activated flag, so enforce_disagg stays blocked.
// In a real deployment, activate() sets the flag before the first
// deactivate/reactivate cycle, so this only exercises the flag reset.
let router = make_test_router(true);
router.deactivate();
assert!(!router.can_serve_requests());
router.reactivate();
assert!(!router.is_deactivated());
assert!(
!router.can_serve_requests(),
"enforce_disagg without activation still can't serve"
);
}
#[test]
fn test_fresh_router_not_deactivated() {
let router = make_test_router(true);
assert!(!router.is_deactivated());
// enforce_disagg + no prefill activation => not servable
assert!(!router.can_serve_requests());
}
#[test]
fn test_fresh_router_no_enforce_disagg_can_serve() {
let router = make_test_router(false);
assert!(!router.is_deactivated());
assert!(
router.can_serve_requests(),
"non-enforce_disagg router must be servable even without prefill activation"
);
}
#[test]
fn test_deactivate_is_idempotent() {
let router = make_test_router(true);
router.deactivate();
router.deactivate();
assert!(router.is_deactivated());
assert!(!router.can_serve_requests());
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment