// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use std::{collections::HashSet, sync::Arc}; use dashmap::{DashMap, mapref::entry::Entry}; use dynamo_kv_router::{PrefillLoadEstimator, config::KvRouterConfig, protocols::WorkerId}; use tokio::sync::oneshot; use super::worker_monitor::LoadThresholdConfig; use super::{KvWorkerMonitor, Model, RuntimeConfigWatch, WorkerSet, runtime_config_watch}; use dynamo_runtime::{ component::{Endpoint, build_transport_type}, discovery::DiscoverySpec, prelude::DistributedRuntimeProvider, protocols::EndpointId, }; use crate::{ kv_router::{KvRouter, router_endpoint_id, scheduler::DefaultWorkerSelector}, local_model::runtime_config::DisaggregatedEndpoint, model_card::ModelDeploymentCard, types::{ generic::tensor::TensorStreamingEngine, openai::{ audios::OpenAIAudiosStreamingEngine, chat_completions::OpenAIChatCompletionsStreamingEngine, completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine, videos::OpenAIVideosStreamingEngine, }, }, }; /// State for prefill router activation rendezvous enum PrefillActivationState { /// Decode model registered, waiting for prefill endpoint DecodeWaiting(oneshot::Sender), /// Prefill endpoint arrived, waiting for decode model to register PrefillReady(oneshot::Receiver), } #[derive(Debug, thiserror::Error)] pub enum ModelManagerError { #[error("Model not found: {0}")] ModelNotFound(String), #[error("Model unavailable: {0}")] ModelUnavailable(String), #[error("Model already exists: {0}")] ModelAlreadyExists(String), } /// Central manager for model engines, routing, and configuration. /// /// Models are stored hierarchically: ModelManager → Model → WorkerSet. /// Each WorkerSet owns a complete pipeline built from its specific configuration. /// /// Note: Don't implement Clone for this, put it in an Arc instead. pub struct ModelManager { /// Model name → Model (which contains WorkerSets with engines) models: DashMap>, /// Per-instance model cards, keyed by instance path. Used for cleanup on worker removal. cards: DashMap, /// Prefill router activation rendezvous, keyed by "model_name:namespace". prefill_router_activators: DashMap, /// Per-endpoint runtime config watchers. Keyed by EndpointId (includes namespace). runtime_configs: DashMap, } impl Default for ModelManager { fn default() -> Self { Self::new() } } impl ModelManager { pub fn new() -> Self { Self { models: DashMap::new(), cards: DashMap::new(), prefill_router_activators: DashMap::new(), runtime_configs: DashMap::new(), } } // -- Model access -- /// Get or create a Model for the given name. pub fn get_or_create_model(&self, model_name: &str) -> Arc { self.models .entry(model_name.to_string()) .or_insert_with(|| Arc::new(Model::new(model_name.to_string()))) .clone() } /// Get an existing Model, if it exists. pub fn get_model(&self, model_name: &str) -> Option> { self.models .get(model_name) .map(|entry| entry.value().clone()) } /// Remove a Model if it has no remaining WorkerSets. /// Uses atomic remove_if to avoid TOCTOU race between checking is_empty and removing. pub fn remove_model_if_empty(&self, model_name: &str) { if self .models .remove_if(model_name, |_, model| model.is_empty()) .is_some() { tracing::info!(model_name, "Removed empty model from manager"); } } /// Add a WorkerSet to a Model. Creates the Model if it doesn't exist. pub fn add_worker_set(&self, model_name: &str, namespace: &str, worker_set: WorkerSet) { let model = self.get_or_create_model(model_name); model.add_worker_set(namespace.to_string(), Arc::new(worker_set)); } /// Remove a WorkerSet from a Model. Removes the Model if it becomes empty. pub fn remove_worker_set(&self, model_name: &str, namespace: &str) -> Option> { let model = self.models.get(model_name)?; let removed = model.remove_worker_set(namespace); drop(model); self.remove_model_if_empty(model_name); removed } // -- Model cards -- pub fn get_model_cards(&self) -> Vec { self.cards.iter().map(|r| r.value().clone()).collect() } /// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is /// deleted. pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> { self.cards.insert(key.to_string(), card); Ok(()) } /// Remove and return model card for this instance's key. We do this when the instance stops. pub fn remove_model_card(&self, key: &str) -> Option { self.cards.remove(key).map(|(_, v)| v) } // -- Engine accessors (delegate through Model → WorkerSet) -- /// Check if a decode model (chat or completions) is registered pub fn has_decode_model(&self, model: &str) -> bool { self.models .get(model) .is_some_and(|m| m.has_decode_engine()) } /// Check if a prefill model is registered pub fn has_prefill_model(&self, model: &str) -> bool { self.models.get(model).is_some_and(|m| m.has_prefill()) } /// Check if any model (decode or prefill) is registered. pub fn has_model_any(&self, model: &str) -> bool { self.has_decode_model(model) || self.has_prefill_model(model) } pub fn model_display_names(&self) -> HashSet { self.models .iter() .filter(|entry| entry.value().is_displayable()) .map(|entry| entry.key().clone()) .collect() } pub fn list_chat_completions_models(&self) -> Vec { self.models .iter() .filter(|entry| entry.value().has_chat_engine()) .map(|entry| entry.key().clone()) .collect() } pub fn list_completions_models(&self) -> Vec { self.models .iter() .filter(|entry| entry.value().has_completions_engine()) .map(|entry| entry.key().clone()) .collect() } pub fn list_embeddings_models(&self) -> Vec { self.models .iter() .filter(|entry| entry.value().has_embeddings_engine()) .map(|entry| entry.key().clone()) .collect() } pub fn list_tensor_models(&self) -> Vec { self.models .iter() .filter(|entry| entry.value().has_tensor_engine()) .map(|entry| entry.key().clone()) .collect() } pub fn list_images_models(&self) -> Vec { self.models .iter() .filter(|entry| entry.value().has_images_engine()) .map(|entry| entry.key().clone()) .collect() } pub fn list_videos_models(&self) -> Vec { self.models .iter() .filter(|entry| entry.value().has_videos_engine()) .map(|entry| entry.key().clone()) .collect() } pub fn list_prefill_models(&self) -> Vec { self.models .iter() .filter(|entry| entry.value().has_prefill()) .map(|entry| entry.key().clone()) .collect() } pub fn get_embeddings_engine( &self, model: &str, ) -> Result { self.models .get(model) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))? .get_embeddings_engine() } pub fn get_completions_engine( &self, model: &str, ) -> Result { self.models .get(model) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))? .get_completions_engine() } pub fn get_chat_completions_engine( &self, model: &str, ) -> Result { self.models .get(model) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))? .get_chat_engine() } pub fn get_tensor_engine( &self, model: &str, ) -> Result { self.models .get(model) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))? .get_tensor_engine() } pub fn get_images_engine( &self, model: &str, ) -> Result { self.models .get(model) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))? .get_images_engine() } pub fn get_videos_engine( &self, model: &str, ) -> Result { self.models .get(model) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))? .get_videos_engine() } pub fn get_audios_engine( &self, model: &str, ) -> Result { self.models .get(model) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))? .get_audios_engine() } // -- Combined engine + parsing options (atomically from one WorkerSet) -- pub fn get_chat_completions_engine_with_parsing( &self, model: &str, ) -> Result< ( OpenAIChatCompletionsStreamingEngine, crate::protocols::openai::ParsingOptions, ), ModelManagerError, > { self.models .get(model) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))? .get_chat_engine_with_parsing() } pub fn get_completions_engine_with_parsing( &self, model: &str, ) -> Result< ( OpenAICompletionsStreamingEngine, crate::protocols::openai::ParsingOptions, ), ModelManagerError, > { self.models .get(model) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))? .get_completions_engine_with_parsing() } // -- Convenience methods for in-process models (http.rs, grpc.rs) -- // These create a WorkerSet with a default namespace for local models. // TODO: These methods use ModelDeploymentCard::default() for the WorkerSet, which means // parsing_options() returns defaults (no tool_call_parser/reasoning_parser). Pass the real // MDC from callers so ParsingOptions reflect the model's actual configuration. pub fn add_chat_completions_model( &self, model: &str, card_checksum: &str, engine: OpenAIChatCompletionsStreamingEngine, ) -> Result<(), ModelManagerError> { let model_entry = self.get_or_create_model(model); if model_entry.has_chat_engine() { return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); } let namespace = format!("__local_chat_{}", model); let mut ws = WorkerSet::new( namespace.clone(), card_checksum.to_string(), ModelDeploymentCard::default(), ); ws.chat_engine = Some(engine); model_entry.add_worker_set(namespace, Arc::new(ws)); Ok(()) } pub fn add_completions_model( &self, model: &str, card_checksum: &str, engine: OpenAICompletionsStreamingEngine, ) -> Result<(), ModelManagerError> { let model_entry = self.get_or_create_model(model); if model_entry.has_completions_engine() { return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); } let namespace = format!("__local_completions_{}", model); let mut ws = WorkerSet::new( namespace.clone(), card_checksum.to_string(), ModelDeploymentCard::default(), ); ws.completions_engine = Some(engine); model_entry.add_worker_set(namespace, Arc::new(ws)); Ok(()) } pub fn add_embeddings_model( &self, model: &str, card_checksum: &str, engine: OpenAIEmbeddingsStreamingEngine, ) -> Result<(), ModelManagerError> { let model_entry = self.get_or_create_model(model); if model_entry.has_embeddings_engine() { return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); } let namespace = format!("__local_embeddings_{}", model); let mut ws = WorkerSet::new( namespace.clone(), card_checksum.to_string(), ModelDeploymentCard::default(), ); ws.embeddings_engine = Some(engine); model_entry.add_worker_set(namespace, Arc::new(ws)); Ok(()) } pub fn add_tensor_model( &self, model: &str, card_checksum: &str, engine: TensorStreamingEngine, ) -> Result<(), ModelManagerError> { let model_entry = self.get_or_create_model(model); if model_entry.has_tensor_engine() { return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); } let namespace = format!("__local_tensor_{}", model); let mut ws = WorkerSet::new( namespace.clone(), card_checksum.to_string(), ModelDeploymentCard::default(), ); ws.tensor_engine = Some(engine); model_entry.add_worker_set(namespace, Arc::new(ws)); Ok(()) } pub fn add_images_model( &self, model: &str, card_checksum: &str, engine: OpenAIImagesStreamingEngine, ) -> Result<(), ModelManagerError> { let model_entry = self.get_or_create_model(model); if model_entry.has_images_engine() { return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); } let namespace = format!("__local_images_{}", model); let mut ws = WorkerSet::new( namespace.clone(), card_checksum.to_string(), ModelDeploymentCard::default(), ); ws.images_engine = Some(engine); model_entry.add_worker_set(namespace, Arc::new(ws)); Ok(()) } pub fn add_videos_model( &self, model: &str, card_checksum: &str, engine: OpenAIVideosStreamingEngine, ) -> Result<(), ModelManagerError> { let model_entry = self.get_or_create_model(model); if model_entry.has_videos_engine() { return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); } let namespace = format!("__local_videos_{}", model); let mut ws = WorkerSet::new( namespace.clone(), card_checksum.to_string(), ModelDeploymentCard::default(), ); ws.videos_engine = Some(engine); model_entry.add_worker_set(namespace, Arc::new(ws)); Ok(()) } pub fn add_audios_model( &self, model: &str, card_checksum: &str, engine: OpenAIAudiosStreamingEngine, ) -> Result<(), ModelManagerError> { let model_entry = self.get_or_create_model(model); if model_entry.has_audios_engine() { return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); } let namespace = format!("__local_audios_{}", model); let mut ws = WorkerSet::new( namespace.clone(), card_checksum.to_string(), ModelDeploymentCard::default(), ); ws.audios_engine = Some(engine); model_entry.add_worker_set(namespace, Arc::new(ws)); Ok(()) } pub fn add_prefill_model( &self, model: &str, card_checksum: &str, ) -> Result<(), ModelManagerError> { let model_entry = self.get_or_create_model(model); if model_entry.has_prefill() { return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); } let namespace = format!("__local_prefill_{}", model); let ws = WorkerSet::new( namespace.clone(), card_checksum.to_string(), ModelDeploymentCard::default(), ); model_entry.add_worker_set(namespace, Arc::new(ws)); Ok(()) } // -- Model removal -- /// Remove a model entirely (all its WorkerSets). /// Returns the removed Model, or None if not found. pub fn remove_model(&self, model: &str) -> Option> { self.models.remove(model).map(|(_, m)| m) } // Per-type remove methods for in-process models (used by Python bindings). // These remove the specific synthetic WorkerSet created by the corresponding add_*_model method. pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { let namespace = format!("__local_chat_{}", model); self.remove_worker_set(model, &namespace) .map(|_| ()) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string())) } pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { let namespace = format!("__local_completions_{}", model); self.remove_worker_set(model, &namespace) .map(|_| ()) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string())) } pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> { let namespace = format!("__local_tensor_{}", model); self.remove_worker_set(model, &namespace) .map(|_| ()) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string())) } pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> { let namespace = format!("__local_embeddings_{}", model); self.remove_worker_set(model, &namespace) .map(|_| ()) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string())) } pub fn remove_images_model(&self, model: &str) -> Result<(), ModelManagerError> { let namespace = format!("__local_images_{}", model); self.remove_worker_set(model, &namespace) .map(|_| ()) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string())) } pub fn remove_videos_model(&self, model: &str) -> Result<(), ModelManagerError> { let namespace = format!("__local_videos_{}", model); self.remove_worker_set(model, &namespace) .map(|_| ()) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string())) } // -- KV Router creation -- #[allow(clippy::too_many_arguments)] pub async fn kv_chooser_for( &self, endpoint: &Endpoint, kv_cache_block_size: u32, kv_router_config: Option, prefill_load_estimator: Option>, worker_type: &'static str, model_name: Option, is_eagle: bool, ) -> anyhow::Result> { let client = endpoint.client().await?; // Register router via discovery mechanism. let discovery = endpoint.component().drt().discovery(); let instance_id = discovery.instance_id(); // Build transport for router endpoint based on request plane mode // Use the worker's component name so each target pool gets its own router discovery group let router_endpoint_id = router_endpoint_id(endpoint.id().namespace, endpoint.id().component); let transport = build_transport_type(endpoint, &router_endpoint_id, instance_id).await?; let discovery_spec = DiscoverySpec::Endpoint { namespace: router_endpoint_id.namespace.clone(), component: router_endpoint_id.component.clone(), endpoint: router_endpoint_id.name.clone(), transport, device_type: None, }; discovery.register(discovery_spec).await?; // Get of create runtime config watcher for this endpoint let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?; let selector = DefaultWorkerSelector::new(kv_router_config.clone(), worker_type); let chooser = KvRouter::new( endpoint.clone(), client, workers_with_configs, kv_cache_block_size, selector, kv_router_config, prefill_load_estimator, worker_type, model_name, is_eagle, ) .await?; Ok(Arc::new(chooser)) } // -- Prefill router coordination -- // Keyed by "model_name:namespace" so each namespace's decode WorkerSet gets its own // prefill router activated by same-namespace prefill workers. /// Build a key for a (model, namespace) pair. Used for prefill router activators /// and registration guards. pub(crate) fn model_namespace_key(model_name: &str, namespace: &str) -> String { format!("{}:{}", model_name, namespace) } /// Register a prefill router for a decode WorkerSet. Returns a receiver that will be /// activated when the corresponding prefill model in the same namespace is discovered. /// Returns None if a decode WorkerSet in this namespace was already registered. pub fn register_prefill_router( &self, model_name: &str, namespace: &str, ) -> Option> { let key = Self::model_namespace_key(model_name, namespace); match self.prefill_router_activators.remove(&key) { Some((_, PrefillActivationState::PrefillReady(rx))) => { // Prefill endpoint already arrived - rx will immediately resolve tracing::debug!( model_name = %model_name, namespace = %namespace, "Prefill endpoint already available for namespace, returning receiver" ); Some(rx) } Some((key, PrefillActivationState::DecodeWaiting(tx))) => { // Decode already registered - this shouldn't happen, restore state and return None tracing::error!( model_name = %model_name, namespace = %namespace, "Decode WorkerSet already registered for this prefill router" ); self.prefill_router_activators .insert(key, PrefillActivationState::DecodeWaiting(tx)); None } None => { // New registration: create tx/rx pair, store sender and return receiver let (tx, rx) = oneshot::channel(); self.prefill_router_activators .insert(key, PrefillActivationState::DecodeWaiting(tx)); tracing::debug!( model_name = %model_name, namespace = %namespace, "No prefill endpoint for namespace yet, storing sender for future activation" ); Some(rx) } } } /// Activate a prefill router by sending the endpoint through the oneshot channel. /// The namespace must match the decode WorkerSet's namespace. pub fn activate_prefill_router( &self, model_name: &str, namespace: &str, endpoint: Endpoint, ) -> anyhow::Result<()> { let key = Self::model_namespace_key(model_name, namespace); match self.prefill_router_activators.remove(&key) { Some((_, PrefillActivationState::DecodeWaiting(sender))) => { sender.send(endpoint).map_err(|_| { anyhow::anyhow!( "Failed to send endpoint to prefill router activator for {}:{}", model_name, namespace ) })?; tracing::info!( model_name = %model_name, namespace = %namespace, "Activated prefill router for decode WorkerSet" ); Ok(()) } Some((_, PrefillActivationState::PrefillReady(_))) => { anyhow::bail!( "Prefill router for {}:{} already activated", model_name, namespace ); } None => { // Try to reactivate an existing deactivated router first. // This handles prefill rejoin after a transient failure: the decode // WorkerSet's PrefillRouter already exists but is deactivated. if let Some(model) = self.get_model(model_name) && let Some(ws) = model.get_worker_set(namespace) && let Some(ref pr) = ws.prefill_router && pr.is_deactivated() { pr.reactivate(); // Store the endpoint so that if the decode WorkerSet is rebuilt // (removed and re-added), a subsequent register_prefill_router call // finds PrefillReady instead of falling back to DecodeWaiting and // stalling. let (tx, rx) = oneshot::channel(); tx.send(endpoint).map_err(|_| { anyhow::anyhow!( "Failed to send endpoint for prefill model {}:{}", model_name, namespace ) })?; self.prefill_router_activators .insert(key, PrefillActivationState::PrefillReady(rx)); tracing::info!( model_name = %model_name, namespace = %namespace, "Reactivated existing prefill router for decode WorkerSet (prefill rejoin)" ); return Ok(()); } // No existing deactivated router -- store endpoint for a future decode // registration. let (tx, rx) = oneshot::channel(); tx.send(endpoint).map_err(|_| { anyhow::anyhow!( "Failed to send endpoint for prefill model {}:{}", model_name, namespace ) })?; self.prefill_router_activators .insert(key, PrefillActivationState::PrefillReady(rx)); tracing::info!( model_name = %model_name, namespace = %namespace, "Stored prefill endpoint for future decode WorkerSet registration" ); Ok(()) } } } /// Deactivate the prefill router on the decode WorkerSet for the given model/namespace. /// Called by the watcher when all prefill workers in a namespace are removed. /// After deactivation, requests fall back to aggregated mode (or fail if enforce_disagg). pub fn deactivate_prefill_router_for_decode(&self, model_name: &str, namespace: &str) { if let Some(model) = self.get_model(model_name) && let Some(ws) = model.get_worker_set(namespace) && let Some(ref pr) = ws.prefill_router { pr.deactivate(); } } /// Remove the prefill router activator for a (model, namespace) pair. /// Called when a WorkerSet is removed to prevent stale activators. pub fn remove_prefill_activator(&self, model_name: &str, namespace: &str) { let key = Self::model_namespace_key(model_name, namespace); if self.prefill_router_activators.remove(&key).is_some() { tracing::debug!( model_name = %model_name, namespace = %namespace, "Cleaned up prefill router activator for removed WorkerSet" ); } } // -- Worker monitoring -- /// Gets or sets the load threshold config for a model's worker monitor. /// Checks across all WorkerSets for the model. pub fn load_threshold_config( &self, model: &str, config: Option<&LoadThresholdConfig>, ) -> Option { let model_entry = self.models.get(model)?; model_entry.load_threshold_config(config) } /// Gets an existing worker monitor for a specific namespace of a model. pub fn get_worker_monitor_for_namespace( &self, model: &str, namespace: &str, ) -> Option { let model_entry = self.models.get(model)?; model_entry.get_worker_monitor_for_namespace(namespace) } /// Lists all models with worker monitors configured. pub fn list_busy_thresholds(&self) -> Vec<(String, LoadThresholdConfig)> { let mut result = Vec::new(); for entry in self.models.iter() { if let Some(config) = entry.value().load_threshold_config(None) { result.push((entry.key().clone(), config)); } } result } // -- Runtime configs -- /// Get or create a runtime config watcher for an endpoint. /// Spawns a background task that joins instance availability and config discovery. /// Returns a `watch::Receiver` with the latest `HashMap`. pub async fn get_or_create_runtime_config_watcher( &self, endpoint: &Endpoint, ) -> anyhow::Result { let endpoint_id = endpoint.id(); if let Some(existing) = self.runtime_configs.get(&endpoint_id) { return Ok(existing.clone()); } // Slow path: create the watch (spawns a background task). // If another caller raced us, the entry() below picks up the winner; // the loser's background task stops once its receivers are dropped. let rx = runtime_config_watch(endpoint).await?; let result = match self.runtime_configs.entry(endpoint_id) { Entry::Occupied(e) => e.get().clone(), Entry::Vacant(e) => { e.insert(rx.clone()); rx } }; Ok(result) } /// Get disaggregated endpoint for a specific worker. pub fn get_disaggregated_endpoint( &self, endpoint_id: &EndpointId, worker_id: WorkerId, ) -> Option { let rx = self.runtime_configs.get(endpoint_id)?; let configs = rx.borrow(); configs.get(&worker_id)?.disaggregated_endpoint.clone() } } #[cfg(test)] mod tests { use super::*; use crate::model_card::ModelDeploymentCard; fn make_worker_set(namespace: &str, mdcsum: &str) -> WorkerSet { WorkerSet::new( namespace.to_string(), mdcsum.to_string(), ModelDeploymentCard::default(), ) } // -- CRUD delegation tests -- #[test] fn test_add_and_get_worker_set() { let mm = ModelManager::new(); let ws = make_worker_set("ns1", "abc"); mm.add_worker_set("llama", "ns1", ws); let model = mm.get_model("llama"); assert!(model.is_some()); let model = model.unwrap(); assert!(model.has_worker_set("ns1")); assert_eq!(model.worker_set_count(), 1); } #[test] fn test_add_worker_set_creates_model() { let mm = ModelManager::new(); assert!(mm.get_model("llama").is_none()); mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); assert!(mm.get_model("llama").is_some()); } #[test] fn test_remove_worker_set_removes_empty_model() { let mm = ModelManager::new(); mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); assert!(mm.get_model("llama").is_some()); let removed = mm.remove_worker_set("llama", "ns1"); assert!(removed.is_some()); assert_eq!(removed.unwrap().namespace(), "ns1"); // Model should be auto-removed since it's now empty assert!(mm.get_model("llama").is_none()); } #[test] fn test_remove_worker_set_keeps_model_with_remaining() { let mm = ModelManager::new(); mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc")); mm.remove_worker_set("llama", "ns1"); // Model should still exist with ns2 let model = mm.get_model("llama").unwrap(); assert!(!model.has_worker_set("ns1")); assert!(model.has_worker_set("ns2")); assert_eq!(model.worker_set_count(), 1); } #[test] fn test_remove_worker_set_nonexistent_model() { let mm = ModelManager::new(); assert!(mm.remove_worker_set("llama", "ns1").is_none()); } #[test] fn test_remove_worker_set_nonexistent_namespace() { let mm = ModelManager::new(); mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); assert!(mm.remove_worker_set("llama", "ns2").is_none()); // Model should still exist (ns1 still there) assert!(mm.get_model("llama").is_some()); } #[test] fn test_remove_model_if_empty_noop_when_not_empty() { let mm = ModelManager::new(); mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); mm.remove_model_if_empty("llama"); assert!(mm.get_model("llama").is_some()); // Still has ns1 } #[test] fn test_remove_model_if_empty_noop_when_missing() { let mm = ModelManager::new(); mm.remove_model_if_empty("nonexistent"); // Should not panic } #[test] fn test_remove_model() { let mm = ModelManager::new(); mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc")); let removed = mm.remove_model("llama"); assert!(removed.is_some()); assert!(mm.get_model("llama").is_none()); } #[test] fn test_get_or_create_model_idempotent() { let mm = ModelManager::new(); let m1 = mm.get_or_create_model("llama"); let m2 = mm.get_or_create_model("llama"); // Both should point to the same Model (same Arc) assert!(Arc::ptr_eq(&m1, &m2)); } // -- Model listing and filtering tests -- #[test] fn test_has_decode_model() { let mm = ModelManager::new(); // No model → false assert!(!mm.has_decode_model("llama")); // Prefill-only set (no engines) → false mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); assert!(!mm.has_decode_model("llama")); } #[test] fn test_has_prefill_model() { let mm = ModelManager::new(); // Prefill set = no engines mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); assert!(mm.has_prefill_model("llama")); } #[test] fn test_has_model_any() { let mm = ModelManager::new(); assert!(!mm.has_model_any("llama")); mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); assert!(mm.has_model_any("llama")); // has prefill } #[test] fn test_model_display_names_includes_prefill() { let mm = ModelManager::new(); mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); let names = mm.model_display_names(); assert!(names.contains("llama")); } #[test] fn test_model_display_names_empty() { let mm = ModelManager::new(); assert!(mm.model_display_names().is_empty()); } #[test] fn test_list_prefill_models() { let mm = ModelManager::new(); mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); mm.add_worker_set("gpt", "ns1", make_worker_set("ns1", "def")); let prefill = mm.list_prefill_models(); assert_eq!(prefill.len(), 2); assert!(prefill.contains(&"llama".to_string())); assert!(prefill.contains(&"gpt".to_string())); } // -- Model card tests -- #[test] fn test_save_and_remove_model_card() { let mm = ModelManager::new(); let card = ModelDeploymentCard::default(); mm.save_model_card("instance/key/1", card.clone()).unwrap(); let cards = mm.get_model_cards(); assert_eq!(cards.len(), 1); let removed = mm.remove_model_card("instance/key/1"); assert!(removed.is_some()); assert!(mm.get_model_cards().is_empty()); } #[test] fn test_remove_model_card_nonexistent() { let mm = ModelManager::new(); assert!(mm.remove_model_card("nonexistent").is_none()); } // -- Prefill router rendezvous tests -- // Note: activate_prefill_router requires an Endpoint (needs DistributedRuntime), // so we test the registration state machine and cleanup only. #[test] fn test_prefill_router_register_new() { let mm = ModelManager::new(); // First registration for a (model, namespace) returns Some(rx) let rx = mm.register_prefill_router("llama", "ns1"); assert!(rx.is_some()); } #[test] fn test_prefill_router_double_register_returns_none() { let mm = ModelManager::new(); let rx1 = mm.register_prefill_router("llama", "ns1"); assert!(rx1.is_some()); // Second registration for the same (model, namespace) returns None let rx2 = mm.register_prefill_router("llama", "ns1"); assert!(rx2.is_none()); } #[test] fn test_prefill_router_different_namespaces_independent() { let mm = ModelManager::new(); // Different namespaces should be independent let rx1 = mm.register_prefill_router("llama", "ns1"); let rx2 = mm.register_prefill_router("llama", "ns2"); assert!(rx1.is_some()); assert!(rx2.is_some()); } #[test] fn test_prefill_router_different_models_independent() { let mm = ModelManager::new(); // Different models should be independent let rx1 = mm.register_prefill_router("llama", "ns1"); let rx2 = mm.register_prefill_router("gpt", "ns1"); assert!(rx1.is_some()); assert!(rx2.is_some()); } #[test] fn test_prefill_router_remove_allows_reregister() { let mm = ModelManager::new(); let rx = mm.register_prefill_router("llama", "ns1"); assert!(rx.is_some()); // Remove the activator mm.remove_prefill_activator("llama", "ns1"); // Should be able to register again let rx2 = mm.register_prefill_router("llama", "ns1"); assert!(rx2.is_some()); } #[test] fn test_prefill_router_remove_nonexistent_noop() { let mm = ModelManager::new(); // Should not panic mm.remove_prefill_activator("llama", "ns1"); } #[test] fn test_model_namespace_key_format() { assert_eq!( ModelManager::model_namespace_key("llama", "ns1"), "llama:ns1" ); assert_eq!( ModelManager::model_namespace_key("gpt-4", "default-abc"), "gpt-4:default-abc" ); } // -- deactivate_prefill_router_for_decode tests -- use crate::kv_router::PrefillRouter; /// Helper: make a WorkerSet with an activated PrefillRouter attached. /// The router is marked as activated to simulate a real deployment where /// the prefill endpoint has already rendezvoused with the decode side. fn make_worker_set_with_prefill_router( namespace: &str, mdcsum: &str, enforce_disagg: bool, ) -> WorkerSet { let mut ws = make_worker_set(namespace, mdcsum); let pr = PrefillRouter::disabled( std::sync::Arc::new(ModelManager::new()), dynamo_runtime::pipeline::RouterMode::RoundRobin, enforce_disagg, ); pr.mark_activated_for_test(); ws.prefill_router = Some(pr); ws } /// Calling deactivate on a non-existent model must not panic. #[test] fn test_deactivate_prefill_router_for_decode_noop_missing_model() { let mm = ModelManager::new(); mm.deactivate_prefill_router_for_decode("nonexistent", "ns1"); } /// Calling deactivate on a WorkerSet without a prefill_router must not panic. #[test] fn test_deactivate_prefill_router_for_decode_noop_no_router() { let mm = ModelManager::new(); mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc")); mm.deactivate_prefill_router_for_decode("llama", "ns1"); } /// Full pipeline test: deactivate finds the WorkerSet, calls deactivate() on its /// PrefillRouter, and the model is hidden from model_display_names() when /// enforce_disagg=true. #[test] fn test_deactivate_prefill_router_for_decode_hides_model() { let mm = ModelManager::new(); mm.add_worker_set( "llama", "ns1", make_worker_set_with_prefill_router("ns1", "abc", true), ); // Model is visible before deactivation. assert!(mm.model_display_names().contains("llama")); mm.deactivate_prefill_router_for_decode("llama", "ns1"); // Model must be hidden after deactivation with enforce_disagg=true. assert!( !mm.model_display_names().contains("llama"), "model must be hidden after prefill deactivation with enforce_disagg=true" ); // Idempotent: calling again must not panic. mm.deactivate_prefill_router_for_decode("llama", "ns1"); assert!(!mm.model_display_names().contains("llama")); } /// Full disagg lifecycle with enforce_disagg=true: /// decode registers -> prefill registers -> prefill dies -> model hidden. #[test] fn test_disagg_lifecycle_prefill_death_hides_model() { let mm = ModelManager::new(); // Step 1: Decode WorkerSet with a PrefillRouter (not yet deactivated). mm.add_worker_set( "llama", "decode-ns", make_worker_set_with_prefill_router("decode-ns", "abc", true), ); assert!( mm.model_display_names().contains("llama"), "step 1: model must be visible with active prefill router" ); // Step 2: Prefill WorkerSet registers (same model, different namespace key). mm.add_worker_set("llama", "prefill-ns", make_worker_set("prefill-ns", "abc")); assert!( mm.model_display_names().contains("llama"), "step 2: model must be visible with both decode and prefill" ); // Step 3: Prefill WorkerSet removed (engine dies). mm.remove_worker_set("llama", "prefill-ns"); // Step 4: Deactivate the prefill router on the decode side. mm.deactivate_prefill_router_for_decode("llama", "decode-ns"); assert!( !mm.model_display_names().contains("llama"), "step 4: model must be hidden after prefill death with enforce_disagg=true" ); } /// Full disagg lifecycle with enforce_disagg=false (fallback allowed). #[test] fn test_disagg_lifecycle_prefill_death_keeps_model_no_enforce() { let mm = ModelManager::new(); mm.add_worker_set( "llama", "decode-ns", make_worker_set_with_prefill_router("decode-ns", "abc", false), ); assert!(mm.model_display_names().contains("llama")); // Deactivate -- model stays visible (enforce_disagg=false, fallback allowed). mm.deactivate_prefill_router_for_decode("llama", "decode-ns"); assert!( mm.model_display_names().contains("llama"), "model must remain visible (enforce_disagg=false, fallback allowed)" ); } /// Full disagg lifecycle including prefill rejoin after transient failure. /// decode registers -> prefill dies -> model hidden -> prefill rejoins -> model visible. #[test] fn test_disagg_lifecycle_prefill_rejoin_restores_model() { let mm = ModelManager::new(); // Decode WorkerSet with enforce_disagg=true. mm.add_worker_set( "llama", "decode-ns", make_worker_set_with_prefill_router("decode-ns", "abc", true), ); assert!(mm.model_display_names().contains("llama")); // Prefill dies -> deactivate. mm.deactivate_prefill_router_for_decode("llama", "decode-ns"); assert!( !mm.model_display_names().contains("llama"), "model must be hidden after prefill death" ); // Prefill rejoins -> reactivate via the WorkerSet's PrefillRouter. if let Some(model) = mm.get_model("llama") && let Some(ws) = model.get_worker_set("decode-ns") && let Some(ref pr) = ws.prefill_router { pr.reactivate(); } else { panic!("decode WorkerSet or prefill_router not found"); } assert!( mm.model_display_names().contains("llama"), "model must be visible again after prefill rejoin" ); } }