Unverified Commit cb55766c authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat(runtime): add hierarchical Model/WorkerSet architecture for multi-namespace support (#6054)


Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
parent 8dd6369e
......@@ -9,6 +9,7 @@ from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from dynamo.common.utils.namespace import get_worker_namespace
from dynamo.common.utils.output_modalities import OutputModality
......@@ -35,6 +36,8 @@ class DynamoRuntimeConfig(ConfigBase):
media_output_http_url: Optional[str] = None
def validate(self) -> None:
self.namespace = get_worker_namespace(self.namespace)
# TODO get a better way for spot fixes like this.
self.enable_local_indexer = not self.durable_kv_events
self._validate_output_modalities()
......@@ -69,7 +72,8 @@ class DynamoRuntimeArgGroup(ArgGroup):
flag_name="--namespace",
env_var="DYN_NAMESPACE",
default="dynamo",
help="Dynamo namespace",
help="Dynamo namespace. If DYN_NAMESPACE_WORKER_SUFFIX is set, "
"'-{suffix}' is appended to support multiple worker pools",
)
add_argument(
g,
......
......@@ -17,6 +17,7 @@ Submodules:
from dynamo.common.utils import (
endpoint_types,
engine_response,
namespace,
otel_tracing,
paths,
prometheus,
......@@ -26,6 +27,7 @@ from dynamo.common.utils import (
__all__ = [
"endpoint_types",
"engine_response",
"namespace",
"otel_tracing",
"paths",
"prometheus",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
def get_worker_namespace(namespace: Optional[str] = None) -> str:
"""Get the Dynamo namespace for a worker.
Uses the provided namespace, or falls back to the DYN_NAMESPACE environment
variable (defaulting to "dynamo"). If DYN_NAMESPACE_WORKER_SUFFIX is set,
it is appended as "{namespace}-{suffix}" to support multiple sets of workers
for the same model.
"""
if not namespace:
namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
suffix = os.environ.get("DYN_NAMESPACE_WORKER_SUFFIX")
if suffix:
namespace = f"{namespace}-{suffix}"
return namespace
......@@ -54,6 +54,7 @@ class FrontendConfig(ConfigBase):
router_max_tree_size: int
router_prune_target_ratio: float
namespace: Optional[str] = None
namespace_prefix: Optional[str] = None
router_replica_sync: bool
router_snapshot_threshold: int
router_reset_states: bool
......@@ -128,9 +129,8 @@ class FrontendArgGroup(ArgGroup):
env_var="DYN_NAMESPACE",
default=None,
help=(
"Dynamo namespace for model discovery scoping. If specified, models will "
"only be discovered from this namespace. If not specified, discovers models "
"from all namespaces (global discovery)."
"Dynamo namespace for model discovery scoping. Use for exact namespace matching. "
"If --namespace-prefix is also specified, prefix takes precedence."
),
)
......@@ -256,6 +256,18 @@ class FrontendArgGroup(ArgGroup):
arg_type=float,
)
add_argument(
g,
flag_name="--namespace-prefix",
env_var="DYN_NAMESPACE_PREFIX",
default=None,
help=(
"Dynamo namespace prefix for model discovery scoping. Discovers models from "
"namespaces starting with this prefix (e.g., 'ns' matches 'ns', 'ns-abc123', "
"'ns-def456'). Takes precedence over --namespace if both are specified."
),
)
add_negatable_bool_argument(
g,
flag_name="--router-replica-sync",
......
......@@ -230,6 +230,8 @@ async def async_main():
kwargs["tls_key_path"] = config.tls_key_path
if config.namespace:
kwargs["namespace"] = config.namespace
if config.namespace_prefix:
kwargs["namespace_prefix"] = config.namespace_prefix
if config.kserve_grpc_server and config.grpc_metrics_port:
kwargs["http_metrics_port"] = config.grpc_metrics_port
......
......@@ -8,6 +8,8 @@ import os
import tempfile
from pathlib import Path
from dynamo.common.utils.namespace import get_worker_namespace
from . import __version__
from .utils.planner_profiler_perf_data_converter import (
convert_profile_results_to_npz,
......@@ -15,7 +17,7 @@ from .utils.planner_profiler_perf_data_converter import (
is_profile_results_dir,
)
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DYN_NAMESPACE = get_worker_namespace()
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"
DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate"
......
......@@ -7,11 +7,12 @@ This module defines the DiffusionConfig dataclass used for configuring
video and image diffusion workers.
"""
import os
from dataclasses import dataclass
from typing import Optional
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
from dynamo.common.utils.namespace import get_worker_namespace
DYN_NAMESPACE = get_worker_namespace()
# Default model paths
DEFAULT_VIDEO_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
......
......@@ -704,6 +704,7 @@ pub unsafe extern "C" fn create_routers(
Some(prefill_config),
enforce_disagg,
model_name.clone(),
namespace_str.clone(),
)
}
None if enforce_disagg => {
......
......@@ -182,6 +182,7 @@ pub(crate) struct EntrypointArgs {
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
namespace: Option<String>,
namespace_prefix: Option<String>,
is_prefill: bool,
migration_limit: u32,
chat_engine_factory: Option<PyEngineFactory>,
......@@ -191,7 +192,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs {
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
pub fn new(
py: Python<'_>,
engine_type: EngineType,
......@@ -209,6 +210,7 @@ impl EntrypointArgs {
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
namespace: Option<String>,
namespace_prefix: Option<String>,
is_prefill: bool,
migration_limit: u32,
chat_engine_factory: Option<PyObject>,
......@@ -254,6 +256,7 @@ impl EntrypointArgs {
tls_key_path,
extra_engine_args,
namespace,
namespace_prefix,
is_prefill,
migration_limit,
chat_engine_factory,
......@@ -296,7 +299,8 @@ pub fn make_engine<'p>(
.tls_key_path(args.tls_key_path.clone())
.is_mocker(matches!(args.engine_type, EngineType::Mocker))
.extra_engine_args(args.extra_engine_args.clone())
.namespace(args.namespace.clone());
.namespace(args.namespace.clone())
.namespace_prefix(args.namespace_prefix.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
if let Some(model_path) = args.model_path.clone() {
let local_path = if model_path.exists() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod model;
pub use model::Model;
mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError};
mod worker_set;
pub use worker_set::WorkerSet;
pub(crate) mod runtime_configs;
pub use runtime_configs::{RuntimeConfigWatch, runtime_config_watch};
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! A Model represents a named model (e.g., "llama-3-70b") that may be served by
//! one or more WorkerSets. Each WorkerSet corresponds to a namespace.
//!
//! Requests are routed to a WorkerSet selected by weighted random (proportional to worker count).
use std::sync::{Arc, OnceLock};
use dashmap::DashMap;
use rand::Rng;
use super::worker_monitor::LoadThresholdConfig;
use super::worker_set::WorkerSet;
use super::{KvWorkerMonitor, ModelManagerError};
use crate::protocols::openai::ParsingOptions;
use crate::types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
images::OpenAIImagesStreamingEngine, videos::OpenAIVideosStreamingEngine,
},
};
/// A named model backed by one or more WorkerSets.
pub struct Model {
name: String,
worker_sets: DashMap<String, Arc<WorkerSet>>,
/// The canonical MDC checksum for this model. Set by the first WorkerSet registered;
/// all subsequent WorkerSets must match. Naturally cleared when the Model is dropped
/// (last WorkerSet removed), allowing a new version to register.
canonical_checksum: OnceLock<String>,
}
impl Model {
pub fn new(name: String) -> Self {
Self {
name,
worker_sets: DashMap::new(),
canonical_checksum: OnceLock::new(),
}
}
pub fn name(&self) -> &str {
&self.name
}
/// Add a WorkerSet to this model. Returns `Err` if the WorkerSet's checksum
/// doesn't match the model's canonical checksum (set by the first WorkerSet).
pub fn add_worker_set(
&self,
namespace: String,
worker_set: Arc<WorkerSet>,
) -> Result<(), ModelManagerError> {
self.set_canonical_checksum(worker_set.mdcsum())?;
tracing::info!(
model = %self.name,
namespace = %namespace,
"Adding worker set to model"
);
self.worker_sets.insert(namespace, worker_set);
Ok(())
}
pub fn remove_worker_set(&self, namespace: &str) -> Option<Arc<WorkerSet>> {
let removed = self.worker_sets.remove(namespace).map(|(_, ws)| ws);
if removed.is_some() {
tracing::info!(
model = %self.name,
namespace = %namespace,
remaining_sets = self.worker_sets.len(),
"Removed worker set from model"
);
}
removed
}
pub fn has_worker_set(&self, namespace: &str) -> bool {
self.worker_sets.contains_key(namespace)
}
pub fn get_worker_set(&self, namespace: &str) -> Option<Arc<WorkerSet>> {
self.worker_sets
.get(namespace)
.map(|entry| entry.value().clone())
}
pub fn is_empty(&self) -> bool {
self.worker_sets.is_empty()
}
pub fn worker_set_count(&self) -> usize {
self.worker_sets.len()
}
/// Check if this model has any decode engine (chat or completions) across any WorkerSet.
pub fn has_decode_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_decode_engine())
}
/// Check if this model tracks prefill (any WorkerSet is a prefill set).
pub fn has_prefill(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().is_prefill_set())
}
/// Check if any WorkerSet has a chat engine.
pub fn has_chat_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_chat_engine())
}
/// Check if any WorkerSet has a completions engine.
pub fn has_completions_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_completions_engine())
}
/// Check if any WorkerSet has an embeddings engine.
pub fn has_embeddings_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_embeddings_engine())
}
/// Check if any WorkerSet has a tensor engine.
pub fn has_tensor_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_tensor_engine())
}
/// Check if any WorkerSet has an images engine.
pub fn has_images_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_images_engine())
}
/// Check if any WorkerSet has a videos engine.
pub fn has_videos_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_videos_engine())
}
/// Check if a candidate checksum is valid for this model.
/// Returns `Some(true)` if it matches the canonical checksum, `Some(false)` if it
/// doesn't match, or `None` if no canonical checksum has been set yet (no WorkerSets).
pub fn is_valid_checksum(&self, candidate: &str) -> Option<bool> {
let canonical = self.canonical_checksum.get()?;
Some(canonical == candidate)
}
/// Set the canonical checksum for this model. The first caller wins (OnceLock).
/// Returns `Err` if a different checksum was already set.
fn set_canonical_checksum(&self, checksum: &str) -> Result<(), ModelManagerError> {
// Try to set; if already set, verify it matches.
match self.canonical_checksum.set(checksum.to_string()) {
Ok(()) => Ok(()),
Err(_) => {
// OnceLock was already set — check if the value matches
let canonical = self.canonical_checksum.get().unwrap();
if canonical == checksum {
Ok(())
} else {
Err(ModelManagerError::ChecksumMismatch {
model: self.name.clone(),
expected: canonical.clone(),
got: checksum.to_string(),
})
}
}
}
}
// -- Engine accessors: select a WorkerSet, return its engine --
pub fn get_chat_engine(
&self,
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.chat_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_completions_engine(
&self,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.completions_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_embeddings_engine(
&self,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.embeddings_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_images_engine(&self) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.images_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_videos_engine(&self) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.videos_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_tensor_engine(&self) -> Result<TensorStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.tensor_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
// -- Combined engine + parsing options (atomically from one WorkerSet) --
pub fn get_chat_engine_with_parsing(
&self,
) -> Result<(OpenAIChatCompletionsStreamingEngine, ParsingOptions), ModelManagerError> {
self.select_worker_set_with(|ws| ws.chat_engine.clone().map(|e| (e, ws.parsing_options())))
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_completions_engine_with_parsing(
&self,
) -> Result<(OpenAICompletionsStreamingEngine, ParsingOptions), ModelManagerError> {
self.select_worker_set_with(|ws| {
ws.completions_engine
.clone()
.map(|e| (e, ws.parsing_options()))
})
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
// -- Worker monitoring (aggregated across WorkerSets) --
/// Get load threshold config from the first WorkerSet that has a monitor.
/// When `config` is Some, updates ALL monitors (each WorkerSet has its own).
pub fn load_threshold_config(
&self,
config: Option<&LoadThresholdConfig>,
) -> Option<LoadThresholdConfig> {
let mut result = None;
for entry in self.worker_sets.iter() {
if let Some(ref monitor) = entry.value().worker_monitor {
if let Some(cfg) = config {
monitor.set_load_threshold_config(cfg);
}
if result.is_none() {
result = Some(monitor.load_threshold_config());
}
}
}
result
}
/// Get the worker monitor for a specific namespace's WorkerSet.
pub fn get_worker_monitor_for_namespace(&self, namespace: &str) -> Option<KvWorkerMonitor> {
self.worker_sets
.get(namespace)
.and_then(|entry| entry.value().worker_monitor.clone())
}
/// Total worker count across all WorkerSets.
pub fn total_workers(&self) -> usize {
self.worker_sets
.iter()
.map(|entry| entry.value().worker_count())
.sum()
}
// -- Internal selection --
/// Select a WorkerSet and extract a value from it.
///
/// When there's only one set (steady state), returns from that set directly.
/// With multiple sets, uses weighted random selection proportional
/// to worker count, filtering to sets that have the requested engine.
///
/// The `extract` closure should return `Some(value)` if the WorkerSet has the
/// desired engine, or `None` if it doesn't.
fn select_worker_set_with<T, F>(&self, extract: F) -> Option<T>
where
F: Fn(&WorkerSet) -> Option<T>,
{
// Fast path: single set (same zero-worker filtering as the multi-set path below)
// TODO: When the single set has 0 workers, this returns None which maps to
// ModelNotFound (404). Ideally should be 503 "no available workers" — see follow-up.
if self.worker_sets.len() == 1 {
return self.worker_sets.iter().next().and_then(|entry| {
let ws = entry.value();
if ws.worker_count() == 0 {
return None;
}
extract(ws)
});
}
// Collect eligible sets with their worker counts, skipping sets with no workers.
// In-process models (no discovery watcher) return count=1, so they always participate.
// Discovery models with count=0 have no available workers and are skipped.
let eligible: Vec<(T, usize)> = self
.worker_sets
.iter()
.filter_map(|entry| {
let ws = entry.value();
let count = ws.worker_count();
if count == 0 {
return None;
}
extract(ws).map(|val| (val, count))
})
.collect();
if eligible.is_empty() {
return None;
}
if eligible.len() == 1 {
return eligible.into_iter().next().map(|(val, _)| val);
}
// Weighted random selection proportional to worker count
let total_weight: usize = eligible.iter().map(|(_, w)| w).sum();
let mut pick = rand::rng().random_range(0..total_weight);
for (val, weight) in eligible {
if pick < weight {
return Some(val);
}
pick -= weight;
}
// Should not reach here, but fallback to None
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_card::ModelDeploymentCard;
use tokio::sync::watch;
fn make_worker_set(namespace: &str, mdcsum: &str) -> Arc<WorkerSet> {
Arc::new(WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
))
}
/// Create a WorkerSet backed by a watch channel so worker_count reflects the vec length.
fn make_worker_set_with_count(
namespace: &str,
mdcsum: &str,
worker_ids: Vec<u64>,
) -> (Arc<WorkerSet>, watch::Sender<Vec<u64>>) {
let (tx, rx) = watch::channel(worker_ids);
let mut ws = WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
);
ws.set_instance_watcher(rx);
(Arc::new(ws), tx)
}
#[test]
fn test_model_new() {
let model = Model::new("llama".to_string());
assert_eq!(model.name(), "llama");
assert!(model.is_empty());
assert_eq!(model.worker_set_count(), 0);
}
#[test]
fn test_add_remove_worker_set() {
let model = Model::new("llama".to_string());
let ws = make_worker_set("ns1", "abc");
model.add_worker_set("ns1".to_string(), ws).unwrap();
assert!(!model.is_empty());
assert_eq!(model.worker_set_count(), 1);
assert!(model.has_worker_set("ns1"));
assert!(!model.has_worker_set("ns2"));
let removed = model.remove_worker_set("ns1");
assert!(removed.is_some());
assert!(model.is_empty());
let removed_again = model.remove_worker_set("ns1");
assert!(removed_again.is_none());
}
#[test]
fn test_get_worker_set() {
let model = Model::new("llama".to_string());
let ws = make_worker_set("ns1", "abc");
model.add_worker_set("ns1".to_string(), ws).unwrap();
let retrieved = model.get_worker_set("ns1");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().namespace(), "ns1");
assert!(model.get_worker_set("ns2").is_none());
}
#[test]
fn test_multiple_worker_sets_same_checksum() {
let model = Model::new("llama".to_string());
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
model
.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"))
.unwrap();
assert_eq!(model.worker_set_count(), 2);
assert!(model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
model.remove_worker_set("ns1");
assert_eq!(model.worker_set_count(), 1);
assert!(!model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
}
#[test]
fn test_add_worker_set_rejects_checksum_mismatch() {
let model = Model::new("llama".to_string());
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
// Different checksum from a different namespace should be rejected
let result = model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "def"));
assert!(result.is_err());
assert_eq!(model.worker_set_count(), 1); // ns2 was not added
}
#[test]
fn test_is_valid_checksum() {
let model = Model::new("llama".to_string());
// No canonical set yet
assert_eq!(model.is_valid_checksum("abc123"), None);
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc123"))
.unwrap();
// Matches canonical
assert_eq!(model.is_valid_checksum("abc123"), Some(true));
// Does not match canonical
assert_eq!(model.is_valid_checksum("wrong"), Some(false));
}
#[test]
fn test_no_engines_means_prefill() {
let model = Model::new("llama".to_string());
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
// WorkerSets with no engines are treated as prefill sets
assert!(model.has_prefill());
assert!(!model.has_decode_engine());
assert!(!model.has_chat_engine());
assert!(!model.has_completions_engine());
assert!(!model.has_embeddings_engine());
assert!(!model.has_tensor_engine());
assert!(!model.has_images_engine());
}
#[test]
fn test_get_engine_returns_error_without_engines() {
let model = Model::new("llama".to_string());
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
assert!(model.get_chat_engine().is_err());
assert!(model.get_completions_engine().is_err());
assert!(model.get_embeddings_engine().is_err());
assert!(model.get_images_engine().is_err());
assert!(model.get_tensor_engine().is_err());
}
#[test]
fn test_select_worker_set_with_extracts_namespace() {
// Test that select_worker_set_with works by going through the public API.
// Since we can't create real engines in tests, we verify that selection
// returns None/Err when no engines are configured, which exercises the
// filtering and selection code paths.
let model = Model::new("llama".to_string());
// Empty model
assert!(model.get_chat_engine().is_err());
// Single set (fast path)
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
assert!(model.get_chat_engine().is_err()); // No engine → filtered out
// Multiple sets (weighted path)
model
.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"))
.unwrap();
assert!(model.get_chat_engine().is_err()); // Still no engines → all filtered out
}
#[test]
fn test_total_workers_no_watcher() {
// In-process WorkerSets (no watcher) default to worker_count=1
let model = Model::new("llama".to_string());
assert_eq!(model.total_workers(), 0); // empty model
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
assert_eq!(model.total_workers(), 1);
model
.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"))
.unwrap();
assert_eq!(model.total_workers(), 2);
}
#[test]
fn test_total_workers_with_watcher() {
let model = Model::new("llama".to_string());
let (ws1, _tx1) = make_worker_set_with_count("ns1", "abc", vec![1, 2, 3]);
let (ws2, _tx2) = make_worker_set_with_count("ns2", "abc", vec![10, 20]);
model.add_worker_set("ns1".to_string(), ws1).unwrap();
model.add_worker_set("ns2".to_string(), ws2).unwrap();
assert_eq!(model.total_workers(), 5); // 3 + 2
}
#[test]
fn test_total_workers_updates_dynamically() {
let model = Model::new("llama".to_string());
let (ws1, tx1) = make_worker_set_with_count("ns1", "abc", vec![1, 2]);
model.add_worker_set("ns1".to_string(), ws1).unwrap();
assert_eq!(model.total_workers(), 2);
// Workers leave
tx1.send(vec![1]).unwrap();
assert_eq!(model.total_workers(), 1);
// All workers gone
tx1.send(vec![]).unwrap();
assert_eq!(model.total_workers(), 0);
}
#[test]
fn test_zero_worker_single_set_filtered() {
// Single WorkerSet with 0 workers should be filtered by select_worker_set_with.
// We test via select_worker_set_with's internal behavior: even though the set
// exists and is_prefill_set() returns true, engine accessors should fail because
// the zero-worker filter runs before the extract closure.
let model = Model::new("llama".to_string());
let (ws, _tx) = make_worker_set_with_count("ns1", "abc", vec![]);
model.add_worker_set("ns1".to_string(), ws).unwrap();
// WorkerSet exists but has 0 workers → selection filtered out → Err
assert!(model.get_chat_engine().is_err());
assert!(model.get_completions_engine().is_err());
}
#[test]
fn test_zero_worker_multi_set_filtered() {
// With multiple sets, only those with workers > 0 participate in selection.
let model = Model::new("llama".to_string());
let (ws1, _tx1) = make_worker_set_with_count("ns1", "abc", vec![]);
let (ws2, _tx2) = make_worker_set_with_count("ns2", "abc", vec![]);
model.add_worker_set("ns1".to_string(), ws1).unwrap();
model.add_worker_set("ns2".to_string(), ws2).unwrap();
// Both have 0 workers → all filtered → Err
assert!(model.get_chat_engine().is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use std::{collections::HashSet, sync::Arc};
use dashmap::{DashMap, mapref::entry::Entry};
use parking_lot::RwLock;
use tokio::sync::oneshot;
use super::worker_monitor::LoadThresholdConfig;
use super::{KvWorkerMonitor, RuntimeConfigWatch, runtime_config_watch};
use super::{KvWorkerMonitor, Model, RuntimeConfigWatch, WorkerSet, runtime_config_watch};
use dynamo_runtime::{
component::{Client, Endpoint, build_transport_type},
component::{Endpoint, build_transport_type},
discovery::DiscoverySpec,
prelude::DistributedRuntimeProvider,
protocols::EndpointId,
......@@ -27,7 +23,6 @@ use crate::{
},
local_model::runtime_config::DisaggregatedEndpoint,
model_card::ModelDeploymentCard,
model_type::ModelType,
types::{
generic::tensor::TensorStreamingEngine,
openai::{
......@@ -54,31 +49,34 @@ pub enum ModelManagerError {
#[error("Model already exists: {0}")]
ModelAlreadyExists(String),
#[error(
"Checksum mismatch for model {model}: expected {expected}, got {got}. All WorkerSets of a model must share the same checksum. Drain all old workers before deploying a new version."
)]
ChecksumMismatch {
model: String,
expected: String,
got: String,
},
}
/// Central manager for model engines, routing, and configuration.
///
/// Manages model lifecycle including engines, KV routers, prefill coordination,
/// and per-model busy thresholds for load-based request rejection.
/// Models are stored hierarchically: ModelManager → Model → WorkerSet.
/// Each WorkerSet owns a complete pipeline built from its specific configuration.
///
/// Note: Don't implement Clone for this, put it in an Arc instead.
pub struct ModelManager {
// We read a lot and write rarely, so these three are RwLock
completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>,
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
images_engines: RwLock<ModelEngines<OpenAIImagesStreamingEngine>>,
videos_engines: RwLock<ModelEngines<OpenAIVideosStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// Prefill models don't have engines - they're only tracked for discovery/lifecycle
prefill_engines: RwLock<ModelEngines<()>>,
/// Model name → Model (which contains WorkerSets with engines)
models: DashMap<String, Arc<Model>>,
/// Per-instance model cards, keyed by instance path. Used for cleanup on worker removal.
cards: DashMap<String, ModelDeploymentCard>,
kv_choosers: DashMap<EndpointId, Arc<KvRouter>>,
/// Prefill router activation rendezvous, keyed by "model_name:namespace".
prefill_router_activators: DashMap<String, PrefillActivationState>,
// Per-model monitoring: worker_monitors for load-based rejection, runtime_configs for KvScheduler
worker_monitors: DashMap<String, KvWorkerMonitor>,
/// Per-endpoint runtime config watchers. Keyed by EndpointId (includes namespace).
runtime_configs: DashMap<EndpointId, RuntimeConfigWatch>,
}
......@@ -91,140 +89,324 @@ impl Default for ModelManager {
impl ModelManager {
pub fn new() -> Self {
Self {
completion_engines: RwLock::new(ModelEngines::default()),
chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()),
images_engines: RwLock::new(ModelEngines::default()),
videos_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()),
prefill_engines: RwLock::new(ModelEngines::default()),
models: DashMap::new(),
cards: DashMap::new(),
kv_choosers: DashMap::new(),
prefill_router_activators: DashMap::new(),
worker_monitors: DashMap::new(),
runtime_configs: DashMap::new(),
}
}
pub fn is_valid_checksum(
// -- Model access --
/// Get or create a Model for the given name.
pub fn get_or_create_model(&self, model_name: &str) -> Arc<Model> {
self.models
.entry(model_name.to_string())
.or_insert_with(|| Arc::new(Model::new(model_name.to_string())))
.clone()
}
/// Get an existing Model, if it exists.
pub fn get_model(&self, model_name: &str) -> Option<Arc<Model>> {
self.models
.get(model_name)
.map(|entry| entry.value().clone())
}
/// Remove a Model if it has no remaining WorkerSets.
/// Uses atomic remove_if to avoid TOCTOU race between checking is_empty and removing.
pub fn remove_model_if_empty(&self, model_name: &str) {
if self
.models
.remove_if(model_name, |_, model| model.is_empty())
.is_some()
{
tracing::info!(model_name, "Removed empty model from manager");
}
}
/// Add a WorkerSet to a Model. Creates the Model if it doesn't exist.
/// Returns `Err` if the WorkerSet's checksum doesn't match the model's canonical checksum.
pub fn add_worker_set(
&self,
model_type: ModelType,
model_name: &str,
candidate_checksum: &str,
) -> Option<bool> {
let mut results = vec![];
for unit in model_type.units() {
let maybe_valid_checksum = match unit {
ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
ModelType::Images => self.images_engines.read().checksum(model_name),
ModelType::Videos => self.videos_engines.read().checksum(model_name),
ModelType::Prefill => self.prefill_engines.read().checksum(model_name),
_ => {
continue;
}
};
if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
tracing::debug!(
model_name,
valid_checksum,
candidate_checksum,
"is_valid_checksum: check case"
);
valid_checksum == candidate_checksum
}) {
results.push(is_valid)
}
}
if results.is_empty() {
None
} else {
// The checksum is valid if it is correct for all the ModelType in the bitflag.
Some(results.into_iter().all(|x| x))
}
namespace: &str,
worker_set: WorkerSet,
) -> Result<(), ModelManagerError> {
let model = self.get_or_create_model(model_name);
model.add_worker_set(namespace.to_string(), Arc::new(worker_set))
}
/// Remove a WorkerSet from a Model. Removes the Model if it becomes empty.
pub fn remove_worker_set(&self, model_name: &str, namespace: &str) -> Option<Arc<WorkerSet>> {
let model = self.models.get(model_name)?;
let removed = model.remove_worker_set(namespace);
drop(model);
self.remove_model_if_empty(model_name);
removed
}
// -- Checksum validation --
/// Check if a candidate checksum is valid for a model.
/// Returns `Some(true)` if it matches the model's canonical checksum, `Some(false)` if it
/// doesn't match, or `None` if the model doesn't exist or has no canonical checksum yet.
pub fn is_valid_checksum(&self, model_name: &str, candidate_checksum: &str) -> Option<bool> {
let model = self.models.get(model_name)?;
model.is_valid_checksum(candidate_checksum)
}
// -- Model cards --
pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.cards.iter().map(|r| r.value().clone()).collect()
}
/// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.insert(key.to_string(), card);
Ok(())
}
/// Remove and return model card for this instance's key. We do this when the instance stops.
pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
self.cards.remove(key).map(|(_, v)| v)
}
// -- Engine accessors (delegate through Model → WorkerSet) --
/// Check if a decode model (chat or completions) is registered
pub fn has_decode_model(&self, model: &str) -> bool {
self.chat_completion_engines.read().contains(model)
|| self.completion_engines.read().contains(model)
self.models
.get(model)
.is_some_and(|m| m.has_decode_engine())
}
/// Check if a prefill model is registered
pub fn has_prefill_model(&self, model: &str) -> bool {
self.prefill_engines.read().contains(model)
self.models.get(model).is_some_and(|m| m.has_prefill())
}
/// Check if any model (decode or prefill) is registered.
/// Note: For registration skip-checks, use has_decode_model() or has_prefill_model() instead.
pub fn has_model_any(&self, model: &str) -> bool {
self.has_decode_model(model) || self.has_prefill_model(model)
}
pub fn model_display_names(&self) -> HashSet<String> {
self.list_chat_completions_models()
.into_iter()
.chain(self.list_completions_models())
.chain(self.list_embeddings_models())
.chain(self.list_images_models())
.chain(self.list_videos_models())
.chain(self.list_tensor_models())
.chain(self.list_prefill_models())
.collect()
let mut names = HashSet::new();
for entry in self.models.iter() {
let model = entry.value();
if model.has_chat_engine()
|| model.has_completions_engine()
|| model.has_embeddings_engine()
|| model.has_images_engine()
|| model.has_tensor_engine()
|| model.has_videos_engine()
|| model.has_prefill()
{
names.insert(entry.key().clone());
}
}
names
}
pub fn list_chat_completions_models(&self) -> Vec<String> {
self.chat_completion_engines.read().list()
self.models
.iter()
.filter(|entry| entry.value().has_chat_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_completions_models(&self) -> Vec<String> {
self.completion_engines.read().list()
self.models
.iter()
.filter(|entry| entry.value().has_completions_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_embeddings_models(&self) -> Vec<String> {
self.embeddings_engines.read().list()
self.models
.iter()
.filter(|entry| entry.value().has_embeddings_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_tensor_models(&self) -> Vec<String> {
self.tensor_engines.read().list()
self.models
.iter()
.filter(|entry| entry.value().has_tensor_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_images_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_images_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_videos_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_videos_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_prefill_models(&self) -> Vec<String> {
self.prefill_engines.read().list()
self.models
.iter()
.filter(|entry| entry.value().has_prefill())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_images_models(&self) -> Vec<String> {
self.images_engines.read().list()
pub fn get_embeddings_engine(
&self,
model: &str,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_embeddings_engine()
}
pub fn list_videos_models(&self) -> Vec<String> {
self.videos_engines.read().list()
pub fn get_completions_engine(
&self,
model: &str,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_completions_engine()
}
pub fn add_completions_model(
pub fn get_chat_completions_engine(
&self,
model: &str,
card_checksum: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write();
clients.add(model, card_checksum, engine)
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_chat_engine()
}
pub fn get_tensor_engine(
&self,
model: &str,
) -> Result<TensorStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_tensor_engine()
}
pub fn get_images_engine(
&self,
model: &str,
) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_images_engine()
}
pub fn get_videos_engine(
&self,
model: &str,
) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_videos_engine()
}
// -- Combined engine + parsing options (atomically from one WorkerSet) --
pub fn get_chat_completions_engine_with_parsing(
&self,
model: &str,
) -> Result<
(
OpenAIChatCompletionsStreamingEngine,
crate::protocols::openai::ParsingOptions,
),
ModelManagerError,
> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_chat_engine_with_parsing()
}
pub fn get_completions_engine_with_parsing(
&self,
model: &str,
) -> Result<
(
OpenAICompletionsStreamingEngine,
crate::protocols::openai::ParsingOptions,
),
ModelManagerError,
> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_completions_engine_with_parsing()
}
// -- Convenience methods for in-process models (http.rs, grpc.rs) --
// These create a WorkerSet with a default namespace for local models.
// TODO: These methods use ModelDeploymentCard::default() for the WorkerSet, which means
// parsing_options() returns defaults (no tool_call_parser/reasoning_parser). Pass the real
// MDC from callers so ParsingOptions reflect the model's actual configuration.
pub fn add_chat_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write();
clients.add(model, card_checksum, engine)
let model_entry = self.get_or_create_model(model);
if model_entry.has_chat_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_chat_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.chat_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
}
pub fn add_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_completions_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_completions_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.completions_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
}
pub fn add_embeddings_model(
......@@ -233,8 +415,19 @@ impl ModelManager {
card_checksum: &str,
engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write();
clients.add(model, card_checksum, engine)
let model_entry = self.get_or_create_model(model);
if model_entry.has_embeddings_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_embeddings_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.embeddings_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
}
pub fn add_tensor_model(
......@@ -243,8 +436,19 @@ impl ModelManager {
card_checksum: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write();
clients.add(model, card_checksum, engine)
let model_entry = self.get_or_create_model(model);
if model_entry.has_tensor_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_tensor_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.tensor_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
}
pub fn add_images_model(
......@@ -253,8 +457,19 @@ impl ModelManager {
card_checksum: &str,
engine: OpenAIImagesStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.images_engines.write();
clients.add(model, card_checksum, engine)
let model_entry = self.get_or_create_model(model);
if model_entry.has_images_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_images_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.images_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
}
pub fn add_videos_model(
......@@ -263,8 +478,19 @@ impl ModelManager {
card_checksum: &str,
engine: OpenAIVideosStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.videos_engines.write();
clients.add(model, card_checksum, engine)
let model_entry = self.get_or_create_model(model);
if model_entry.has_videos_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_videos_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.videos_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
}
pub fn add_prefill_model(
......@@ -272,122 +498,74 @@ impl ModelManager {
model: &str,
card_checksum: &str,
) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write();
clients.add(model, card_checksum, ())
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write();
clients.remove(model)
}
pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write();
clients.remove(model)
}
pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write();
clients.remove(model)
let model_entry = self.get_or_create_model(model);
if model_entry.has_prefill() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_prefill_{}", model);
let ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
}
pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write();
clients.remove(model)
}
// -- Model removal --
pub fn remove_images_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.images_engines.write();
clients.remove(model)
/// Remove a model entirely (all its WorkerSets).
/// Returns the removed Model, or None if not found.
pub fn remove_model(&self, model: &str) -> Option<Arc<Model>> {
self.models.remove(model).map(|(_, m)| m)
}
pub fn remove_videos_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.videos_engines.write();
clients.remove(model)
}
// Per-type remove methods for in-process models (used by Python bindings).
// These remove the specific synthetic WorkerSet created by the corresponding add_*_model method.
pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write();
clients.remove(model)
}
pub fn get_embeddings_engine(
&self,
model: &str,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.embeddings_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_completions_engine(
&self,
model: &str,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.completion_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_chat_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_chat_completions_engine(
&self,
model: &str,
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.chat_completion_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_completions_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_tensor_engine(
&self,
model: &str,
) -> Result<TensorStreamingEngine, ModelManagerError> {
self.tensor_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_tensor_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_images_engine(
&self,
model: &str,
) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.images_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_embeddings_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_videos_engine(
&self,
model: &str,
) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
self.videos_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
pub fn remove_images_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_images_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.insert(key.to_string(), card);
Ok(())
pub fn remove_videos_model(&self, model: &str) -> Result<(), ModelManagerError> {
let namespace = format!("__local_videos_{}", model);
self.remove_worker_set(model, &namespace)
.map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
}
/// Remove and return model card for this instance's key. We do this when the instance stops.
pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
self.cards.remove(key).map(|(_, v)| v)
}
// -- KV Router creation --
pub async fn kv_chooser_for(
&self,
......@@ -396,25 +574,9 @@ impl ModelManager {
kv_router_config: Option<KvRouterConfig>,
worker_type: &'static str,
) -> anyhow::Result<Arc<KvRouter>> {
let endpoint_id = endpoint.id();
if let Some(kv_chooser) = self.get_kv_chooser(&endpoint_id) {
// Check if the existing router has a different block size
if kv_chooser.block_size() != kv_cache_block_size {
tracing::warn!(
endpoint = %endpoint_id,
existing_block_size = %kv_chooser.block_size(),
requested_block_size = %kv_cache_block_size,
"KV Router block size mismatch! Endpoint is requesting a different kv_cache_block_size than the existing router. \
This will cause routing to fail silently. Consider using the same block size or restarting the router."
);
}
return Ok(kv_chooser);
}
let client = endpoint.client().await?;
// Register router via discovery mechanism
// Register router via discovery mechanism.
let discovery = endpoint.component().drt().discovery();
let instance_id = discovery.instance_id();
......@@ -433,7 +595,7 @@ impl ModelManager {
discovery.register(discovery_spec).await?;
// Get or create runtime config watcher for this endpoint
// Get of create runtime config watcher for this endpoint
let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
......@@ -447,28 +609,35 @@ impl ModelManager {
worker_type,
)
.await?;
let new_kv_chooser = Arc::new(chooser);
self.kv_choosers.insert(endpoint_id, new_kv_chooser.clone());
Ok(new_kv_chooser)
Ok(Arc::new(chooser))
}
fn get_kv_chooser(&self, id: &EndpointId) -> Option<Arc<KvRouter>> {
self.kv_choosers.get(id).map(|r| r.value().clone())
// -- Prefill router coordination --
// Keyed by "model_name:namespace" so each namespace's decode WorkerSet gets its own
// prefill router activated by same-namespace prefill workers.
/// Build a key for a (model, namespace) pair. Used for prefill router activators
/// and registration guards.
pub(crate) fn model_namespace_key(model_name: &str, namespace: &str) -> String {
format!("{}:{}", model_name, namespace)
}
/// Register a prefill router for a decode model. Returns a receiver that will be
/// activated when the corresponding prefill model is discovered.
/// Returns None if the decode model was already registered.
/// Register a prefill router for a decode WorkerSet. Returns a receiver that will be
/// activated when the corresponding prefill model in the same namespace is discovered.
/// Returns None if a decode WorkerSet in this namespace was already registered.
pub fn register_prefill_router(
&self,
model_name: String,
model_name: &str,
namespace: &str,
) -> Option<oneshot::Receiver<Endpoint>> {
match self.prefill_router_activators.remove(&model_name) {
let key = Self::model_namespace_key(model_name, namespace);
match self.prefill_router_activators.remove(&key) {
Some((_, PrefillActivationState::PrefillReady(rx))) => {
// Prefill endpoint already arrived - rx will immediately resolve
tracing::debug!(
model_name = %model_name,
"Prefill endpoint already available, returning receiver with endpoint"
namespace = %namespace,
"Prefill endpoint already available for namespace, returning receiver"
);
Some(rx)
}
......@@ -476,7 +645,8 @@ impl ModelManager {
// Decode already registered - this shouldn't happen, restore state and return None
tracing::error!(
model_name = %model_name,
"Decode model already registered for this prefill router"
namespace = %namespace,
"Decode WorkerSet already registered for this prefill router"
);
self.prefill_router_activators
.insert(key, PrefillActivationState::DecodeWaiting(tx));
......@@ -485,13 +655,12 @@ impl ModelManager {
None => {
// New registration: create tx/rx pair, store sender and return receiver
let (tx, rx) = oneshot::channel();
self.prefill_router_activators.insert(
model_name.clone(),
PrefillActivationState::DecodeWaiting(tx),
);
self.prefill_router_activators
.insert(key, PrefillActivationState::DecodeWaiting(tx));
tracing::debug!(
model_name = %model_name,
"No prefill endpoint available yet, storing sender for future activation"
namespace = %namespace,
"No prefill endpoint for namespace yet, storing sender for future activation"
);
Some(rx)
}
......@@ -499,115 +668,107 @@ impl ModelManager {
}
/// Activate a prefill router by sending the endpoint through the oneshot channel.
/// If no decode model has registered yet, stores the endpoint for future retrieval.
/// The namespace must match the decode WorkerSet's namespace.
pub fn activate_prefill_router(
&self,
model_name: &str,
namespace: &str,
endpoint: Endpoint,
) -> anyhow::Result<()> {
match self.prefill_router_activators.remove(model_name) {
let key = Self::model_namespace_key(model_name, namespace);
match self.prefill_router_activators.remove(&key) {
Some((_, PrefillActivationState::DecodeWaiting(sender))) => {
// Decode model already registered
sender.send(endpoint).map_err(|_| {
anyhow::anyhow!(
"Failed to send endpoint to prefill router activator for model: {}",
model_name
"Failed to send endpoint to prefill router activator for {}:{}",
model_name,
namespace
)
})?;
tracing::info!(
model_name = %model_name,
"Activated prefill router for already-registered decode model"
namespace = %namespace,
"Activated prefill router for decode WorkerSet"
);
Ok(())
}
Some((_, PrefillActivationState::PrefillReady(_))) => {
// Prefill already activated - this shouldn't happen
anyhow::bail!("Prefill router for model {} already activated", model_name);
anyhow::bail!(
"Prefill router for {}:{} already activated",
model_name,
namespace
);
}
None => {
// Decode model not registered yet - create pair and immediately send endpoint
let (tx, rx) = oneshot::channel();
tx.send(endpoint).map_err(|_| {
anyhow::anyhow!("Failed to send endpoint for prefill model: {}", model_name)
anyhow::anyhow!(
"Failed to send endpoint for prefill model {}:{}",
model_name,
namespace
)
})?;
// Store the receiver for when decode model registers
self.prefill_router_activators.insert(
model_name.to_string(),
PrefillActivationState::PrefillReady(rx),
);
self.prefill_router_activators
.insert(key, PrefillActivationState::PrefillReady(rx));
tracing::info!(
model_name = %model_name,
"Stored prefill endpoint for future decode model registration"
namespace = %namespace,
"Stored prefill endpoint for future decode WorkerSet registration"
);
Ok(())
}
}
}
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
self.cards
.iter()
.find(|r| r.value().display_name == model)
.and_then(|r| r.value().runtime_config.tool_call_parser.clone())
}
pub fn get_model_reasoning_parser(&self, model: &str) -> Option<String> {
self.cards
.iter()
.find(|r| r.value().display_name == model)
.and_then(|r| r.value().runtime_config.reasoning_parser.clone())
/// Remove the prefill router activator for a (model, namespace) pair.
/// Called when a WorkerSet is removed to prevent stale activators.
pub fn remove_prefill_activator(&self, model_name: &str, namespace: &str) {
let key = Self::model_namespace_key(model_name, namespace);
if self.prefill_router_activators.remove(&key).is_some() {
tracing::debug!(
model_name = %model_name,
namespace = %namespace,
"Cleaned up prefill router activator for removed WorkerSet"
);
}
}
/// Creates parsing options with tool call parser and reasoning parser for the specified model.
pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
let tool_call_parser = self.get_model_tool_call_parser(model);
let reasoning_parser = self.get_model_reasoning_parser(model);
crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
}
// -- Worker monitoring --
/// Gets or sets the load threshold config for a model's worker monitor.
/// Pass `Some(config)` to update, `None` to get. Returns `None` if no monitor exists.
/// Checks across all WorkerSets for the model.
pub fn load_threshold_config(
&self,
model: &str,
config: Option<&LoadThresholdConfig>,
) -> Option<LoadThresholdConfig> {
let monitor = self.worker_monitors.get(model)?;
if let Some(cfg) = config {
monitor.set_load_threshold_config(cfg);
}
Some(monitor.load_threshold_config())
let model_entry = self.models.get(model)?;
model_entry.load_threshold_config(config)
}
/// Gets an existing worker monitor for a model, if one exists.
pub fn get_worker_monitor(&self, model: &str) -> Option<KvWorkerMonitor> {
self.worker_monitors.get(model).map(|m| m.clone())
}
/// Gets or creates a worker monitor for a model. Updates thresholds if monitor exists.
pub fn get_or_create_worker_monitor(
/// Gets an existing worker monitor for a specific namespace of a model.
pub fn get_worker_monitor_for_namespace(
&self,
model: &str,
client: Client,
config: LoadThresholdConfig,
) -> KvWorkerMonitor {
if let Some(existing) = self.worker_monitors.get(model) {
existing.set_load_threshold_config(&config);
return existing.clone();
namespace: &str,
) -> Option<KvWorkerMonitor> {
let model_entry = self.models.get(model)?;
model_entry.get_worker_monitor_for_namespace(namespace)
}
/// Lists all models with worker monitors configured.
pub fn list_busy_thresholds(&self) -> Vec<(String, LoadThresholdConfig)> {
let mut result = Vec::new();
for entry in self.models.iter() {
if let Some(config) = entry.value().load_threshold_config(None) {
result.push((entry.key().clone(), config));
}
}
let monitor = KvWorkerMonitor::new(client, config);
self.worker_monitors
.insert(model.to_string(), monitor.clone());
monitor
result
}
// -- Runtime configs --
/// Get or create a runtime config watcher for an endpoint.
/// Spawns a background task that joins instance availability and config discovery.
/// Returns a `watch::Receiver` with the latest `HashMap<WorkerId, ModelRuntimeConfig>`.
......@@ -617,7 +778,6 @@ impl ModelManager {
) -> anyhow::Result<RuntimeConfigWatch> {
let endpoint_id = endpoint.id();
// Fast path: return existing if present
if let Some(existing) = self.runtime_configs.get(&endpoint_id) {
return Ok(existing.clone());
}
......@@ -638,7 +798,6 @@ impl ModelManager {
}
/// Get disaggregated endpoint for a specific worker.
/// Used by PrefillRouter for bootstrap info - works for ANY routing mode.
pub fn get_disaggregated_endpoint(
&self,
endpoint_id: &EndpointId,
......@@ -648,79 +807,348 @@ impl ModelManager {
let configs = rx.borrow();
configs.get(&worker_id)?.disaggregated_endpoint.clone()
}
}
/// Lists all models with worker monitors configured.
pub fn list_busy_thresholds(&self) -> Vec<(String, LoadThresholdConfig)> {
self.worker_monitors
.iter()
.map(|entry| (entry.key().clone(), entry.value().load_threshold_config()))
.collect()
#[cfg(test)]
mod tests {
use super::*;
use crate::model_card::ModelDeploymentCard;
fn make_worker_set(namespace: &str, mdcsum: &str) -> WorkerSet {
WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
)
}
}
pub struct ModelEngines<E> {
/// Optional default model name
default: Option<String>,
engines: HashMap<String, E>,
/// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
/// same card.
checksums: HashMap<String, String>,
}
// -- CRUD delegation tests --
impl<E> Default for ModelEngines<E> {
fn default() -> Self {
Self {
default: None,
engines: HashMap::new(),
checksums: HashMap::new(),
}
#[test]
fn test_add_and_get_worker_set() {
let mm = ModelManager::new();
let ws = make_worker_set("ns1", "abc");
mm.add_worker_set("llama", "ns1", ws).unwrap();
let model = mm.get_model("llama");
assert!(model.is_some());
let model = model.unwrap();
assert!(model.has_worker_set("ns1"));
assert_eq!(model.worker_set_count(), 1);
}
}
impl<E> ModelEngines<E> {
#[allow(dead_code)]
fn set_default(&mut self, model: &str) {
self.default = Some(model.to_string());
#[test]
fn test_add_worker_set_creates_model() {
let mm = ModelManager::new();
assert!(mm.get_model("llama").is_none());
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(mm.get_model("llama").is_some());
}
#[allow(dead_code)]
fn clear_default(&mut self) {
self.default = None;
#[test]
fn test_remove_worker_set_removes_empty_model() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(mm.get_model("llama").is_some());
let removed = mm.remove_worker_set("llama", "ns1");
assert!(removed.is_some());
assert_eq!(removed.unwrap().namespace(), "ns1");
// Model should be auto-removed since it's now empty
assert!(mm.get_model("llama").is_none());
}
fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
if self.engines.contains_key(model) {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
self.engines.insert(model.to_string(), engine);
self.checksums
.insert(model.to_string(), checksum.to_string());
Ok(())
#[test]
fn test_remove_worker_set_keeps_model_with_remaining() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"))
.unwrap();
mm.remove_worker_set("llama", "ns1");
// Model should still exist with ns2
let model = mm.get_model("llama").unwrap();
assert!(!model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
assert_eq!(model.worker_set_count(), 1);
}
fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> {
if self.engines.remove(model).is_none() {
return Err(ModelManagerError::ModelNotFound(model.to_string()));
}
let _ = self.checksums.remove(model);
Ok(())
#[test]
fn test_remove_worker_set_nonexistent_model() {
let mm = ModelManager::new();
assert!(mm.remove_worker_set("llama", "ns1").is_none());
}
#[test]
fn test_remove_worker_set_nonexistent_namespace() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(mm.remove_worker_set("llama", "ns2").is_none());
// Model should still exist (ns1 still there)
assert!(mm.get_model("llama").is_some());
}
#[test]
fn test_remove_model_if_empty_noop_when_not_empty() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
mm.remove_model_if_empty("llama");
assert!(mm.get_model("llama").is_some()); // Still has ns1
}
fn get(&self, model: &str) -> Option<&E> {
self.engines.get(model)
#[test]
fn test_remove_model_if_empty_noop_when_missing() {
let mm = ModelManager::new();
mm.remove_model_if_empty("nonexistent"); // Should not panic
}
fn contains(&self, model: &str) -> bool {
self.engines.contains_key(model)
#[test]
fn test_remove_model() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"))
.unwrap();
let removed = mm.remove_model("llama");
assert!(removed.is_some());
assert!(mm.get_model("llama").is_none());
}
#[test]
fn test_get_or_create_model_idempotent() {
let mm = ModelManager::new();
let m1 = mm.get_or_create_model("llama");
let m2 = mm.get_or_create_model("llama");
// Both should point to the same Model (same Arc)
assert!(Arc::ptr_eq(&m1, &m2));
}
// -- Checksum validation tests --
#[test]
fn test_is_valid_checksum_match() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc123"))
.unwrap();
assert_eq!(mm.is_valid_checksum("llama", "abc123"), Some(true));
}
#[test]
fn test_is_valid_checksum_mismatch() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc123"))
.unwrap();
assert_eq!(mm.is_valid_checksum("llama", "wrong"), Some(false));
}
#[test]
fn test_is_valid_checksum_no_canonical_yet() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc123"))
.unwrap();
// Canonical is set, so even for a "new namespace" scenario the checksum is checked
assert_eq!(mm.is_valid_checksum("llama", "abc123"), Some(true));
assert_eq!(mm.is_valid_checksum("llama", "xyz"), Some(false));
}
#[test]
fn test_is_valid_checksum_missing_model() {
let mm = ModelManager::new();
assert_eq!(mm.is_valid_checksum("nonexistent", "abc"), None);
}
#[test]
fn test_is_valid_checksum_cross_namespace_enforcement() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "checksum_a"))
.unwrap();
// A different namespace with a different checksum should be rejected at the model level
assert_eq!(mm.is_valid_checksum("llama", "checksum_b"), Some(false));
// Same checksum is accepted
assert_eq!(mm.is_valid_checksum("llama", "checksum_a"), Some(true));
}
// -- Model listing and filtering tests --
#[test]
fn test_has_decode_model() {
let mm = ModelManager::new();
// No model → false
assert!(!mm.has_decode_model("llama"));
// Prefill-only set (no engines) → false
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(!mm.has_decode_model("llama"));
}
#[test]
fn test_has_prefill_model() {
let mm = ModelManager::new();
// Prefill set = no engines
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(mm.has_prefill_model("llama"));
}
#[test]
fn test_has_model_any() {
let mm = ModelManager::new();
assert!(!mm.has_model_any("llama"));
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(mm.has_model_any("llama")); // has prefill
}
#[test]
fn test_model_display_names_includes_prefill() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
let names = mm.model_display_names();
assert!(names.contains("llama"));
}
#[test]
fn test_model_display_names_empty() {
let mm = ModelManager::new();
assert!(mm.model_display_names().is_empty());
}
#[test]
fn test_list_prefill_models() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
mm.add_worker_set("gpt", "ns1", make_worker_set("ns1", "def"))
.unwrap();
let prefill = mm.list_prefill_models();
assert_eq!(prefill.len(), 2);
assert!(prefill.contains(&"llama".to_string()));
assert!(prefill.contains(&"gpt".to_string()));
}
// -- Model card tests --
#[test]
fn test_save_and_remove_model_card() {
let mm = ModelManager::new();
let card = ModelDeploymentCard::default();
mm.save_model_card("instance/key/1", card.clone()).unwrap();
let cards = mm.get_model_cards();
assert_eq!(cards.len(), 1);
let removed = mm.remove_model_card("instance/key/1");
assert!(removed.is_some());
assert!(mm.get_model_cards().is_empty());
}
#[test]
fn test_remove_model_card_nonexistent() {
let mm = ModelManager::new();
assert!(mm.remove_model_card("nonexistent").is_none());
}
// -- Prefill router rendezvous tests --
// Note: activate_prefill_router requires an Endpoint (needs DistributedRuntime),
// so we test the registration state machine and cleanup only.
#[test]
fn test_prefill_router_register_new() {
let mm = ModelManager::new();
// First registration for a (model, namespace) returns Some(rx)
let rx = mm.register_prefill_router("llama", "ns1");
assert!(rx.is_some());
}
#[test]
fn test_prefill_router_double_register_returns_none() {
let mm = ModelManager::new();
let rx1 = mm.register_prefill_router("llama", "ns1");
assert!(rx1.is_some());
// Second registration for the same (model, namespace) returns None
let rx2 = mm.register_prefill_router("llama", "ns1");
assert!(rx2.is_none());
}
#[test]
fn test_prefill_router_different_namespaces_independent() {
let mm = ModelManager::new();
// Different namespaces should be independent
let rx1 = mm.register_prefill_router("llama", "ns1");
let rx2 = mm.register_prefill_router("llama", "ns2");
assert!(rx1.is_some());
assert!(rx2.is_some());
}
#[test]
fn test_prefill_router_different_models_independent() {
let mm = ModelManager::new();
// Different models should be independent
let rx1 = mm.register_prefill_router("llama", "ns1");
let rx2 = mm.register_prefill_router("gpt", "ns1");
assert!(rx1.is_some());
assert!(rx2.is_some());
}
#[test]
fn test_prefill_router_remove_allows_reregister() {
let mm = ModelManager::new();
let rx = mm.register_prefill_router("llama", "ns1");
assert!(rx.is_some());
// Remove the activator
mm.remove_prefill_activator("llama", "ns1");
// Should be able to register again
let rx2 = mm.register_prefill_router("llama", "ns1");
assert!(rx2.is_some());
}
pub fn list(&self) -> Vec<String> {
self.engines.keys().map(|k| k.to_owned()).collect()
#[test]
fn test_prefill_router_remove_nonexistent_noop() {
let mm = ModelManager::new();
// Should not panic
mm.remove_prefill_activator("llama", "ns1");
}
/// Returns a newly allocated String for called convenience. All the places I use
/// this I need a String.
pub fn checksum(&self, model: &str) -> Option<String> {
self.checksums.get(model).map(|s| s.to_string())
#[test]
fn test_model_namespace_key_format() {
assert_eq!(
ModelManager::model_namespace_key("llama", "ns1"),
"llama:ns1"
);
assert_eq!(
ModelManager::model_namespace_key("gpt-4", "default-abc"),
"gpt-4:default-abc"
);
}
}
......@@ -24,7 +24,7 @@ use dynamo_runtime::{
use crate::{
backend::Backend,
discovery::WORKER_TYPE_DECODE,
discovery::{KvWorkerMonitor, WORKER_TYPE_DECODE, WorkerSet},
entrypoint::{self, ChatEngineFactoryCallback, RouterConfig},
http::service::metrics::Metrics,
kv_router::PrefillRouter,
......@@ -47,7 +47,17 @@ use crate::{
};
use super::ModelManager;
use crate::namespace::is_global_namespace;
use crate::namespace::NamespaceFilter;
/// Constructs the WorkerSet storage key. Prefill and decode workers in the same
/// namespace get different keys so they don't block each other's registration.
fn worker_set_key(namespace: &str, model_type: ModelType) -> String {
if model_type.supports_prefill() {
format!("{}:prefill", namespace)
} else {
namespace.to_string()
}
}
#[derive(Debug, Clone)]
pub enum ModelUpdate {
......@@ -64,7 +74,8 @@ pub struct ModelWatcher {
model_update_tx: Option<Sender<ModelUpdate>>,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
metrics: Arc<Metrics>,
registering_models: DashSet<String>,
/// Guards against concurrent pipeline construction for the same (model, namespace).
registering_worker_sets: DashSet<String>,
}
const ALL_MODEL_TYPES: &[ModelType] = &[
......@@ -78,6 +89,27 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Prefill,
];
/// Returns true if no models in the manager support the given model type.
fn is_model_type_list_empty(manager: &ModelManager, model_type: ModelType) -> bool {
if model_type == ModelType::Chat {
manager.list_chat_completions_models().is_empty()
} else if model_type == ModelType::Completions {
manager.list_completions_models().is_empty()
} else if model_type == ModelType::Embedding {
manager.list_embeddings_models().is_empty()
} else if model_type == ModelType::Images {
manager.list_images_models().is_empty()
} else if model_type == ModelType::Videos {
manager.list_videos_models().is_empty()
} else if model_type == ModelType::TensorBased {
manager.list_tensor_models().is_empty()
} else if model_type == ModelType::Prefill {
manager.list_prefill_models().is_empty()
} else {
true
}
}
impl ModelWatcher {
pub fn new(
runtime: DistributedRuntime,
......@@ -96,7 +128,7 @@ impl ModelWatcher {
model_update_tx: None,
chat_engine_factory,
metrics,
registering_models: DashSet::new(),
registering_worker_sets: DashSet::new(),
}
}
......@@ -119,10 +151,8 @@ impl ModelWatcher {
pub async fn watch(
&self,
mut discovery_stream: DiscoveryStream,
target_namespace: Option<&str>,
namespace_filter: NamespaceFilter,
) {
let global_namespace = target_namespace.is_none_or(is_global_namespace);
while let Some(result) = discovery_stream.next().await {
let event = match result {
Ok(event) => event,
......@@ -168,28 +198,27 @@ impl ModelWatcher {
}
};
// Filter by namespace if target_namespace is specified
if !global_namespace
&& let Some(target_ns) = target_namespace
&& mcid.namespace != target_ns
{
// Filter by namespace using the configured filter
if !namespace_filter.matches(&mcid.namespace) {
tracing::debug!(
model_namespace = mcid.namespace,
target_namespace = target_ns,
"Skipping model from different namespace"
namespace_filter = ?namespace_filter,
"Skipping model due to namespace filter"
);
continue;
}
// If we already have a worker for this model, and the ModelDeploymentCard
// cards don't match, alert, and don't add the new instance
let can_add =
self.manager
.is_valid_checksum(card.model_type, card.name(), card.mdcsum());
// If we already have a WorkerSet for this model and the checksums
// don't match, reject the new worker. All WorkerSets of a model
// must share the same checksum.
let can_add = self.manager.is_valid_checksum(card.name(), card.mdcsum());
if can_add.is_some_and(|is_valid| !is_valid) {
tracing::error!(
model_name = card.name(),
"Checksum for new model does not match existing model."
namespace = mcid.namespace,
"Checksum for new worker does not match model's canonical checksum. \
All WorkerSets must share the same checksum. \
Drain all old workers before deploying a new version."
);
// TODO: mark that instance down in clients
......@@ -199,7 +228,6 @@ impl ModelWatcher {
// needs more testing).
// The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside
// interface `AsyncEngine` which only has a `generate` method.
continue;
}
......@@ -235,7 +263,7 @@ impl ModelWatcher {
};
match self
.handle_delete(model_card_instance_id, target_namespace, global_namespace)
.handle_delete(model_card_instance_id, &namespace_filter)
.await
{
Ok(Some(model_name)) => {
......@@ -253,13 +281,12 @@ impl ModelWatcher {
}
}
/// If the last instance running this model has gone delete it.
/// Returns the name of the model we just deleted, if any.
/// Handle a worker removal. Cleans up per-namespace WorkerSets and the Model itself
/// when no instances remain. Returns the model name if the entire Model was removed.
async fn handle_delete(
&self,
mcid: &ModelCardInstanceId,
target_namespace: Option<&str>,
is_global_namespace: bool,
namespace_filter: &NamespaceFilter,
) -> anyhow::Result<Option<String>> {
let key = mcid.to_path();
let card = match self.manager.remove_model_card(&key) {
......@@ -269,89 +296,55 @@ impl ModelWatcher {
}
};
let model_name = card.name().to_string();
let worker_namespace = &mcid.namespace;
let worker_component = &mcid.component;
let ws_key = worker_set_key(&mcid.namespace, card.model_type);
// Query discovery for all remaining instances of this model
let active_instances = self
.cards_for_model(&model_name, target_namespace, is_global_namespace)
.cards_for_model_with_endpoints(&model_name, namespace_filter)
.await
.with_context(|| model_name.clone())?;
// Check if instances of the SAME component remain in this namespace.
// In disaggregated deployments, prefill and decode are different components
// in the same namespace, so we must check at the component level to avoid
// removing one type's WorkerSet while the other still has workers.
let component_has_instances = active_instances.iter().any(|(eid, _)| {
eid.namespace == *worker_namespace && eid.component == *worker_component
});
if !component_has_instances {
// No more workers of this component in this namespace — remove its WorkerSet
if let Some(_removed_ws) = self.manager.remove_worker_set(&model_name, &ws_key) {
// remove_prefill_activator uses deployment namespace (not ws_key)
self.manager
.remove_prefill_activator(&model_name, worker_namespace);
tracing::info!(
model_name,
namespace = %worker_namespace,
"Removed WorkerSet (no remaining instances in namespace)"
);
}
}
// Check if the Model still has instances in any namespace
if !active_instances.is_empty() {
tracing::debug!(
model_name,
target_namespace = ?target_namespace,
active_instance_count = active_instances.len(),
"Model has other active instances, not removing"
"Model has other active instances in other namespaces"
);
return Ok(None);
}
// Ignore the errors because model could be either type
let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name);
let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let images_model_remove_err = self.manager.remove_images_model(&model_name);
let videos_model_remove_err = self.manager.remove_videos_model(&model_name);
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name);
let mut chat_model_removed = false;
let mut completions_model_removed = false;
let mut embeddings_model_removed = false;
let mut images_model_removed = false;
let mut videos_model_removed = false;
let mut tensor_model_removed = false;
let mut prefill_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
chat_model_removed = true;
}
if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
{
completions_model_removed = true;
}
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
embeddings_model_removed = true;
}
if images_model_remove_err.is_ok() && self.manager.list_images_models().is_empty() {
images_model_removed = true;
}
if videos_model_remove_err.is_ok() && self.manager.list_videos_models().is_empty() {
videos_model_removed = true;
}
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
tensor_model_removed = true;
}
if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() {
prefill_model_removed = true;
}
// No instances remain anywhere — remove the entire Model
let _ = self.manager.remove_model(&model_name);
if !chat_model_removed
&& !completions_model_removed
&& !embeddings_model_removed
&& !images_model_removed
&& !videos_model_removed
&& !tensor_model_removed
&& !prefill_model_removed
{
tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, images_model_removed: {}, videos_model_removed: {}, tensor_model_removed: {}, prefill_model_removed: {}",
model_name,
chat_model_removed,
completions_model_removed,
embeddings_model_removed,
images_model_removed,
videos_model_removed,
tensor_model_removed,
prefill_model_removed
);
} else {
if let Some(tx) = &self.model_update_tx {
for model_type in ALL_MODEL_TYPES {
if ((chat_model_removed && *model_type == ModelType::Chat)
|| (completions_model_removed && *model_type == ModelType::Completions)
|| (embeddings_model_removed && *model_type == ModelType::Embedding)
|| (images_model_removed && *model_type == ModelType::Images)
|| (videos_model_removed && *model_type == ModelType::Videos)
|| (tensor_model_removed && *model_type == ModelType::TensorBased)
|| (prefill_model_removed && *model_type == ModelType::Prefill))
&& let Some(tx) = &self.model_update_tx
if card.model_type.intersects(*model_type)
&& is_model_type_list_empty(&self.manager, *model_type)
{
tx.send(ModelUpdate::Removed(card.clone())).await.ok();
}
......@@ -368,54 +361,52 @@ impl ModelWatcher {
mcid: &ModelCardInstanceId,
card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> {
// Check if model is already registered before downloading config.
// This prevents duplicate HuggingFace API calls when multiple workers register
// the same model.
// Prefill and decode models are tracked separately, so registering one
// doesn't block the other (they can arrive in any order).
let already_registered = if card.model_type.supports_prefill() {
self.manager.has_prefill_model(card.name())
} else {
self.manager.has_decode_model(card.name())
};
// Check if this specific (model, namespace, type) WorkerSet already exists.
// If so, this is just another worker joining an existing set — no pipeline build needed.
let model_name = card.name().to_string();
let namespace = mcid.namespace.clone();
let ws_key = worker_set_key(&namespace, card.model_type);
if already_registered {
if let Some(model) = self.manager.get_model(&model_name)
&& model.has_worker_set(&ws_key)
{
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!(
model_name = card.name(),
namespace = mcid.namespace,
model_type = %card.model_type,
"Model already registered, skipping config download"
namespace = namespace,
"Worker joined existing WorkerSet, skipping pipeline build"
);
return Ok(());
}
// Use registering_models set to prevent concurrent registrations.
let model_key = card.name().to_string();
if !self.registering_models.insert(model_key.clone()) {
// Guard against concurrent pipeline construction for the same (model, namespace, type)
let registration_key = ModelManager::model_namespace_key(&model_name, &ws_key);
if !self
.registering_worker_sets
.insert(registration_key.clone())
{
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!(
model_name = card.name(),
namespace = mcid.namespace,
"Model registration in progress by another worker, skipping"
namespace = namespace,
"WorkerSet registration in progress, skipping"
);
return Ok(());
}
// We acquired the registration lock. Use a helper to ensure cleanup on all exit paths.
let result = self.do_model_registration(mcid, card).await;
let result = self.do_worker_set_registration(mcid, card).await;
// Always remove from registering set, whether success or failure
self.registering_models.remove(&model_key);
// Always remove from registering set
self.registering_worker_sets.remove(&registration_key);
result
}
/// Inner function that performs the actual model registration.
/// Called by handle_put after acquiring the registration lock.
async fn do_model_registration(
/// Build a complete WorkerSet with all engines for this (model, namespace)
/// and add it to the Model.
async fn do_worker_set_registration(
&self,
mcid: &ModelCardInstanceId,
card: &mut ModelDeploymentCard,
......@@ -428,7 +419,12 @@ impl ModelWatcher {
.component(&mcid.component)?;
let endpoint = component.endpoint(&mcid.endpoint);
let client = endpoint.client().await?;
tracing::debug!(model_name = card.name(), "adding model");
let instance_watcher = client.instance_avail_watcher();
tracing::debug!(
model_name = card.name(),
namespace = mcid.namespace,
"building worker set pipeline"
);
self.manager
.save_model_card(&mcid.to_path(), card.clone())?;
......@@ -437,6 +433,12 @@ impl ModelWatcher {
}
let checksum = card.mdcsum();
let namespace = mcid.namespace.clone();
let ws_key = worker_set_key(&namespace, card.model_type);
// Build the WorkerSet with all applicable engines
let mut worker_set = WorkerSet::new(namespace.clone(), checksum.to_string(), card.clone());
worker_set.set_instance_watcher(instance_watcher);
if card.model_input == ModelInput::Tokens
&& (card.model_type.supports_chat() || card.model_type.supports_completions())
......@@ -477,7 +479,7 @@ impl ModelWatcher {
let model_name = card.name().to_string();
let prefill_chooser = self
.manager
.register_prefill_router(model_name.clone())
.register_prefill_router(&model_name, &namespace)
.map(|rx| {
// Create prefill-specific config with track_active_blocks disabled
let mut prefill_config = self.router_config.kv_router_config;
......@@ -490,20 +492,24 @@ impl ModelWatcher {
card.kv_cache_block_size,
Some(prefill_config),
self.router_config.enforce_disagg,
model_name.clone(), // Pass model name for worker monitor lookup
model_name.clone(),
namespace.clone(),
)
});
// Get or create the worker monitor for this model.
// Always create the monitor for Prometheus metrics (active_decode_blocks, active_prefill_tokens,
// Create a new worker monitor for this WorkerSet. Each WorkerSet gets its own
// monitor (1-to-1) since each monitor is scoped to this WorkerSet's Client/namespace.
// The monitor tracks Prometheus metrics (active_decode_blocks, active_prefill_tokens,
// worker TTFT/ITL cleanup). The thresholds control busy detection behavior only.
// LoadThresholdConfig allows dynamic threshold updates via the ModelManager.
let worker_monitor = Some(self.manager.get_or_create_worker_monitor(
card.name(),
let worker_monitor = Some(KvWorkerMonitor::new(
client.clone(),
self.router_config.load_threshold_config.clone(),
));
// Store KV router and worker monitor on the WorkerSet
worker_set.kv_router = kv_chooser.clone();
worker_set.worker_monitor = worker_monitor.clone();
// Add chat engine only if the model supports chat
if card.model_type.supports_chat() {
let factory_engine = if let Some(ref factory) = self.chat_engine_factory {
......@@ -537,9 +543,7 @@ impl ModelWatcher {
.await
.context("build_routed_pipeline")?
};
self.manager
.add_chat_completions_model(card.name(), checksum, chat_engine)
.context("add_chat_completions_model")?;
worker_set.chat_engine = Some(chat_engine);
tracing::info!("Chat completions is ready");
}
......@@ -572,9 +576,7 @@ impl ModelWatcher {
)
.await
.context("build_routed_pipeline_with_preprocessor")?;
self.manager
.add_completions_model(card.name(), checksum, completions_engine)
.context("add_completions_model")?;
worker_set.completions_engine = Some(completions_engine);
tracing::info!("Completions is ready");
}
} else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
......@@ -586,9 +588,7 @@ impl ModelWatcher {
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_embeddings_model(card.name(), checksum, engine)?;
worker_set.embeddings_engine = Some(Arc::new(push_router));
}
// Case: Text + (Images, Audio, Videos)
// Must come before the plain Text+Chat / Text+Completions branches because
......@@ -599,8 +599,7 @@ impl ModelWatcher {
|| card.model_type.supports_audios()
|| card.model_type.supports_videos())
{
// Image Models can support chat completions (vllm omni way)
// So register chat_completions model as well
// Image/Audio/Video models can also support chat completions (vLLM omni way)
if card.model_type.supports_chat() {
let chat_router = PushRouter::<
NvCreateChatCompletionRequest,
......@@ -612,14 +611,9 @@ impl ModelWatcher {
None,
)
.await?;
self.manager.add_chat_completions_model(
card.name(),
checksum,
Arc::new(chat_router),
)?;
worker_set.chat_engine = Some(Arc::new(chat_router));
}
// This is ModelType::Images : registers /v1/images/* endpoints
if card.model_type.supports_images() {
let images_router = PushRouter::<
NvCreateImageRequest,
......@@ -628,11 +622,9 @@ impl ModelWatcher {
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_images_model(card.name(), checksum, Arc::new(images_router))?;
worker_set.images_engine = Some(Arc::new(images_router));
}
// This is ModelType::Videos : registers /v1/videos/* endpoints
if card.model_type.supports_videos() {
let videos_router = PushRouter::<
NvCreateVideoRequest,
......@@ -641,8 +633,7 @@ impl ModelWatcher {
client.clone(), self.router_config.router_mode, None, None
)
.await?;
self.manager
.add_videos_model(card.name(), checksum, Arc::new(videos_router))?;
worker_set.videos_engine = Some(Arc::new(videos_router));
}
// TODO: add audio models support
......@@ -655,9 +646,7 @@ impl ModelWatcher {
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_chat_completions_model(card.name(), checksum, engine)?;
worker_set.chat_engine = Some(Arc::new(push_router));
} else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
// Case: Text + Completions
let push_router = PushRouter::<
......@@ -667,12 +656,9 @@ impl ModelWatcher {
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_completions_model(card.name(), checksum, engine)?;
worker_set.completions_engine = Some(Arc::new(push_router));
} else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
// Case 4: Tokens + Embeddings
// Create preprocessing pipeline similar to Backend
let frontend = SegmentSource::<
SingleIn<NvCreateEmbeddingRequest>,
......@@ -702,8 +688,7 @@ impl ModelWatcher {
.link(preprocessor.backward_edge())?
.link(frontend)?;
self.manager
.add_embeddings_model(card.name(), checksum, embedding_engine)?;
worker_set.embeddings_engine = Some(embedding_engine);
} else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
// Case 6: Tensor + TensorBased (non-LLM)
// No KV cache concepts - not an LLM model
......@@ -714,9 +699,7 @@ impl ModelWatcher {
client, self.router_config.router_mode, None, None
)
.await?;
let engine = Arc::new(push_router);
self.manager
.add_tensor_model(card.name(), checksum, engine)?;
worker_set.tensor_engine = Some(Arc::new(push_router));
} else if card.model_type.supports_prefill() {
// Case 6: Prefill
// Guardrail: Verify model_input is Tokens
......@@ -732,13 +715,18 @@ impl ModelWatcher {
"Prefill model detected, registering and activating prefill router"
);
// Register prefill model for tracking (no engine needed, just lifecycle)
// Prefill sets have no engines — we add the WorkerSet first for tracking,
// then activate the prefill router.
self.manager
.add_prefill_model(card.name(), checksum)
.context("add_prefill_model")?;
.add_worker_set(card.name(), &ws_key, worker_set)?;
// Activate the prefill router with the endpoint for this prefill model
let Ok(()) = self.manager.activate_prefill_router(card.name(), endpoint) else {
// Note: activate_prefill_router is keyed by deployment namespace (not ws_key)
// because it coordinates between decode and prefill WorkerSets that share
// the same deployment namespace but have different ws_keys ("ns" vs "ns:prefill").
let Ok(()) = self
.manager
.activate_prefill_router(card.name(), &namespace, endpoint)
else {
tracing::warn!(
model_name = card.name(),
"Failed to activate prefill router - prefill model may already be activated"
......@@ -750,6 +738,8 @@ impl ModelWatcher {
model_name = card.name(),
"Prefill model registered and router activated successfully"
);
return Ok(());
} else {
// Reject unsupported combinations
anyhow::bail!(
......@@ -760,6 +750,10 @@ impl ModelWatcher {
);
}
// Add the completed WorkerSet to the Model
self.manager
.add_worker_set(card.name(), &ws_key, worker_set)?;
Ok(())
}
......@@ -772,7 +766,6 @@ impl ModelWatcher {
for instance in instances {
match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => {
// Extract EndpointId from the instance
let endpoint_id = match &instance {
dynamo_runtime::discovery::DiscoveryInstance::Model {
namespace,
......@@ -805,19 +798,101 @@ impl ModelWatcher {
pub async fn cards_for_model(
&self,
model_name: &str,
target_namespace: Option<&str>,
is_global_namespace: bool,
namespace_filter: &NamespaceFilter,
) -> anyhow::Result<Vec<ModelDeploymentCard>> {
Ok(self
.cards_for_model_with_endpoints(model_name, namespace_filter)
.await?
.into_iter()
.map(|(_, card)| card)
.collect())
}
/// Like `cards_for_model` but also returns the EndpointId for each card,
/// allowing callers to filter by namespace.
async fn cards_for_model_with_endpoints(
&self,
model_name: &str,
namespace_filter: &NamespaceFilter,
) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
let mut all = self.all_cards().await?;
all.retain(|(endpoint_id, card)| {
let matches_name = card.name() == model_name;
let matches_namespace = match (is_global_namespace, target_namespace) {
(true, _) => true,
(false, None) => true,
(false, Some(target_ns)) => endpoint_id.namespace == target_ns,
};
let matches_namespace = namespace_filter.matches(&endpoint_id.namespace);
matches_name && matches_namespace
});
Ok(all.into_iter().map(|(_eid, card)| card).collect())
Ok(all)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::discovery::WorkerSet;
use crate::model_card::ModelDeploymentCard;
fn make_worker_set(namespace: &str) -> WorkerSet {
WorkerSet::new(
namespace.to_string(),
"test-checksum".to_string(),
ModelDeploymentCard::default(),
)
}
#[test]
fn test_is_model_type_list_empty_on_empty_manager() {
let mm = ModelManager::new();
assert!(is_model_type_list_empty(&mm, ModelType::Chat));
assert!(is_model_type_list_empty(&mm, ModelType::Completions));
assert!(is_model_type_list_empty(&mm, ModelType::Embedding));
assert!(is_model_type_list_empty(&mm, ModelType::Images));
assert!(is_model_type_list_empty(&mm, ModelType::Videos));
assert!(is_model_type_list_empty(&mm, ModelType::TensorBased));
assert!(is_model_type_list_empty(&mm, ModelType::Prefill));
}
#[test]
fn test_is_model_type_list_empty_prefill_present() {
let mm = ModelManager::new();
// A WorkerSet with no engines is treated as a prefill set
mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"))
.unwrap();
assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));
// Other types should still be empty since the WorkerSet has no engines
assert!(is_model_type_list_empty(&mm, ModelType::Chat));
assert!(is_model_type_list_empty(&mm, ModelType::Completions));
assert!(is_model_type_list_empty(&mm, ModelType::Embedding));
assert!(is_model_type_list_empty(&mm, ModelType::Images));
assert!(is_model_type_list_empty(&mm, ModelType::Videos));
assert!(is_model_type_list_empty(&mm, ModelType::TensorBased));
}
#[test]
fn test_is_model_type_list_empty_after_removal() {
let mm = ModelManager::new();
mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"))
.unwrap();
assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));
mm.remove_model("model-a");
assert!(is_model_type_list_empty(&mm, ModelType::Prefill));
}
#[test]
fn test_is_model_type_list_not_empty_when_other_model_remains() {
let mm = ModelManager::new();
mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"))
.unwrap();
mm.add_worker_set("model-b", "ns1", make_worker_set("ns1"))
.unwrap();
// Remove one model — other still provides prefill
mm.remove_model("model-a");
assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));
// Remove the last model — now empty
mm.remove_model("model-b");
assert!(is_model_type_list_empty(&mm, ModelType::Prefill));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! A WorkerSet represents a group of workers deployed from the same configuration,
//! identified by their shared namespace. Each WorkerSet owns a complete pipeline
//! (engines, KV router, prefill router) built from its specific ModelDeploymentCard.
use std::sync::Arc;
use tokio::sync::watch;
use crate::{
discovery::KvWorkerMonitor,
kv_router::KvRouter,
model_card::ModelDeploymentCard,
types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
videos::OpenAIVideosStreamingEngine,
},
},
};
/// A set of workers from the same namespace/configuration with their own pipeline.
pub struct WorkerSet {
/// Full namespace (e.g., "ns-abc12345")
namespace: String,
/// MDC checksum for this set's configuration
mdcsum: String,
/// The model deployment card used to build this set's pipeline
card: ModelDeploymentCard,
// Engines — each WorkerSet owns its own pipelines
pub(crate) chat_engine: Option<OpenAIChatCompletionsStreamingEngine>,
pub(crate) completions_engine: Option<OpenAICompletionsStreamingEngine>,
pub(crate) embeddings_engine: Option<OpenAIEmbeddingsStreamingEngine>,
pub(crate) images_engine: Option<OpenAIImagesStreamingEngine>,
pub(crate) videos_engine: Option<OpenAIVideosStreamingEngine>,
pub(crate) tensor_engine: Option<TensorStreamingEngine>,
/// KV router for this set's workers (if KV mode)
pub(crate) kv_router: Option<Arc<KvRouter>>,
/// Worker monitor for load-based rejection
pub(crate) worker_monitor: Option<KvWorkerMonitor>,
/// Watcher for available instance IDs (from the Client's discovery watch).
/// None for in-process models (http/grpc) which don't have a discovery client.
instance_count_rx: Option<watch::Receiver<Vec<u64>>>,
}
impl WorkerSet {
pub fn new(namespace: String, mdcsum: String, card: ModelDeploymentCard) -> Self {
Self {
namespace,
mdcsum,
card,
chat_engine: None,
completions_engine: None,
embeddings_engine: None,
images_engine: None,
videos_engine: None,
tensor_engine: None,
kv_router: None,
worker_monitor: None,
instance_count_rx: None,
}
}
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn mdcsum(&self) -> &str {
&self.mdcsum
}
pub fn card(&self) -> &ModelDeploymentCard {
&self.card
}
pub fn has_chat_engine(&self) -> bool {
self.chat_engine.is_some()
}
pub fn has_completions_engine(&self) -> bool {
self.completions_engine.is_some()
}
pub fn has_embeddings_engine(&self) -> bool {
self.embeddings_engine.is_some()
}
pub fn has_images_engine(&self) -> bool {
self.images_engine.is_some()
}
pub fn has_videos_engine(&self) -> bool {
self.videos_engine.is_some()
}
pub fn has_tensor_engine(&self) -> bool {
self.tensor_engine.is_some()
}
/// Whether this set has any decode engine (chat or completions)
pub fn has_decode_engine(&self) -> bool {
self.has_chat_engine() || self.has_completions_engine()
}
/// Whether this set tracks a prefill model (no engine, just lifecycle)
pub fn is_prefill_set(&self) -> bool {
!self.has_decode_engine()
&& !self.has_embeddings_engine()
&& !self.has_images_engine()
&& !self.has_videos_engine()
&& !self.has_tensor_engine()
}
/// Build ParsingOptions from this WorkerSet's card configuration.
pub fn parsing_options(&self) -> crate::protocols::openai::ParsingOptions {
crate::protocols::openai::ParsingOptions::new(
self.card.runtime_config.tool_call_parser.clone(),
self.card.runtime_config.reasoning_parser.clone(),
)
}
/// Number of active workers in this set, derived from the Client's discovery watcher.
/// Returns 1 for in-process models (no watcher) since they always have one local worker.
pub fn worker_count(&self) -> usize {
match &self.instance_count_rx {
Some(rx) => rx.borrow().len(),
None => 1,
}
}
/// Store the instance watcher from the Client's discovery system.
/// Must be called before the WorkerSet is wrapped in Arc.
pub fn set_instance_watcher(&mut self, rx: watch::Receiver<Vec<u64>>) {
self.instance_count_rx = Some(rx);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_card::ModelDeploymentCard;
fn make_worker_set(namespace: &str, mdcsum: &str) -> WorkerSet {
WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
)
}
#[test]
fn test_worker_set_basics() {
let ws = make_worker_set("ns1", "abc123");
assert_eq!(ws.namespace(), "ns1");
assert_eq!(ws.mdcsum(), "abc123");
}
#[test]
fn test_no_engines_by_default() {
let ws = make_worker_set("ns1", "abc123");
assert!(!ws.has_chat_engine());
assert!(!ws.has_completions_engine());
assert!(!ws.has_embeddings_engine());
assert!(!ws.has_images_engine());
assert!(!ws.has_tensor_engine());
assert!(!ws.has_decode_engine());
assert!(ws.is_prefill_set());
}
#[test]
fn test_worker_count_without_watcher() {
// In-process models have no discovery watcher; worker_count defaults to 1
let ws = make_worker_set("ns1", "abc");
assert_eq!(ws.worker_count(), 1);
}
#[test]
fn test_worker_count_with_watcher() {
let mut ws = make_worker_set("ns1", "abc");
// Simulate a discovery watcher with 3 workers
let (tx, rx) = watch::channel(vec![1, 2, 3]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 3);
// Workers leave → count drops
tx.send(vec![1]).unwrap();
assert_eq!(ws.worker_count(), 1);
// All workers gone → count is 0
tx.send(vec![]).unwrap();
assert_eq!(ws.worker_count(), 0);
}
#[test]
fn test_worker_count_with_empty_watcher() {
// Discovery watcher starts empty (no workers have joined yet)
let mut ws = make_worker_set("ns1", "abc");
let (_tx, rx) = watch::channel::<Vec<u64>>(vec![]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 0);
}
#[test]
fn test_worker_count_updates_on_join() {
let mut ws = make_worker_set("ns1", "abc");
let (tx, rx) = watch::channel::<Vec<u64>>(vec![]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 0);
// Workers join one by one
tx.send(vec![100]).unwrap();
assert_eq!(ws.worker_count(), 1);
tx.send(vec![100, 200]).unwrap();
assert_eq!(ws.worker_count(), 2);
tx.send(vec![100, 200, 300]).unwrap();
assert_eq!(ws.worker_count(), 3);
}
}
......@@ -12,6 +12,7 @@ use crate::{
kv_router::{DirectRoutingRouter, KvPushRouter, KvRouter, PrefillRouter},
migration::Migration,
model_card::ModelDeploymentCard,
namespace::NamespaceFilter,
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate,
......@@ -82,8 +83,14 @@ pub async fn prepare_engine(
)
.await?;
let inner_watch_obj = watch_obj.clone();
let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
local_model.namespace(),
local_model.namespace_prefix(),
);
let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(discovery_stream, None).await;
inner_watch_obj
.watch(discovery_stream, namespace_filter)
.await;
});
tracing::info!("Waiting for remote model..");
......
......@@ -9,7 +9,7 @@ use crate::{
entrypoint::{EngineConfig, RouterConfig, input::common},
grpc::service::kserve,
http::service::metrics::Metrics,
namespace::is_global_namespace,
namespace::NamespaceFilter,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
......@@ -38,18 +38,16 @@ pub async fn run(
let router_config = model.router_config();
let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to gRPC service
let namespace = model.namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) {
None
} else {
Some(namespace.to_string())
};
let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
model.namespace(),
model.namespace_prefix(),
);
run_watcher(
distributed_runtime.clone(),
grpc_service.state().manager_clone(),
router_config.clone(),
migration_limit,
target_namespace,
namespace_filter,
)
.await?;
grpc_service
......@@ -113,7 +111,7 @@ async fn run_watcher(
model_manager: Arc<ModelManager>,
router_config: RouterConfig,
migration_limit: u32,
target_namespace: Option<String>,
namespace_filter: NamespaceFilter,
) -> anyhow::Result<()> {
// Create metrics for migration tracking (not exposed via /metrics in gRPC mode)
let metrics = Arc::new(Metrics::new());
......@@ -140,9 +138,7 @@ async fn run_watcher(
// Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move {
watch_obj
.watch(discovery_stream, target_namespace.as_deref())
.await;
watch_obj.watch(discovery_stream, namespace_filter).await;
});
Ok(())
......
......@@ -9,7 +9,7 @@ use crate::{
engines::StreamingEngineAdapter,
entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common},
http::service::service_v2::{self, HttpService},
namespace::is_global_namespace,
namespace::NamespaceFilter,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
......@@ -66,20 +66,17 @@ pub async fn run(
let router_config = model.router_config();
let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace)
// Get namespace from the model, fallback to endpoint_id namespace if not set
let namespace = model.namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) {
None
} else {
Some(namespace.to_string())
};
// Create namespace filter from model configuration
let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
model.namespace(),
model.namespace_prefix(),
);
run_watcher(
distributed_runtime.clone(),
http_service.state().manager_clone(),
router_config.clone(),
migration_limit,
target_namespace,
namespace_filter,
Arc::new(http_service.clone()),
http_service.state().metrics_clone(),
chat_engine_factory.clone(),
......@@ -157,7 +154,7 @@ async fn run_watcher(
model_manager: Arc<ModelManager>,
router_config: RouterConfig,
migration_limit: u32,
target_namespace: Option<String>,
namespace_filter: NamespaceFilter,
http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
......@@ -193,9 +190,7 @@ async fn run_watcher(
// Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move {
watch_obj
.watch(discovery_stream, target_namespace.as_deref())
.await;
watch_obj.watch(discovery_stream, namespace_filter).await;
});
Ok(())
......
......@@ -377,10 +377,8 @@ impl GrpcInferenceService for KserveService {
}
}
let model = completion_request.inner.model.clone();
let parsing_options = self.state.manager.get_parsing_options(&model);
let stream = completion_response_stream(self.state_clone(), completion_request).await?;
let (stream, parsing_options) =
completion_response_stream(self.state_clone(), completion_request).await?;
let completion_response =
NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
......@@ -494,12 +492,9 @@ impl GrpcInferenceService for KserveService {
}
}
let model = completion_request.inner.model.clone();
let parsing_options = state.manager.get_parsing_options(&model);
let streaming = completion_request.inner.stream.unwrap_or(false);
let stream = completion_response_stream(state.clone(), completion_request).await?;
let (stream, parsing_options) = completion_response_stream(state.clone(), completion_request).await?;
if streaming {
pin_mut!(stream);
......
......@@ -9,6 +9,7 @@ use dynamo_runtime::{
use futures::{Stream, StreamExt, stream};
use std::sync::Arc;
use crate::protocols::openai::ParsingOptions;
use crate::protocols::openai::completions::{
NvCreateCompletionRequest, NvCreateCompletionResponse,
};
......@@ -43,7 +44,13 @@ pub const ANNOTATION_REQUEST_ID: &str = "request_id";
pub async fn completion_response_stream(
state: Arc<kserve::State>,
request: NvCreateCompletionRequest,
) -> Result<impl Stream<Item = Annotated<NvCreateCompletionResponse>>, Status> {
) -> Result<
(
impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
ParsingOptions,
),
Status,
> {
// create the context for the request
// [WIP] from request id.
let request_id = get_or_create_request_id(request.inner.user.as_deref());
......@@ -66,9 +73,9 @@ pub async fn completion_response_stream(
let model = &request.inner.model;
// todo - error handling should be more robust
let engine = state
let (engine, parsing_options) = state
.manager()
.get_completions_engine(model)
.get_completions_engine_with_parsing(model)
.map_err(|_| Status::not_found("model not found"))?;
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
......@@ -130,7 +137,7 @@ pub async fn completion_response_stream(
// without need to be cancelled.
connection_handle.disarm();
Ok(stream)
Ok((stream, parsing_options))
}
/// This method will consume an AsyncEngineStream and monitor for disconnects or context cancellation.
......
......@@ -194,9 +194,9 @@ async fn anthropic_messages(
tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state
let (engine, parsing_options) = state
.manager()
.get_chat_completions_engine(&model)
.get_chat_completions_engine_with_parsing(&model)
.map_err(|_| {
anthropic_error(
StatusCode::NOT_FOUND,
......@@ -205,8 +205,6 @@ async fn anthropic_messages(
)
})?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
tracing::trace!("Issuing generate call for Anthropic messages");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment