Unverified Commit cb55766c authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat(runtime): add hierarchical Model/WorkerSet architecture for multi-namespace support (#6054)


Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
parent 8dd6369e
......@@ -372,13 +372,11 @@ async fn completions_single(
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
// todo - error handling should be more robust
let engine = state
let (engine, parsing_options) = state
.manager()
.get_completions_engine(&model)
.get_completions_engine_with_parsing(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
// prepare to process any annotations
......@@ -495,13 +493,11 @@ async fn completions_batch(
// Create http_queue_guard early - tracks time waiting to be processed
let http_queue_guard = state.metrics_clone().create_http_queue_guard(&model);
let engine = state
let (engine, parsing_options) = state
.manager()
.get_completions_engine(&model)
.get_completions_engine_with_parsing(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
// prepare to process any annotations
......@@ -916,13 +912,11 @@ async fn chat_completions(
tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state
let (engine, parsing_options) = state
.manager()
.get_chat_completions_engine(&model)
.get_chat_completions_engine_with_parsing(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
let annotations = request.annotations();
......@@ -1260,13 +1254,11 @@ async fn responses(
tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state
let (engine, parsing_options) = state
.manager()
.get_chat_completions_engine(&model)
.get_chat_completions_engine_with_parsing(&model)
.map_err(|_| ErrorMessage::model_not_found())?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
tracing::trace!("Issuing generate call for responses");
......
......@@ -100,6 +100,8 @@ pub struct PrefillRouter {
enforce_disagg: bool,
/// Model name used to look up the worker monitor for prefill client registration
model_name: String,
/// Namespace used to look up the correct WorkerSet's worker monitor
namespace: String,
}
impl PrefillRouter {
......@@ -117,9 +119,11 @@ impl PrefillRouter {
router_mode,
enforce_disagg,
model_name: String::new(), // Not used for disabled router
namespace: String::new(), // Not used for disabled router
})
}
#[allow(clippy::too_many_arguments)]
pub fn new(
activation_rx: oneshot::Receiver<Endpoint>,
model_manager: Arc<ModelManager>,
......@@ -128,6 +132,7 @@ impl PrefillRouter {
kv_router_config: Option<KvRouterConfig>,
enforce_disagg: bool,
model_name: String,
namespace: String,
) -> Arc<Self> {
let prefill_router = OnceLock::new();
let cancel_token = CancellationToken::new();
......@@ -140,6 +145,7 @@ impl PrefillRouter {
router_mode,
enforce_disagg,
model_name,
namespace,
});
// Spawn background task to wait for activation
......@@ -207,7 +213,9 @@ impl PrefillRouter {
let client = kv_chooser.client().clone();
// Register prefill client with worker monitor for TTFT metric cleanup in disaggregated mode
if let Some(monitor) = model_manager.get_worker_monitor(&self.model_name) {
if let Some(monitor) =
model_manager.get_worker_monitor_for_namespace(&self.model_name, &self.namespace)
{
monitor.set_prefill_client(client.clone());
}
......@@ -227,7 +235,9 @@ impl PrefillRouter {
let client = endpoint.client().await?;
// Register prefill client with worker monitor for TTFT metric cleanup in disaggregated mode
if let Some(monitor) = model_manager.get_worker_monitor(&self.model_name) {
if let Some(monitor) =
model_manager.get_worker_monitor_for_namespace(&self.model_name, &self.namespace)
{
monitor.set_prefill_client(client.clone());
}
......
......@@ -56,6 +56,7 @@ pub struct LocalModelBuilder {
user_data: Option<serde_json::Value>,
custom_template_path: Option<PathBuf>,
namespace: Option<String>,
namespace_prefix: Option<String>,
media_decoder: Option<MediaDecoder>,
media_fetcher: Option<MediaFetcher>,
}
......@@ -83,6 +84,7 @@ impl Default for LocalModelBuilder {
user_data: Default::default(),
custom_template_path: Default::default(),
namespace: Default::default(),
namespace_prefix: Default::default(),
media_decoder: Default::default(),
media_fetcher: Default::default(),
}
......@@ -160,6 +162,11 @@ impl LocalModelBuilder {
self
}
pub fn namespace_prefix(&mut self, namespace_prefix: Option<String>) -> &mut Self {
self.namespace_prefix = namespace_prefix;
self
}
pub fn request_template(&mut self, template_file: Option<PathBuf>) -> &mut Self {
self.template_file = template_file;
self
......@@ -288,6 +295,7 @@ impl LocalModelBuilder {
router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(),
namespace: self.namespace.clone(),
namespace_prefix: self.namespace_prefix.clone(),
migration_limit: self.migration_limit,
});
}
......@@ -340,6 +348,7 @@ impl LocalModelBuilder {
router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(),
namespace: self.namespace.clone(),
namespace_prefix: self.namespace_prefix.clone(),
migration_limit: self.migration_limit,
})
}
......@@ -359,6 +368,7 @@ pub struct LocalModel {
router_config: RouterConfig,
runtime_config: ModelRuntimeConfig,
namespace: Option<String>,
namespace_prefix: Option<String>,
migration_limit: u32,
}
......@@ -431,6 +441,10 @@ impl LocalModel {
self.namespace.as_deref()
}
pub fn namespace_prefix(&self) -> Option<&str> {
self.namespace_prefix.as_deref()
}
/// An endpoint to identify this model by.
pub fn endpoint_id(&self) -> &EndpointId {
&self.endpoint_id
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
/// The global namespace for all models
pub const GLOBAL_NAMESPACE: &str = "dynamo";
/// Determines how namespaces are filtered during model discovery.
///
/// This supports the hierarchical model architecture where multiple WorkerSets
/// with different namespaces (e.g., during rolling updates) should be discovered
/// together under the same Model.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum NamespaceFilter {
/// Discover models from all namespaces (no filtering)
Global,
/// Discover models only from an exact namespace match
Exact(String),
/// Discover models from namespaces starting with the given prefix
/// (e.g., prefix "ns" matches "ns", "ns-abc123", "ns-def456")
Prefix(String),
}
impl NamespaceFilter {
/// Create a NamespaceFilter from optional namespace and namespace_prefix.
/// If prefix is provided, it takes precedence over exact namespace.
pub fn from_namespace_and_prefix(
namespace: Option<&str>,
namespace_prefix: Option<&str>,
) -> Self {
// Prefix takes precedence if both are specified
if let Some(prefix) = namespace_prefix {
if prefix.is_empty() || is_global_namespace(prefix) {
return NamespaceFilter::Global;
}
return NamespaceFilter::Prefix(prefix.to_string());
}
if let Some(ns) = namespace {
if ns.is_empty() || is_global_namespace(ns) {
return NamespaceFilter::Global;
}
return NamespaceFilter::Exact(ns.to_string());
}
NamespaceFilter::Global
}
/// Check if a given namespace matches this filter.
pub fn matches(&self, namespace: &str) -> bool {
match self {
NamespaceFilter::Global => true,
NamespaceFilter::Exact(target) => namespace == target,
NamespaceFilter::Prefix(prefix) => namespace.starts_with(prefix),
}
}
/// Returns true if this is global namespace filtering (no filtering).
pub fn is_global(&self) -> bool {
matches!(self, NamespaceFilter::Global)
}
}
pub fn is_global_namespace(namespace: &str) -> bool {
namespace == GLOBAL_NAMESPACE || namespace.is_empty()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_from_namespace_and_prefix_global() {
assert_eq!(
NamespaceFilter::from_namespace_and_prefix(None, None),
NamespaceFilter::Global
);
assert_eq!(
NamespaceFilter::from_namespace_and_prefix(Some(""), None),
NamespaceFilter::Global
);
assert_eq!(
NamespaceFilter::from_namespace_and_prefix(Some(GLOBAL_NAMESPACE), None),
NamespaceFilter::Global
);
}
#[test]
fn test_from_namespace_and_prefix_exact() {
assert_eq!(
NamespaceFilter::from_namespace_and_prefix(Some("my-namespace"), None),
NamespaceFilter::Exact("my-namespace".to_string())
);
}
#[test]
fn test_from_namespace_and_prefix_prefix_takes_precedence() {
assert_eq!(
NamespaceFilter::from_namespace_and_prefix(Some("exact"), Some("prefix")),
NamespaceFilter::Prefix("prefix".to_string())
);
}
#[test]
fn test_matches_global() {
let filter = NamespaceFilter::Global;
assert!(filter.matches("anything"));
assert!(filter.matches(""));
assert!(filter.matches("default"));
assert!(filter.matches("ns-abc123"));
}
#[test]
fn test_matches_exact() {
let filter = NamespaceFilter::Exact("my-namespace".to_string());
assert!(filter.matches("my-namespace"));
assert!(!filter.matches("my-namespace-abc123"));
assert!(!filter.matches("other"));
assert!(!filter.matches(""));
}
#[test]
fn test_matches_prefix() {
let filter = NamespaceFilter::Prefix("ns".to_string());
assert!(filter.matches("ns"));
assert!(filter.matches("ns-abc123"));
assert!(filter.matches("ns-def456"));
assert!(!filter.matches("other-ns"));
assert!(!filter.matches(""));
}
#[test]
fn test_is_global() {
assert!(NamespaceFilter::Global.is_global());
assert!(!NamespaceFilter::Exact("ns".to_string()).is_global());
assert!(!NamespaceFilter::Prefix("ns".to_string()).is_global());
}
}
......@@ -295,7 +295,7 @@ mod integration_tests {
use super::*;
use dynamo_llm::{
discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig,
local_model::LocalModelBuilder,
local_model::LocalModelBuilder, namespace::NamespaceFilter,
};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::discovery::DiscoveryQuery;
......@@ -355,7 +355,9 @@ mod integration_tests {
// Spawn watcher task to discover models
let _watcher_task = tokio::spawn(async move {
model_watcher.watch(discovery_stream, None).await;
model_watcher
.watch(discovery_stream, NamespaceFilter::Global)
.await;
});
let EngineConfig::InProcessText { engine, model, .. } = engine_config else {
......@@ -553,13 +555,8 @@ mod integration_tests {
if let Some(key) = key {
// Remove from ModelManager first (this returns the ModelEntry)
if let Some(_removed_card) = manager.remove_model_card(&key) {
// Remove engines (following ModelWatcher::handle_delete pattern)
manager
.remove_chat_completions_model(&model_entry.name)
.ok();
manager.remove_completions_model(&model_entry.name).ok();
manager.remove_embeddings_model(&model_entry.name).ok();
manager.remove_tensor_model(&model_entry.name).ok();
// Remove entire model (following ModelWatcher::handle_delete pattern)
manager.remove_model(&model_entry.name);
// Then delete from etcd
etcd_client.kv_delete(key.as_str(), None).await.unwrap();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment