Unverified Commit cb55766c authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat(runtime): add hierarchical Model/WorkerSet architecture for multi-namespace support (#6054)


Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
parent 8dd6369e
...@@ -372,13 +372,11 @@ async fn completions_single( ...@@ -372,13 +372,11 @@ async fn completions_single(
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model); let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
// todo - error handling should be more robust // todo - error handling should be more robust
let engine = state let (engine, parsing_options) = state
.manager() .manager()
.get_completions_engine(&model) .get_completions_engine_with_parsing(&model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model); let mut response_collector = state.metrics_clone().create_response_collector(&model);
// prepare to process any annotations // prepare to process any annotations
...@@ -495,13 +493,11 @@ async fn completions_batch( ...@@ -495,13 +493,11 @@ async fn completions_batch(
// Create http_queue_guard early - tracks time waiting to be processed // Create http_queue_guard early - tracks time waiting to be processed
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model); let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let engine = state let (engine, parsing_options) = state
.manager() .manager()
.get_completions_engine(&model) .get_completions_engine_with_parsing(&model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model); let mut response_collector = state.metrics_clone().create_response_collector(&model);
// prepare to process any annotations // prepare to process any annotations
...@@ -916,13 +912,11 @@ async fn chat_completions( ...@@ -916,13 +912,11 @@ async fn chat_completions(
tracing::trace!("Getting chat completions engine for model: {}", model); tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state let (engine, parsing_options) = state
.manager() .manager()
.get_chat_completions_engine(&model) .get_chat_completions_engine_with_parsing(&model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model); let mut response_collector = state.metrics_clone().create_response_collector(&model);
let annotations = request.annotations(); let annotations = request.annotations();
...@@ -1260,13 +1254,11 @@ async fn responses( ...@@ -1260,13 +1254,11 @@ async fn responses(
tracing::trace!("Getting chat completions engine for model: {}", model); tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state let (engine, parsing_options) = state
.manager() .manager()
.get_chat_completions_engine(&model) .get_chat_completions_engine_with_parsing(&model)
.map_err(|_| ErrorMessage::model_not_found())?; .map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model); let mut response_collector = state.metrics_clone().create_response_collector(&model);
tracing::trace!("Issuing generate call for responses"); tracing::trace!("Issuing generate call for responses");
......
...@@ -100,6 +100,8 @@ pub struct PrefillRouter { ...@@ -100,6 +100,8 @@ pub struct PrefillRouter {
enforce_disagg: bool, enforce_disagg: bool,
/// Model name used to look up the worker monitor for prefill client registration /// Model name used to look up the worker monitor for prefill client registration
model_name: String, model_name: String,
/// Namespace used to look up the correct WorkerSet's worker monitor
namespace: String,
} }
impl PrefillRouter { impl PrefillRouter {
...@@ -117,9 +119,11 @@ impl PrefillRouter { ...@@ -117,9 +119,11 @@ impl PrefillRouter {
router_mode, router_mode,
enforce_disagg, enforce_disagg,
model_name: String::new(), // Not used for disabled router model_name: String::new(), // Not used for disabled router
namespace: String::new(), // Not used for disabled router
}) })
} }
#[allow(clippy::too_many_arguments)]
pub fn new( pub fn new(
activation_rx: oneshot::Receiver<Endpoint>, activation_rx: oneshot::Receiver<Endpoint>,
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
...@@ -128,6 +132,7 @@ impl PrefillRouter { ...@@ -128,6 +132,7 @@ impl PrefillRouter {
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
enforce_disagg: bool, enforce_disagg: bool,
model_name: String, model_name: String,
namespace: String,
) -> Arc<Self> { ) -> Arc<Self> {
let prefill_router = OnceLock::new(); let prefill_router = OnceLock::new();
let cancel_token = CancellationToken::new(); let cancel_token = CancellationToken::new();
...@@ -140,6 +145,7 @@ impl PrefillRouter { ...@@ -140,6 +145,7 @@ impl PrefillRouter {
router_mode, router_mode,
enforce_disagg, enforce_disagg,
model_name, model_name,
namespace,
}); });
// Spawn background task to wait for activation // Spawn background task to wait for activation
...@@ -207,7 +213,9 @@ impl PrefillRouter { ...@@ -207,7 +213,9 @@ impl PrefillRouter {
let client = kv_chooser.client().clone(); let client = kv_chooser.client().clone();
// Register prefill client with worker monitor for TTFT metric cleanup in disaggregated mode // Register prefill client with worker monitor for TTFT metric cleanup in disaggregated mode
if let Some(monitor) = model_manager.get_worker_monitor(&self.model_name) { if let Some(monitor) =
model_manager.get_worker_monitor_for_namespace(&self.model_name, &self.namespace)
{
monitor.set_prefill_client(client.clone()); monitor.set_prefill_client(client.clone());
} }
...@@ -227,7 +235,9 @@ impl PrefillRouter { ...@@ -227,7 +235,9 @@ impl PrefillRouter {
let client = endpoint.client().await?; let client = endpoint.client().await?;
// Register prefill client with worker monitor for TTFT metric cleanup in disaggregated mode // Register prefill client with worker monitor for TTFT metric cleanup in disaggregated mode
if let Some(monitor) = model_manager.get_worker_monitor(&self.model_name) { if let Some(monitor) =
model_manager.get_worker_monitor_for_namespace(&self.model_name, &self.namespace)
{
monitor.set_prefill_client(client.clone()); monitor.set_prefill_client(client.clone());
} }
......
...@@ -56,6 +56,7 @@ pub struct LocalModelBuilder { ...@@ -56,6 +56,7 @@ pub struct LocalModelBuilder {
user_data: Option<serde_json::Value>, user_data: Option<serde_json::Value>,
custom_template_path: Option<PathBuf>, custom_template_path: Option<PathBuf>,
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>,
media_decoder: Option<MediaDecoder>, media_decoder: Option<MediaDecoder>,
media_fetcher: Option<MediaFetcher>, media_fetcher: Option<MediaFetcher>,
} }
...@@ -83,6 +84,7 @@ impl Default for LocalModelBuilder { ...@@ -83,6 +84,7 @@ impl Default for LocalModelBuilder {
user_data: Default::default(), user_data: Default::default(),
custom_template_path: Default::default(), custom_template_path: Default::default(),
namespace: Default::default(), namespace: Default::default(),
namespace_prefix: Default::default(),
media_decoder: Default::default(), media_decoder: Default::default(),
media_fetcher: Default::default(), media_fetcher: Default::default(),
} }
...@@ -160,6 +162,11 @@ impl LocalModelBuilder { ...@@ -160,6 +162,11 @@ impl LocalModelBuilder {
self self
} }
pub fn namespace_prefix(&mut self, namespace_prefix: Option<String>) -> &mut Self {
self.namespace_prefix = namespace_prefix;
self
}
pub fn request_template(&mut self, template_file: Option<PathBuf>) -> &mut Self { pub fn request_template(&mut self, template_file: Option<PathBuf>) -> &mut Self {
self.template_file = template_file; self.template_file = template_file;
self self
...@@ -288,6 +295,7 @@ impl LocalModelBuilder { ...@@ -288,6 +295,7 @@ impl LocalModelBuilder {
router_config: self.router_config.take().unwrap_or_default(), router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(), runtime_config: self.runtime_config.clone(),
namespace: self.namespace.clone(), namespace: self.namespace.clone(),
namespace_prefix: self.namespace_prefix.clone(),
migration_limit: self.migration_limit, migration_limit: self.migration_limit,
}); });
} }
...@@ -340,6 +348,7 @@ impl LocalModelBuilder { ...@@ -340,6 +348,7 @@ impl LocalModelBuilder {
router_config: self.router_config.take().unwrap_or_default(), router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(), runtime_config: self.runtime_config.clone(),
namespace: self.namespace.clone(), namespace: self.namespace.clone(),
namespace_prefix: self.namespace_prefix.clone(),
migration_limit: self.migration_limit, migration_limit: self.migration_limit,
}) })
} }
...@@ -359,6 +368,7 @@ pub struct LocalModel { ...@@ -359,6 +368,7 @@ pub struct LocalModel {
router_config: RouterConfig, router_config: RouterConfig,
runtime_config: ModelRuntimeConfig, runtime_config: ModelRuntimeConfig,
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>,
migration_limit: u32, migration_limit: u32,
} }
...@@ -431,6 +441,10 @@ impl LocalModel { ...@@ -431,6 +441,10 @@ impl LocalModel {
self.namespace.as_deref() self.namespace.as_deref()
} }
pub fn namespace_prefix(&self) -> Option<&str> {
self.namespace_prefix.as_deref()
}
/// An endpoint to identify this model by. /// An endpoint to identify this model by.
pub fn endpoint_id(&self) -> &EndpointId { pub fn endpoint_id(&self) -> &EndpointId {
&self.endpoint_id &self.endpoint_id
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
/// The global namespace for all models
pub const GLOBAL_NAMESPACE: &str = "dynamo"; pub const GLOBAL_NAMESPACE: &str = "dynamo";
/// Determines how namespaces are filtered during model discovery.
///
/// This supports the hierarchical model architecture where multiple WorkerSets
/// with different namespaces (e.g., during rolling updates) should be discovered
/// together under the same Model.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NamespaceFilter {
/// Discover models from all namespaces (no filtering)
Global,
/// Discover models only from an exact namespace match
Exact(String),
/// Discover models from namespaces starting with the given prefix
/// (e.g., prefix "ns" matches "ns", "ns-abc123", "ns-def456")
Prefix(String),
}
impl NamespaceFilter {
/// Create a NamespaceFilter from optional namespace and namespace_prefix.
/// If prefix is provided, it takes precedence over exact namespace.
pub fn from_namespace_and_prefix(
namespace: Option<&str>,
namespace_prefix: Option<&str>,
) -> Self {
// Prefix takes precedence if both are specified
if let Some(prefix) = namespace_prefix {
if prefix.is_empty() || is_global_namespace(prefix) {
return NamespaceFilter::Global;
}
return NamespaceFilter::Prefix(prefix.to_string());
}
if let Some(ns) = namespace {
if ns.is_empty() || is_global_namespace(ns) {
return NamespaceFilter::Global;
}
return NamespaceFilter::Exact(ns.to_string());
}
NamespaceFilter::Global
}
/// Check if a given namespace matches this filter.
pub fn matches(&self, namespace: &str) -> bool {
match self {
NamespaceFilter::Global => true,
NamespaceFilter::Exact(target) => namespace == target,
NamespaceFilter::Prefix(prefix) => namespace.starts_with(prefix),
}
}
/// Returns true if this is global namespace filtering (no filtering).
pub fn is_global(&self) -> bool {
matches!(self, NamespaceFilter::Global)
}
}
pub fn is_global_namespace(namespace: &str) -> bool { pub fn is_global_namespace(namespace: &str) -> bool {
namespace == GLOBAL_NAMESPACE || namespace.is_empty() namespace == GLOBAL_NAMESPACE || namespace.is_empty()
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_namespace_and_prefix_global() {
assert_eq!(
NamespaceFilter::from_namespace_and_prefix(None, None),
NamespaceFilter::Global
);
assert_eq!(
NamespaceFilter::from_namespace_and_prefix(Some(""), None),
NamespaceFilter::Global
);
assert_eq!(
NamespaceFilter::from_namespace_and_prefix(Some(GLOBAL_NAMESPACE), None),
NamespaceFilter::Global
);
}
#[test]
fn test_from_namespace_and_prefix_exact() {
assert_eq!(
NamespaceFilter::from_namespace_and_prefix(Some("my-namespace"), None),
NamespaceFilter::Exact("my-namespace".to_string())
);
}
#[test]
fn test_from_namespace_and_prefix_prefix_takes_precedence() {
assert_eq!(
NamespaceFilter::from_namespace_and_prefix(Some("exact"), Some("prefix")),
NamespaceFilter::Prefix("prefix".to_string())
);
}
#[test]
fn test_matches_global() {
let filter = NamespaceFilter::Global;
assert!(filter.matches("anything"));
assert!(filter.matches(""));
assert!(filter.matches("default"));
assert!(filter.matches("ns-abc123"));
}
#[test]
fn test_matches_exact() {
let filter = NamespaceFilter::Exact("my-namespace".to_string());
assert!(filter.matches("my-namespace"));
assert!(!filter.matches("my-namespace-abc123"));
assert!(!filter.matches("other"));
assert!(!filter.matches(""));
}
#[test]
fn test_matches_prefix() {
let filter = NamespaceFilter::Prefix("ns".to_string());
assert!(filter.matches("ns"));
assert!(filter.matches("ns-abc123"));
assert!(filter.matches("ns-def456"));
assert!(!filter.matches("other-ns"));
assert!(!filter.matches(""));
}
#[test]
fn test_is_global() {
assert!(NamespaceFilter::Global.is_global());
assert!(!NamespaceFilter::Exact("ns".to_string()).is_global());
assert!(!NamespaceFilter::Prefix("ns".to_string()).is_global());
}
}
...@@ -295,7 +295,7 @@ mod integration_tests { ...@@ -295,7 +295,7 @@ mod integration_tests {
use super::*; use super::*;
use dynamo_llm::{ use dynamo_llm::{
discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig, discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig,
local_model::LocalModelBuilder, local_model::LocalModelBuilder, namespace::NamespaceFilter,
}; };
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::discovery::DiscoveryQuery; use dynamo_runtime::discovery::DiscoveryQuery;
...@@ -355,7 +355,9 @@ mod integration_tests { ...@@ -355,7 +355,9 @@ mod integration_tests {
// Spawn watcher task to discover models // Spawn watcher task to discover models
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
model_watcher.watch(discovery_stream, None).await; model_watcher
.watch(discovery_stream, NamespaceFilter::Global)
.await;
}); });
let EngineConfig::InProcessText { engine, model, .. } = engine_config else { let EngineConfig::InProcessText { engine, model, .. } = engine_config else {
...@@ -553,13 +555,8 @@ mod integration_tests { ...@@ -553,13 +555,8 @@ mod integration_tests {
if let Some(key) = key { if let Some(key) = key {
// Remove from ModelManager first (this returns the ModelEntry) // Remove from ModelManager first (this returns the ModelEntry)
if let Some(_removed_card) = manager.remove_model_card(&key) { if let Some(_removed_card) = manager.remove_model_card(&key) {
// Remove engines (following ModelWatcher::handle_delete pattern) // Remove entire model (following ModelWatcher::handle_delete pattern)
manager manager.remove_model(&model_entry.name);
.remove_chat_completions_model(&model_entry.name)
.ok();
manager.remove_completions_model(&model_entry.name).ok();
manager.remove_embeddings_model(&model_entry.name).ok();
manager.remove_tensor_model(&model_entry.name).ok();
// Then delete from etcd // Then delete from etcd
etcd_client.kv_delete(key.as_str(), None).await.unwrap(); etcd_client.kv_delete(key.as_str(), None).await.unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment