Unverified Commit cb55766c authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat(runtime): add hierarchical Model/WorkerSet architecture for multi-namespace support (#6054)


Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
parent 8dd6369e
...@@ -9,6 +9,7 @@ from dynamo._core import get_reasoning_parser_names, get_tool_parser_names ...@@ -9,6 +9,7 @@ from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.configuration.arg_group import ArgGroup from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from dynamo.common.utils.namespace import get_worker_namespace
from dynamo.common.utils.output_modalities import OutputModality from dynamo.common.utils.output_modalities import OutputModality
...@@ -35,6 +36,8 @@ class DynamoRuntimeConfig(ConfigBase): ...@@ -35,6 +36,8 @@ class DynamoRuntimeConfig(ConfigBase):
media_output_http_url: Optional[str] = None media_output_http_url: Optional[str] = None
def validate(self) -> None: def validate(self) -> None:
self.namespace = get_worker_namespace(self.namespace)
# TODO get a better way for spot fixes like this. # TODO get a better way for spot fixes like this.
self.enable_local_indexer = not self.durable_kv_events self.enable_local_indexer = not self.durable_kv_events
self._validate_output_modalities() self._validate_output_modalities()
...@@ -69,7 +72,8 @@ class DynamoRuntimeArgGroup(ArgGroup): ...@@ -69,7 +72,8 @@ class DynamoRuntimeArgGroup(ArgGroup):
flag_name="--namespace", flag_name="--namespace",
env_var="DYN_NAMESPACE", env_var="DYN_NAMESPACE",
default="dynamo", default="dynamo",
help="Dynamo namespace", help="Dynamo namespace. If DYN_NAMESPACE_WORKER_SUFFIX is set, "
"'-{suffix}' is appended to support multiple worker pools",
) )
add_argument( add_argument(
g, g,
......
...@@ -17,6 +17,7 @@ Submodules: ...@@ -17,6 +17,7 @@ Submodules:
from dynamo.common.utils import ( from dynamo.common.utils import (
endpoint_types, endpoint_types,
engine_response, engine_response,
namespace,
otel_tracing, otel_tracing,
paths, paths,
prometheus, prometheus,
...@@ -26,6 +27,7 @@ from dynamo.common.utils import ( ...@@ -26,6 +27,7 @@ from dynamo.common.utils import (
__all__ = [ __all__ = [
"endpoint_types", "endpoint_types",
"engine_response", "engine_response",
"namespace",
"otel_tracing", "otel_tracing",
"paths", "paths",
"prometheus", "prometheus",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
def get_worker_namespace(namespace: Optional[str] = None) -> str:
"""Get the Dynamo namespace for a worker.
Uses the provided namespace, or falls back to the DYN_NAMESPACE environment
variable (defaulting to "dynamo"). If DYN_NAMESPACE_WORKER_SUFFIX is set,
it is appended as "{namespace}-{suffix}" to support multiple sets of workers
for the same model.
"""
if not namespace:
namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
suffix = os.environ.get("DYN_NAMESPACE_WORKER_SUFFIX")
if suffix:
namespace = f"{namespace}-{suffix}"
return namespace
...@@ -54,6 +54,7 @@ class FrontendConfig(ConfigBase): ...@@ -54,6 +54,7 @@ class FrontendConfig(ConfigBase):
router_max_tree_size: int router_max_tree_size: int
router_prune_target_ratio: float router_prune_target_ratio: float
namespace: Optional[str] = None namespace: Optional[str] = None
namespace_prefix: Optional[str] = None
router_replica_sync: bool router_replica_sync: bool
router_snapshot_threshold: int router_snapshot_threshold: int
router_reset_states: bool router_reset_states: bool
...@@ -128,9 +129,8 @@ class FrontendArgGroup(ArgGroup): ...@@ -128,9 +129,8 @@ class FrontendArgGroup(ArgGroup):
env_var="DYN_NAMESPACE", env_var="DYN_NAMESPACE",
default=None, default=None,
help=( help=(
"Dynamo namespace for model discovery scoping. If specified, models will " "Dynamo namespace for model discovery scoping. Use for exact namespace matching. "
"only be discovered from this namespace. If not specified, discovers models " "If --namespace-prefix is also specified, prefix takes precedence."
"from all namespaces (global discovery)."
), ),
) )
...@@ -256,6 +256,18 @@ class FrontendArgGroup(ArgGroup): ...@@ -256,6 +256,18 @@ class FrontendArgGroup(ArgGroup):
arg_type=float, arg_type=float,
) )
add_argument(
g,
flag_name="--namespace-prefix",
env_var="DYN_NAMESPACE_PREFIX",
default=None,
help=(
"Dynamo namespace prefix for model discovery scoping. Discovers models from "
"namespaces starting with this prefix (e.g., 'ns' matches 'ns', 'ns-abc123', "
"'ns-def456'). Takes precedence over --namespace if both are specified."
),
)
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--router-replica-sync", flag_name="--router-replica-sync",
......
...@@ -230,6 +230,8 @@ async def async_main(): ...@@ -230,6 +230,8 @@ async def async_main():
kwargs["tls_key_path"] = config.tls_key_path kwargs["tls_key_path"] = config.tls_key_path
if config.namespace: if config.namespace:
kwargs["namespace"] = config.namespace kwargs["namespace"] = config.namespace
if config.namespace_prefix:
kwargs["namespace_prefix"] = config.namespace_prefix
if config.kserve_grpc_server and config.grpc_metrics_port: if config.kserve_grpc_server and config.grpc_metrics_port:
kwargs["http_metrics_port"] = config.grpc_metrics_port kwargs["http_metrics_port"] = config.grpc_metrics_port
......
...@@ -8,6 +8,8 @@ import os ...@@ -8,6 +8,8 @@ import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from dynamo.common.utils.namespace import get_worker_namespace
from . import __version__ from . import __version__
from .utils.planner_profiler_perf_data_converter import ( from .utils.planner_profiler_perf_data_converter import (
convert_profile_results_to_npz, convert_profile_results_to_npz,
...@@ -15,7 +17,7 @@ from .utils.planner_profiler_perf_data_converter import ( ...@@ -15,7 +17,7 @@ from .utils.planner_profiler_perf_data_converter import (
is_profile_results_dir, is_profile_results_dir,
) )
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo") DYN_NAMESPACE = get_worker_namespace()
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate" DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"
DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate" DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate"
......
...@@ -7,11 +7,12 @@ This module defines the DiffusionConfig dataclass used for configuring ...@@ -7,11 +7,12 @@ This module defines the DiffusionConfig dataclass used for configuring
video and image diffusion workers. video and image diffusion workers.
""" """
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo") from dynamo.common.utils.namespace import get_worker_namespace
DYN_NAMESPACE = get_worker_namespace()
# Default model paths # Default model paths
DEFAULT_VIDEO_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" DEFAULT_VIDEO_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
......
...@@ -704,6 +704,7 @@ pub unsafe extern "C" fn create_routers( ...@@ -704,6 +704,7 @@ pub unsafe extern "C" fn create_routers(
Some(prefill_config), Some(prefill_config),
enforce_disagg, enforce_disagg,
model_name.clone(), model_name.clone(),
namespace_str.clone(),
) )
} }
None if enforce_disagg => { None if enforce_disagg => {
......
...@@ -182,6 +182,7 @@ pub(crate) struct EntrypointArgs { ...@@ -182,6 +182,7 @@ pub(crate) struct EntrypointArgs {
tls_key_path: Option<PathBuf>, tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>,
is_prefill: bool, is_prefill: bool,
migration_limit: u32, migration_limit: u32,
chat_engine_factory: Option<PyEngineFactory>, chat_engine_factory: Option<PyEngineFactory>,
...@@ -191,7 +192,7 @@ pub(crate) struct EntrypointArgs { ...@@ -191,7 +192,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs { impl EntrypointArgs {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[new] #[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))] #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
pub fn new( pub fn new(
py: Python<'_>, py: Python<'_>,
engine_type: EngineType, engine_type: EngineType,
...@@ -209,6 +210,7 @@ impl EntrypointArgs { ...@@ -209,6 +210,7 @@ impl EntrypointArgs {
tls_key_path: Option<PathBuf>, tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>,
is_prefill: bool, is_prefill: bool,
migration_limit: u32, migration_limit: u32,
chat_engine_factory: Option<PyObject>, chat_engine_factory: Option<PyObject>,
...@@ -254,6 +256,7 @@ impl EntrypointArgs { ...@@ -254,6 +256,7 @@ impl EntrypointArgs {
tls_key_path, tls_key_path,
extra_engine_args, extra_engine_args,
namespace, namespace,
namespace_prefix,
is_prefill, is_prefill,
migration_limit, migration_limit,
chat_engine_factory, chat_engine_factory,
...@@ -296,7 +299,8 @@ pub fn make_engine<'p>( ...@@ -296,7 +299,8 @@ pub fn make_engine<'p>(
.tls_key_path(args.tls_key_path.clone()) .tls_key_path(args.tls_key_path.clone())
.is_mocker(matches!(args.engine_type, EngineType::Mocker)) .is_mocker(matches!(args.engine_type, EngineType::Mocker))
.extra_engine_args(args.extra_engine_args.clone()) .extra_engine_args(args.extra_engine_args.clone())
.namespace(args.namespace.clone()); .namespace(args.namespace.clone())
.namespace_prefix(args.namespace_prefix.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
if let Some(model_path) = args.model_path.clone() { if let Some(model_path) = args.model_path.clone() {
let local_path = if model_path.exists() { let local_path = if model_path.exists() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
mod model;
pub use model::Model;
mod model_manager; mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError}; pub use model_manager::{ModelManager, ModelManagerError};
mod worker_set;
pub use worker_set::WorkerSet;
pub(crate) mod runtime_configs; pub(crate) mod runtime_configs;
pub use runtime_configs::{RuntimeConfigWatch, runtime_config_watch}; pub use runtime_configs::{RuntimeConfigWatch, runtime_config_watch};
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! A Model represents a named model (e.g., "llama-3-70b") that may be served by
//! one or more WorkerSets. Each WorkerSet corresponds to a namespace.
//!
//! Requests are routed to a WorkerSet selected by weighted random (proportional to worker count).
use std::sync::{Arc, OnceLock};
use dashmap::DashMap;
use rand::Rng;
use super::worker_monitor::LoadThresholdConfig;
use super::worker_set::WorkerSet;
use super::{KvWorkerMonitor, ModelManagerError};
use crate::protocols::openai::ParsingOptions;
use crate::types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
images::OpenAIImagesStreamingEngine, videos::OpenAIVideosStreamingEngine,
},
};
/// A named model backed by one or more WorkerSets.
pub struct Model {
name: String,
worker_sets: DashMap<String, Arc<WorkerSet>>,
/// The canonical MDC checksum for this model. Set by the first WorkerSet registered;
/// all subsequent WorkerSets must match. Naturally cleared when the Model is dropped
/// (last WorkerSet removed), allowing a new version to register.
canonical_checksum: OnceLock<String>,
}
impl Model {
pub fn new(name: String) -> Self {
Self {
name,
worker_sets: DashMap::new(),
canonical_checksum: OnceLock::new(),
}
}
pub fn name(&self) -> &str {
&self.name
}
/// Add a WorkerSet to this model. Returns `Err` if the WorkerSet's checksum
/// doesn't match the model's canonical checksum (set by the first WorkerSet).
pub fn add_worker_set(
&self,
namespace: String,
worker_set: Arc<WorkerSet>,
) -> Result<(), ModelManagerError> {
self.set_canonical_checksum(worker_set.mdcsum())?;
tracing::info!(
model = %self.name,
namespace = %namespace,
"Adding worker set to model"
);
self.worker_sets.insert(namespace, worker_set);
Ok(())
}
pub fn remove_worker_set(&self, namespace: &str) -> Option<Arc<WorkerSet>> {
let removed = self.worker_sets.remove(namespace).map(|(_, ws)| ws);
if removed.is_some() {
tracing::info!(
model = %self.name,
namespace = %namespace,
remaining_sets = self.worker_sets.len(),
"Removed worker set from model"
);
}
removed
}
pub fn has_worker_set(&self, namespace: &str) -> bool {
self.worker_sets.contains_key(namespace)
}
pub fn get_worker_set(&self, namespace: &str) -> Option<Arc<WorkerSet>> {
self.worker_sets
.get(namespace)
.map(|entry| entry.value().clone())
}
pub fn is_empty(&self) -> bool {
self.worker_sets.is_empty()
}
pub fn worker_set_count(&self) -> usize {
self.worker_sets.len()
}
/// Check if this model has any decode engine (chat or completions) across any WorkerSet.
pub fn has_decode_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_decode_engine())
}
/// Check if this model tracks prefill (any WorkerSet is a prefill set).
pub fn has_prefill(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().is_prefill_set())
}
/// Check if any WorkerSet has a chat engine.
pub fn has_chat_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_chat_engine())
}
/// Check if any WorkerSet has a completions engine.
pub fn has_completions_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_completions_engine())
}
/// Check if any WorkerSet has an embeddings engine.
pub fn has_embeddings_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_embeddings_engine())
}
/// Check if any WorkerSet has a tensor engine.
pub fn has_tensor_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_tensor_engine())
}
/// Check if any WorkerSet has an images engine.
pub fn has_images_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_images_engine())
}
/// Check if any WorkerSet has a videos engine.
pub fn has_videos_engine(&self) -> bool {
self.worker_sets
.iter()
.any(|entry| entry.value().has_videos_engine())
}
/// Check if a candidate checksum is valid for this model.
/// Returns `Some(true)` if it matches the canonical checksum, `Some(false)` if it
/// doesn't match, or `None` if no canonical checksum has been set yet (no WorkerSets).
pub fn is_valid_checksum(&self, candidate: &str) -> Option<bool> {
let canonical = self.canonical_checksum.get()?;
Some(canonical == candidate)
}
/// Set the canonical checksum for this model. The first caller wins (OnceLock).
/// Returns `Err` if a different checksum was already set.
fn set_canonical_checksum(&self, checksum: &str) -> Result<(), ModelManagerError> {
// Try to set; if already set, verify it matches.
match self.canonical_checksum.set(checksum.to_string()) {
Ok(()) => Ok(()),
Err(_) => {
// OnceLock was already set — check if the value matches
let canonical = self.canonical_checksum.get().unwrap();
if canonical == checksum {
Ok(())
} else {
Err(ModelManagerError::ChecksumMismatch {
model: self.name.clone(),
expected: canonical.clone(),
got: checksum.to_string(),
})
}
}
}
}
// -- Engine accessors: select a WorkerSet, return its engine --
pub fn get_chat_engine(
&self,
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.chat_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_completions_engine(
&self,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.completions_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_embeddings_engine(
&self,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.embeddings_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_images_engine(&self) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.images_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_videos_engine(&self) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.videos_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_tensor_engine(&self) -> Result<TensorStreamingEngine, ModelManagerError> {
self.select_worker_set_with(|ws| ws.tensor_engine.clone())
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
// -- Combined engine + parsing options (atomically from one WorkerSet) --
pub fn get_chat_engine_with_parsing(
&self,
) -> Result<(OpenAIChatCompletionsStreamingEngine, ParsingOptions), ModelManagerError> {
self.select_worker_set_with(|ws| ws.chat_engine.clone().map(|e| (e, ws.parsing_options())))
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
pub fn get_completions_engine_with_parsing(
&self,
) -> Result<(OpenAICompletionsStreamingEngine, ParsingOptions), ModelManagerError> {
self.select_worker_set_with(|ws| {
ws.completions_engine
.clone()
.map(|e| (e, ws.parsing_options()))
})
.ok_or_else(|| ModelManagerError::ModelNotFound(self.name.clone()))
}
// -- Worker monitoring (aggregated across WorkerSets) --
/// Get load threshold config from the first WorkerSet that has a monitor.
/// When `config` is Some, updates ALL monitors (each WorkerSet has its own).
pub fn load_threshold_config(
&self,
config: Option<&LoadThresholdConfig>,
) -> Option<LoadThresholdConfig> {
let mut result = None;
for entry in self.worker_sets.iter() {
if let Some(ref monitor) = entry.value().worker_monitor {
if let Some(cfg) = config {
monitor.set_load_threshold_config(cfg);
}
if result.is_none() {
result = Some(monitor.load_threshold_config());
}
}
}
result
}
/// Get the worker monitor for a specific namespace's WorkerSet.
pub fn get_worker_monitor_for_namespace(&self, namespace: &str) -> Option<KvWorkerMonitor> {
self.worker_sets
.get(namespace)
.and_then(|entry| entry.value().worker_monitor.clone())
}
/// Total worker count across all WorkerSets.
pub fn total_workers(&self) -> usize {
self.worker_sets
.iter()
.map(|entry| entry.value().worker_count())
.sum()
}
// -- Internal selection --
/// Select a WorkerSet and extract a value from it.
///
/// When there's only one set (steady state), returns from that set directly.
/// With multiple sets, uses weighted random selection proportional
/// to worker count, filtering to sets that have the requested engine.
///
/// The `extract` closure should return `Some(value)` if the WorkerSet has the
/// desired engine, or `None` if it doesn't.
fn select_worker_set_with<T, F>(&self, extract: F) -> Option<T>
where
F: Fn(&WorkerSet) -> Option<T>,
{
// Fast path: single set (same zero-worker filtering as the multi-set path below)
// TODO: When the single set has 0 workers, this returns None which maps to
// ModelNotFound (404). Ideally should be 503 "no available workers" — see follow-up.
if self.worker_sets.len() == 1 {
return self.worker_sets.iter().next().and_then(|entry| {
let ws = entry.value();
if ws.worker_count() == 0 {
return None;
}
extract(ws)
});
}
// Collect eligible sets with their worker counts, skipping sets with no workers.
// In-process models (no discovery watcher) return count=1, so they always participate.
// Discovery models with count=0 have no available workers and are skipped.
let eligible: Vec<(T, usize)> = self
.worker_sets
.iter()
.filter_map(|entry| {
let ws = entry.value();
let count = ws.worker_count();
if count == 0 {
return None;
}
extract(ws).map(|val| (val, count))
})
.collect();
if eligible.is_empty() {
return None;
}
if eligible.len() == 1 {
return eligible.into_iter().next().map(|(val, _)| val);
}
// Weighted random selection proportional to worker count
let total_weight: usize = eligible.iter().map(|(_, w)| w).sum();
let mut pick = rand::rng().random_range(0..total_weight);
for (val, weight) in eligible {
if pick < weight {
return Some(val);
}
pick -= weight;
}
// Should not reach here, but fallback to None
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_card::ModelDeploymentCard;
use tokio::sync::watch;
fn make_worker_set(namespace: &str, mdcsum: &str) -> Arc<WorkerSet> {
Arc::new(WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
))
}
/// Create a WorkerSet backed by a watch channel so worker_count reflects the vec length.
fn make_worker_set_with_count(
namespace: &str,
mdcsum: &str,
worker_ids: Vec<u64>,
) -> (Arc<WorkerSet>, watch::Sender<Vec<u64>>) {
let (tx, rx) = watch::channel(worker_ids);
let mut ws = WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
);
ws.set_instance_watcher(rx);
(Arc::new(ws), tx)
}
#[test]
fn test_model_new() {
let model = Model::new("llama".to_string());
assert_eq!(model.name(), "llama");
assert!(model.is_empty());
assert_eq!(model.worker_set_count(), 0);
}
#[test]
fn test_add_remove_worker_set() {
let model = Model::new("llama".to_string());
let ws = make_worker_set("ns1", "abc");
model.add_worker_set("ns1".to_string(), ws).unwrap();
assert!(!model.is_empty());
assert_eq!(model.worker_set_count(), 1);
assert!(model.has_worker_set("ns1"));
assert!(!model.has_worker_set("ns2"));
let removed = model.remove_worker_set("ns1");
assert!(removed.is_some());
assert!(model.is_empty());
let removed_again = model.remove_worker_set("ns1");
assert!(removed_again.is_none());
}
#[test]
fn test_get_worker_set() {
let model = Model::new("llama".to_string());
let ws = make_worker_set("ns1", "abc");
model.add_worker_set("ns1".to_string(), ws).unwrap();
let retrieved = model.get_worker_set("ns1");
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().namespace(), "ns1");
assert!(model.get_worker_set("ns2").is_none());
}
#[test]
fn test_multiple_worker_sets_same_checksum() {
let model = Model::new("llama".to_string());
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
model
.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"))
.unwrap();
assert_eq!(model.worker_set_count(), 2);
assert!(model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
model.remove_worker_set("ns1");
assert_eq!(model.worker_set_count(), 1);
assert!(!model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
}
#[test]
fn test_add_worker_set_rejects_checksum_mismatch() {
let model = Model::new("llama".to_string());
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
// Different checksum from a different namespace should be rejected
let result = model.add_worker_set("ns2".to_string(), make_worker_set("ns2", "def"));
assert!(result.is_err());
assert_eq!(model.worker_set_count(), 1); // ns2 was not added
}
#[test]
fn test_is_valid_checksum() {
let model = Model::new("llama".to_string());
// No canonical set yet
assert_eq!(model.is_valid_checksum("abc123"), None);
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc123"))
.unwrap();
// Matches canonical
assert_eq!(model.is_valid_checksum("abc123"), Some(true));
// Does not match canonical
assert_eq!(model.is_valid_checksum("wrong"), Some(false));
}
#[test]
fn test_no_engines_means_prefill() {
let model = Model::new("llama".to_string());
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
// WorkerSets with no engines are treated as prefill sets
assert!(model.has_prefill());
assert!(!model.has_decode_engine());
assert!(!model.has_chat_engine());
assert!(!model.has_completions_engine());
assert!(!model.has_embeddings_engine());
assert!(!model.has_tensor_engine());
assert!(!model.has_images_engine());
}
#[test]
fn test_get_engine_returns_error_without_engines() {
let model = Model::new("llama".to_string());
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
assert!(model.get_chat_engine().is_err());
assert!(model.get_completions_engine().is_err());
assert!(model.get_embeddings_engine().is_err());
assert!(model.get_images_engine().is_err());
assert!(model.get_tensor_engine().is_err());
}
#[test]
fn test_select_worker_set_with_extracts_namespace() {
// Test that select_worker_set_with works by going through the public API.
// Since we can't create real engines in tests, we verify that selection
// returns None/Err when no engines are configured, which exercises the
// filtering and selection code paths.
let model = Model::new("llama".to_string());
// Empty model
assert!(model.get_chat_engine().is_err());
// Single set (fast path)
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
assert!(model.get_chat_engine().is_err()); // No engine → filtered out
// Multiple sets (weighted path)
model
.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"))
.unwrap();
assert!(model.get_chat_engine().is_err()); // Still no engines → all filtered out
}
#[test]
fn test_total_workers_no_watcher() {
// In-process WorkerSets (no watcher) default to worker_count=1
let model = Model::new("llama".to_string());
assert_eq!(model.total_workers(), 0); // empty model
model
.add_worker_set("ns1".to_string(), make_worker_set("ns1", "abc"))
.unwrap();
assert_eq!(model.total_workers(), 1);
model
.add_worker_set("ns2".to_string(), make_worker_set("ns2", "abc"))
.unwrap();
assert_eq!(model.total_workers(), 2);
}
#[test]
fn test_total_workers_with_watcher() {
let model = Model::new("llama".to_string());
let (ws1, _tx1) = make_worker_set_with_count("ns1", "abc", vec![1, 2, 3]);
let (ws2, _tx2) = make_worker_set_with_count("ns2", "abc", vec![10, 20]);
model.add_worker_set("ns1".to_string(), ws1).unwrap();
model.add_worker_set("ns2".to_string(), ws2).unwrap();
assert_eq!(model.total_workers(), 5); // 3 + 2
}
#[test]
fn test_total_workers_updates_dynamically() {
let model = Model::new("llama".to_string());
let (ws1, tx1) = make_worker_set_with_count("ns1", "abc", vec![1, 2]);
model.add_worker_set("ns1".to_string(), ws1).unwrap();
assert_eq!(model.total_workers(), 2);
// Workers leave
tx1.send(vec![1]).unwrap();
assert_eq!(model.total_workers(), 1);
// All workers gone
tx1.send(vec![]).unwrap();
assert_eq!(model.total_workers(), 0);
}
#[test]
fn test_zero_worker_single_set_filtered() {
// Single WorkerSet with 0 workers should be filtered by select_worker_set_with.
// We test via select_worker_set_with's internal behavior: even though the set
// exists and is_prefill_set() returns true, engine accessors should fail because
// the zero-worker filter runs before the extract closure.
let model = Model::new("llama".to_string());
let (ws, _tx) = make_worker_set_with_count("ns1", "abc", vec![]);
model.add_worker_set("ns1".to_string(), ws).unwrap();
// WorkerSet exists but has 0 workers → selection filtered out → Err
assert!(model.get_chat_engine().is_err());
assert!(model.get_completions_engine().is_err());
}
#[test]
fn test_zero_worker_multi_set_filtered() {
// With multiple sets, only those with workers > 0 participate in selection.
let model = Model::new("llama".to_string());
let (ws1, _tx1) = make_worker_set_with_count("ns1", "abc", vec![]);
let (ws2, _tx2) = make_worker_set_with_count("ns2", "abc", vec![]);
model.add_worker_set("ns1".to_string(), ws1).unwrap();
model.add_worker_set("ns2".to_string(), ws2).unwrap();
// Both have 0 workers → all filtered → Err
assert!(model.get_chat_engine().is_err());
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::{ use std::{collections::HashSet, sync::Arc};
collections::{HashMap, HashSet},
sync::Arc,
};
use dashmap::{DashMap, mapref::entry::Entry}; use dashmap::{DashMap, mapref::entry::Entry};
use parking_lot::RwLock;
use tokio::sync::oneshot; use tokio::sync::oneshot;
use super::worker_monitor::LoadThresholdConfig; use super::worker_monitor::LoadThresholdConfig;
use super::{KvWorkerMonitor, RuntimeConfigWatch, runtime_config_watch}; use super::{KvWorkerMonitor, Model, RuntimeConfigWatch, WorkerSet, runtime_config_watch};
use dynamo_runtime::{ use dynamo_runtime::{
component::{Client, Endpoint, build_transport_type}, component::{Endpoint, build_transport_type},
discovery::DiscoverySpec, discovery::DiscoverySpec,
prelude::DistributedRuntimeProvider, prelude::DistributedRuntimeProvider,
protocols::EndpointId, protocols::EndpointId,
...@@ -27,7 +23,6 @@ use crate::{ ...@@ -27,7 +23,6 @@ use crate::{
}, },
local_model::runtime_config::DisaggregatedEndpoint, local_model::runtime_config::DisaggregatedEndpoint,
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
model_type::ModelType,
types::{ types::{
generic::tensor::TensorStreamingEngine, generic::tensor::TensorStreamingEngine,
openai::{ openai::{
...@@ -54,31 +49,34 @@ pub enum ModelManagerError { ...@@ -54,31 +49,34 @@ pub enum ModelManagerError {
#[error("Model already exists: {0}")] #[error("Model already exists: {0}")]
ModelAlreadyExists(String), ModelAlreadyExists(String),
#[error(
"Checksum mismatch for model {model}: expected {expected}, got {got}. All WorkerSets of a model must share the same checksum. Drain all old workers before deploying a new version."
)]
ChecksumMismatch {
model: String,
expected: String,
got: String,
},
} }
/// Central manager for model engines, routing, and configuration. /// Central manager for model engines, routing, and configuration.
/// ///
/// Manages model lifecycle including engines, KV routers, prefill coordination, /// Models are stored hierarchically: ModelManager → Model → WorkerSet.
/// and per-model busy thresholds for load-based request rejection. /// Each WorkerSet owns a complete pipeline built from its specific configuration.
/// ///
/// Note: Don't implement Clone for this, put it in an Arc instead. /// Note: Don't implement Clone for this, put it in an Arc instead.
pub struct ModelManager { pub struct ModelManager {
// We read a lot and write rarely, so these three are RwLock /// Model name → Model (which contains WorkerSets with engines)
completion_engines: RwLock<ModelEngines<OpenAICompletionsStreamingEngine>>, models: DashMap<String, Arc<Model>>,
chat_completion_engines: RwLock<ModelEngines<OpenAIChatCompletionsStreamingEngine>>,
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
images_engines: RwLock<ModelEngines<OpenAIImagesStreamingEngine>>,
videos_engines: RwLock<ModelEngines<OpenAIVideosStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// Prefill models don't have engines - they're only tracked for discovery/lifecycle
prefill_engines: RwLock<ModelEngines<()>>,
/// Per-instance model cards, keyed by instance path. Used for cleanup on worker removal.
cards: DashMap<String, ModelDeploymentCard>, cards: DashMap<String, ModelDeploymentCard>,
kv_choosers: DashMap<EndpointId, Arc<KvRouter>>,
/// Prefill router activation rendezvous, keyed by "model_name:namespace".
prefill_router_activators: DashMap<String, PrefillActivationState>, prefill_router_activators: DashMap<String, PrefillActivationState>,
// Per-model monitoring: worker_monitors for load-based rejection, runtime_configs for KvScheduler /// Per-endpoint runtime config watchers. Keyed by EndpointId (includes namespace).
worker_monitors: DashMap<String, KvWorkerMonitor>,
runtime_configs: DashMap<EndpointId, RuntimeConfigWatch>, runtime_configs: DashMap<EndpointId, RuntimeConfigWatch>,
} }
...@@ -91,140 +89,324 @@ impl Default for ModelManager { ...@@ -91,140 +89,324 @@ impl Default for ModelManager {
impl ModelManager { impl ModelManager {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
completion_engines: RwLock::new(ModelEngines::default()), models: DashMap::new(),
chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()),
images_engines: RwLock::new(ModelEngines::default()),
videos_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()),
prefill_engines: RwLock::new(ModelEngines::default()),
cards: DashMap::new(), cards: DashMap::new(),
kv_choosers: DashMap::new(),
prefill_router_activators: DashMap::new(), prefill_router_activators: DashMap::new(),
worker_monitors: DashMap::new(),
runtime_configs: DashMap::new(), runtime_configs: DashMap::new(),
} }
} }
pub fn is_valid_checksum( // -- Model access --
/// Get or create a Model for the given name.
pub fn get_or_create_model(&self, model_name: &str) -> Arc<Model> {
self.models
.entry(model_name.to_string())
.or_insert_with(|| Arc::new(Model::new(model_name.to_string())))
.clone()
}
/// Get an existing Model, if it exists.
pub fn get_model(&self, model_name: &str) -> Option<Arc<Model>> {
self.models
.get(model_name)
.map(|entry| entry.value().clone())
}
/// Remove a Model if it has no remaining WorkerSets.
/// Uses atomic remove_if to avoid TOCTOU race between checking is_empty and removing.
pub fn remove_model_if_empty(&self, model_name: &str) {
if self
.models
.remove_if(model_name, |_, model| model.is_empty())
.is_some()
{
tracing::info!(model_name, "Removed empty model from manager");
}
}
/// Add a WorkerSet to a Model. Creates the Model if it doesn't exist.
/// Returns `Err` if the WorkerSet's checksum doesn't match the model's canonical checksum.
pub fn add_worker_set(
&self, &self,
model_type: ModelType,
model_name: &str, model_name: &str,
candidate_checksum: &str, namespace: &str,
) -> Option<bool> { worker_set: WorkerSet,
let mut results = vec![]; ) -> Result<(), ModelManagerError> {
for unit in model_type.units() { let model = self.get_or_create_model(model_name);
let maybe_valid_checksum = match unit { model.add_worker_set(namespace.to_string(), Arc::new(worker_set))
ModelType::Chat => self.chat_completion_engines.read().checksum(model_name), }
ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name), /// Remove a WorkerSet from a Model. Removes the Model if it becomes empty.
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name), pub fn remove_worker_set(&self, model_name: &str, namespace: &str) -> Option<Arc<WorkerSet>> {
ModelType::Images => self.images_engines.read().checksum(model_name), let model = self.models.get(model_name)?;
ModelType::Videos => self.videos_engines.read().checksum(model_name), let removed = model.remove_worker_set(namespace);
ModelType::Prefill => self.prefill_engines.read().checksum(model_name), drop(model);
_ => { self.remove_model_if_empty(model_name);
continue; removed
}
};
if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
tracing::debug!(
model_name,
valid_checksum,
candidate_checksum,
"is_valid_checksum: check case"
);
valid_checksum == candidate_checksum
}) {
results.push(is_valid)
}
}
if results.is_empty() {
None
} else {
// The checksum is valid if it is correct for all the ModelType in the bitflag.
Some(results.into_iter().all(|x| x))
}
} }
// -- Checksum validation --
/// Check if a candidate checksum is valid for a model.
/// Returns `Some(true)` if it matches the model's canonical checksum, `Some(false)` if it
/// doesn't match, or `None` if the model doesn't exist or has no canonical checksum yet.
pub fn is_valid_checksum(&self, model_name: &str, candidate_checksum: &str) -> Option<bool> {
let model = self.models.get(model_name)?;
model.is_valid_checksum(candidate_checksum)
}
// -- Model cards --
pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> { pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.cards.iter().map(|r| r.value().clone()).collect() self.cards.iter().map(|r| r.value().clone()).collect()
} }
/// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.insert(key.to_string(), card);
Ok(())
}
/// Remove and return model card for this instance's key. We do this when the instance stops.
pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
self.cards.remove(key).map(|(_, v)| v)
}
// -- Engine accessors (delegate through Model → WorkerSet) --
/// Check if a decode model (chat or completions) is registered /// Check if a decode model (chat or completions) is registered
pub fn has_decode_model(&self, model: &str) -> bool { pub fn has_decode_model(&self, model: &str) -> bool {
self.chat_completion_engines.read().contains(model) self.models
|| self.completion_engines.read().contains(model) .get(model)
.is_some_and(|m| m.has_decode_engine())
} }
/// Check if a prefill model is registered /// Check if a prefill model is registered
pub fn has_prefill_model(&self, model: &str) -> bool { pub fn has_prefill_model(&self, model: &str) -> bool {
self.prefill_engines.read().contains(model) self.models.get(model).is_some_and(|m| m.has_prefill())
} }
/// Check if any model (decode or prefill) is registered. /// Check if any model (decode or prefill) is registered.
/// Note: For registration skip-checks, use has_decode_model() or has_prefill_model() instead.
pub fn has_model_any(&self, model: &str) -> bool { pub fn has_model_any(&self, model: &str) -> bool {
self.has_decode_model(model) || self.has_prefill_model(model) self.has_decode_model(model) || self.has_prefill_model(model)
} }
pub fn model_display_names(&self) -> HashSet<String> { pub fn model_display_names(&self) -> HashSet<String> {
self.list_chat_completions_models() let mut names = HashSet::new();
.into_iter() for entry in self.models.iter() {
.chain(self.list_completions_models()) let model = entry.value();
.chain(self.list_embeddings_models()) if model.has_chat_engine()
.chain(self.list_images_models()) || model.has_completions_engine()
.chain(self.list_videos_models()) || model.has_embeddings_engine()
.chain(self.list_tensor_models()) || model.has_images_engine()
.chain(self.list_prefill_models()) || model.has_tensor_engine()
.collect() || model.has_videos_engine()
|| model.has_prefill()
{
names.insert(entry.key().clone());
}
}
names
} }
pub fn list_chat_completions_models(&self) -> Vec<String> { pub fn list_chat_completions_models(&self) -> Vec<String> {
self.chat_completion_engines.read().list() self.models
.iter()
.filter(|entry| entry.value().has_chat_engine())
.map(|entry| entry.key().clone())
.collect()
} }
pub fn list_completions_models(&self) -> Vec<String> { pub fn list_completions_models(&self) -> Vec<String> {
self.completion_engines.read().list() self.models
.iter()
.filter(|entry| entry.value().has_completions_engine())
.map(|entry| entry.key().clone())
.collect()
} }
pub fn list_embeddings_models(&self) -> Vec<String> { pub fn list_embeddings_models(&self) -> Vec<String> {
self.embeddings_engines.read().list() self.models
.iter()
.filter(|entry| entry.value().has_embeddings_engine())
.map(|entry| entry.key().clone())
.collect()
} }
pub fn list_tensor_models(&self) -> Vec<String> { pub fn list_tensor_models(&self) -> Vec<String> {
self.tensor_engines.read().list() self.models
.iter()
.filter(|entry| entry.value().has_tensor_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_images_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_images_engine())
.map(|entry| entry.key().clone())
.collect()
}
pub fn list_videos_models(&self) -> Vec<String> {
self.models
.iter()
.filter(|entry| entry.value().has_videos_engine())
.map(|entry| entry.key().clone())
.collect()
} }
pub fn list_prefill_models(&self) -> Vec<String> { pub fn list_prefill_models(&self) -> Vec<String> {
self.prefill_engines.read().list() self.models
.iter()
.filter(|entry| entry.value().has_prefill())
.map(|entry| entry.key().clone())
.collect()
} }
pub fn list_images_models(&self) -> Vec<String> { pub fn get_embeddings_engine(
self.images_engines.read().list() &self,
model: &str,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_embeddings_engine()
} }
pub fn list_videos_models(&self) -> Vec<String> { pub fn get_completions_engine(
self.videos_engines.read().list() &self,
model: &str,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_completions_engine()
} }
pub fn add_completions_model( pub fn get_chat_completions_engine(
&self, &self,
model: &str, model: &str,
card_checksum: &str, ) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> {
engine: OpenAICompletionsStreamingEngine, self.models
) -> Result<(), ModelManagerError> { .get(model)
let mut clients = self.completion_engines.write(); .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
clients.add(model, card_checksum, engine) .get_chat_engine()
}
pub fn get_tensor_engine(
&self,
model: &str,
) -> Result<TensorStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_tensor_engine()
}
pub fn get_images_engine(
&self,
model: &str,
) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_images_engine()
}
pub fn get_videos_engine(
&self,
model: &str,
) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_videos_engine()
}
// -- Combined engine + parsing options (atomically from one WorkerSet) --
pub fn get_chat_completions_engine_with_parsing(
&self,
model: &str,
) -> Result<
(
OpenAIChatCompletionsStreamingEngine,
crate::protocols::openai::ParsingOptions,
),
ModelManagerError,
> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_chat_engine_with_parsing()
} }
pub fn get_completions_engine_with_parsing(
&self,
model: &str,
) -> Result<
(
OpenAICompletionsStreamingEngine,
crate::protocols::openai::ParsingOptions,
),
ModelManagerError,
> {
self.models
.get(model)
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))?
.get_completions_engine_with_parsing()
}
// -- Convenience methods for in-process models (http.rs, grpc.rs) --
// These create a WorkerSet with a default namespace for local models.
// TODO: These methods use ModelDeploymentCard::default() for the WorkerSet, which means
// parsing_options() returns defaults (no tool_call_parser/reasoning_parser). Pass the real
// MDC from callers so ParsingOptions reflect the model's actual configuration.
pub fn add_chat_completions_model( pub fn add_chat_completions_model(
&self, &self,
model: &str, model: &str,
card_checksum: &str, card_checksum: &str,
engine: OpenAIChatCompletionsStreamingEngine, engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write(); let model_entry = self.get_or_create_model(model);
clients.add(model, card_checksum, engine) if model_entry.has_chat_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_chat_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.chat_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
}
pub fn add_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let model_entry = self.get_or_create_model(model);
if model_entry.has_completions_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_completions_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.completions_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
} }
pub fn add_embeddings_model( pub fn add_embeddings_model(
...@@ -233,8 +415,19 @@ impl ModelManager { ...@@ -233,8 +415,19 @@ impl ModelManager {
card_checksum: &str, card_checksum: &str,
engine: OpenAIEmbeddingsStreamingEngine, engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write(); let model_entry = self.get_or_create_model(model);
clients.add(model, card_checksum, engine) if model_entry.has_embeddings_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_embeddings_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.embeddings_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
} }
pub fn add_tensor_model( pub fn add_tensor_model(
...@@ -243,8 +436,19 @@ impl ModelManager { ...@@ -243,8 +436,19 @@ impl ModelManager {
card_checksum: &str, card_checksum: &str,
engine: TensorStreamingEngine, engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write(); let model_entry = self.get_or_create_model(model);
clients.add(model, card_checksum, engine) if model_entry.has_tensor_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_tensor_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.tensor_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
} }
pub fn add_images_model( pub fn add_images_model(
...@@ -253,8 +457,19 @@ impl ModelManager { ...@@ -253,8 +457,19 @@ impl ModelManager {
card_checksum: &str, card_checksum: &str,
engine: OpenAIImagesStreamingEngine, engine: OpenAIImagesStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.images_engines.write(); let model_entry = self.get_or_create_model(model);
clients.add(model, card_checksum, engine) if model_entry.has_images_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_images_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.images_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
} }
pub fn add_videos_model( pub fn add_videos_model(
...@@ -263,8 +478,19 @@ impl ModelManager { ...@@ -263,8 +478,19 @@ impl ModelManager {
card_checksum: &str, card_checksum: &str,
engine: OpenAIVideosStreamingEngine, engine: OpenAIVideosStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.videos_engines.write(); let model_entry = self.get_or_create_model(model);
clients.add(model, card_checksum, engine) if model_entry.has_videos_engine() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
let namespace = format!("__local_videos_{}", model);
let mut ws = WorkerSet::new(
namespace.clone(),
card_checksum.to_string(),
ModelDeploymentCard::default(),
);
ws.videos_engine = Some(engine);
model_entry.add_worker_set(namespace, Arc::new(ws))?;
Ok(())
} }
pub fn add_prefill_model( pub fn add_prefill_model(
...@@ -272,122 +498,74 @@ impl ModelManager { ...@@ -272,122 +498,74 @@ impl ModelManager {
model: &str, model: &str,
card_checksum: &str, card_checksum: &str,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write(); let model_entry = self.get_or_create_model(model);
clients.add(model, card_checksum, ()) if model_entry.has_prefill() {
} return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { let namespace = format!("__local_prefill_{}", model);
let mut clients = self.completion_engines.write(); let ws = WorkerSet::new(
clients.remove(model) namespace.clone(),
} card_checksum.to_string(),
ModelDeploymentCard::default(),
pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { );
let mut clients = self.chat_completion_engines.write(); model_entry.add_worker_set(namespace, Arc::new(ws))?;
clients.remove(model) Ok(())
}
pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write();
clients.remove(model)
} }
pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> { // -- Model removal --
let mut clients = self.tensor_engines.write();
clients.remove(model)
}
pub fn remove_images_model(&self, model: &str) -> Result<(), ModelManagerError> { /// Remove a model entirely (all its WorkerSets).
let mut clients = self.images_engines.write(); /// Returns the removed Model, or None if not found.
clients.remove(model) pub fn remove_model(&self, model: &str) -> Option<Arc<Model>> {
self.models.remove(model).map(|(_, m)| m)
} }
pub fn remove_videos_model(&self, model: &str) -> Result<(), ModelManagerError> { // Per-type remove methods for in-process models (used by Python bindings).
let mut clients = self.videos_engines.write(); // These remove the specific synthetic WorkerSet created by the corresponding add_*_model method.
clients.remove(model)
}
pub fn remove_prefill_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_chat_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
let mut clients = self.prefill_engines.write(); let namespace = format!("__local_chat_{}", model);
clients.remove(model) self.remove_worker_set(model, &namespace)
} .map(|_| ())
.ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
pub fn get_embeddings_engine(
&self,
model: &str,
) -> Result<OpenAIEmbeddingsStreamingEngine, ModelManagerError> {
self.embeddings_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
pub fn get_completions_engine(
&self,
model: &str,
) -> Result<OpenAICompletionsStreamingEngine, ModelManagerError> {
self.completion_engines
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
pub fn get_chat_completions_engine( pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
&self, let namespace = format!("__local_completions_{}", model);
model: &str, self.remove_worker_set(model, &namespace)
) -> Result<OpenAIChatCompletionsStreamingEngine, ModelManagerError> { .map(|_| ())
self.chat_completion_engines .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
pub fn get_tensor_engine( pub fn remove_tensor_model(&self, model: &str) -> Result<(), ModelManagerError> {
&self, let namespace = format!("__local_tensor_{}", model);
model: &str, self.remove_worker_set(model, &namespace)
) -> Result<TensorStreamingEngine, ModelManagerError> { .map(|_| ())
self.tensor_engines .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
pub fn get_images_engine( pub fn remove_embeddings_model(&self, model: &str) -> Result<(), ModelManagerError> {
&self, let namespace = format!("__local_embeddings_{}", model);
model: &str, self.remove_worker_set(model, &namespace)
) -> Result<OpenAIImagesStreamingEngine, ModelManagerError> { .map(|_| ())
self.images_engines .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
pub fn get_videos_engine( pub fn remove_images_model(&self, model: &str) -> Result<(), ModelManagerError> {
&self, let namespace = format!("__local_images_{}", model);
model: &str, self.remove_worker_set(model, &namespace)
) -> Result<OpenAIVideosStreamingEngine, ModelManagerError> { .map(|_| ())
self.videos_engines .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
.read()
.get(model)
.cloned()
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
/// Save a ModelDeploymentCard from an instance's key so we can fetch it later when the key is pub fn remove_videos_model(&self, model: &str) -> Result<(), ModelManagerError> {
/// deleted. let namespace = format!("__local_videos_{}", model);
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> { self.remove_worker_set(model, &namespace)
self.cards.insert(key.to_string(), card); .map(|_| ())
Ok(()) .ok_or_else(|| ModelManagerError::ModelNotFound(model.to_string()))
} }
/// Remove and return model card for this instance's key. We do this when the instance stops. // -- KV Router creation --
pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
self.cards.remove(key).map(|(_, v)| v)
}
pub async fn kv_chooser_for( pub async fn kv_chooser_for(
&self, &self,
...@@ -396,25 +574,9 @@ impl ModelManager { ...@@ -396,25 +574,9 @@ impl ModelManager {
kv_router_config: Option<KvRouterConfig>, kv_router_config: Option<KvRouterConfig>,
worker_type: &'static str, worker_type: &'static str,
) -> anyhow::Result<Arc<KvRouter>> { ) -> anyhow::Result<Arc<KvRouter>> {
let endpoint_id = endpoint.id();
if let Some(kv_chooser) = self.get_kv_chooser(&endpoint_id) {
// Check if the existing router has a different block size
if kv_chooser.block_size() != kv_cache_block_size {
tracing::warn!(
endpoint = %endpoint_id,
existing_block_size = %kv_chooser.block_size(),
requested_block_size = %kv_cache_block_size,
"KV Router block size mismatch! Endpoint is requesting a different kv_cache_block_size than the existing router. \
This will cause routing to fail silently. Consider using the same block size or restarting the router."
);
}
return Ok(kv_chooser);
}
let client = endpoint.client().await?; let client = endpoint.client().await?;
// Register router via discovery mechanism // Register router via discovery mechanism.
let discovery = endpoint.component().drt().discovery(); let discovery = endpoint.component().drt().discovery();
let instance_id = discovery.instance_id(); let instance_id = discovery.instance_id();
...@@ -433,7 +595,7 @@ impl ModelManager { ...@@ -433,7 +595,7 @@ impl ModelManager {
discovery.register(discovery_spec).await?; discovery.register(discovery_spec).await?;
// Get or create runtime config watcher for this endpoint // Get of create runtime config watcher for this endpoint
let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?; let workers_with_configs = self.get_or_create_runtime_config_watcher(endpoint).await?;
let selector = Box::new(DefaultWorkerSelector::new(kv_router_config)); let selector = Box::new(DefaultWorkerSelector::new(kv_router_config));
...@@ -447,28 +609,35 @@ impl ModelManager { ...@@ -447,28 +609,35 @@ impl ModelManager {
worker_type, worker_type,
) )
.await?; .await?;
let new_kv_chooser = Arc::new(chooser); Ok(Arc::new(chooser))
self.kv_choosers.insert(endpoint_id, new_kv_chooser.clone());
Ok(new_kv_chooser)
} }
fn get_kv_chooser(&self, id: &EndpointId) -> Option<Arc<KvRouter>> { // -- Prefill router coordination --
self.kv_choosers.get(id).map(|r| r.value().clone()) // Keyed by "model_name:namespace" so each namespace's decode WorkerSet gets its own
// prefill router activated by same-namespace prefill workers.
/// Build a key for a (model, namespace) pair. Used for prefill router activators
/// and registration guards.
pub(crate) fn model_namespace_key(model_name: &str, namespace: &str) -> String {
format!("{}:{}", model_name, namespace)
} }
/// Register a prefill router for a decode model. Returns a receiver that will be /// Register a prefill router for a decode WorkerSet. Returns a receiver that will be
/// activated when the corresponding prefill model is discovered. /// activated when the corresponding prefill model in the same namespace is discovered.
/// Returns None if the decode model was already registered. /// Returns None if a decode WorkerSet in this namespace was already registered.
pub fn register_prefill_router( pub fn register_prefill_router(
&self, &self,
model_name: String, model_name: &str,
namespace: &str,
) -> Option<oneshot::Receiver<Endpoint>> { ) -> Option<oneshot::Receiver<Endpoint>> {
match self.prefill_router_activators.remove(&model_name) { let key = Self::model_namespace_key(model_name, namespace);
match self.prefill_router_activators.remove(&key) {
Some((_, PrefillActivationState::PrefillReady(rx))) => { Some((_, PrefillActivationState::PrefillReady(rx))) => {
// Prefill endpoint already arrived - rx will immediately resolve // Prefill endpoint already arrived - rx will immediately resolve
tracing::debug!( tracing::debug!(
model_name = %model_name, model_name = %model_name,
"Prefill endpoint already available, returning receiver with endpoint" namespace = %namespace,
"Prefill endpoint already available for namespace, returning receiver"
); );
Some(rx) Some(rx)
} }
...@@ -476,7 +645,8 @@ impl ModelManager { ...@@ -476,7 +645,8 @@ impl ModelManager {
// Decode already registered - this shouldn't happen, restore state and return None // Decode already registered - this shouldn't happen, restore state and return None
tracing::error!( tracing::error!(
model_name = %model_name, model_name = %model_name,
"Decode model already registered for this prefill router" namespace = %namespace,
"Decode WorkerSet already registered for this prefill router"
); );
self.prefill_router_activators self.prefill_router_activators
.insert(key, PrefillActivationState::DecodeWaiting(tx)); .insert(key, PrefillActivationState::DecodeWaiting(tx));
...@@ -485,13 +655,12 @@ impl ModelManager { ...@@ -485,13 +655,12 @@ impl ModelManager {
None => { None => {
// New registration: create tx/rx pair, store sender and return receiver // New registration: create tx/rx pair, store sender and return receiver
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
self.prefill_router_activators.insert( self.prefill_router_activators
model_name.clone(), .insert(key, PrefillActivationState::DecodeWaiting(tx));
PrefillActivationState::DecodeWaiting(tx),
);
tracing::debug!( tracing::debug!(
model_name = %model_name, model_name = %model_name,
"No prefill endpoint available yet, storing sender for future activation" namespace = %namespace,
"No prefill endpoint for namespace yet, storing sender for future activation"
); );
Some(rx) Some(rx)
} }
...@@ -499,115 +668,107 @@ impl ModelManager { ...@@ -499,115 +668,107 @@ impl ModelManager {
} }
/// Activate a prefill router by sending the endpoint through the oneshot channel. /// Activate a prefill router by sending the endpoint through the oneshot channel.
/// If no decode model has registered yet, stores the endpoint for future retrieval. /// The namespace must match the decode WorkerSet's namespace.
pub fn activate_prefill_router( pub fn activate_prefill_router(
&self, &self,
model_name: &str, model_name: &str,
namespace: &str,
endpoint: Endpoint, endpoint: Endpoint,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
match self.prefill_router_activators.remove(model_name) { let key = Self::model_namespace_key(model_name, namespace);
match self.prefill_router_activators.remove(&key) {
Some((_, PrefillActivationState::DecodeWaiting(sender))) => { Some((_, PrefillActivationState::DecodeWaiting(sender))) => {
// Decode model already registered
sender.send(endpoint).map_err(|_| { sender.send(endpoint).map_err(|_| {
anyhow::anyhow!( anyhow::anyhow!(
"Failed to send endpoint to prefill router activator for model: {}", "Failed to send endpoint to prefill router activator for {}:{}",
model_name model_name,
namespace
) )
})?; })?;
tracing::info!( tracing::info!(
model_name = %model_name, model_name = %model_name,
"Activated prefill router for already-registered decode model" namespace = %namespace,
"Activated prefill router for decode WorkerSet"
); );
Ok(()) Ok(())
} }
Some((_, PrefillActivationState::PrefillReady(_))) => { Some((_, PrefillActivationState::PrefillReady(_))) => {
// Prefill already activated - this shouldn't happen anyhow::bail!(
anyhow::bail!("Prefill router for model {} already activated", model_name); "Prefill router for {}:{} already activated",
model_name,
namespace
);
} }
None => { None => {
// Decode model not registered yet - create pair and immediately send endpoint
let (tx, rx) = oneshot::channel(); let (tx, rx) = oneshot::channel();
tx.send(endpoint).map_err(|_| { tx.send(endpoint).map_err(|_| {
anyhow::anyhow!("Failed to send endpoint for prefill model: {}", model_name) anyhow::anyhow!(
"Failed to send endpoint for prefill model {}:{}",
model_name,
namespace
)
})?; })?;
self.prefill_router_activators
// Store the receiver for when decode model registers .insert(key, PrefillActivationState::PrefillReady(rx));
self.prefill_router_activators.insert(
model_name.to_string(),
PrefillActivationState::PrefillReady(rx),
);
tracing::info!( tracing::info!(
model_name = %model_name, model_name = %model_name,
"Stored prefill endpoint for future decode model registration" namespace = %namespace,
"Stored prefill endpoint for future decode WorkerSet registration"
); );
Ok(()) Ok(())
} }
} }
} }
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> { /// Remove the prefill router activator for a (model, namespace) pair.
self.cards /// Called when a WorkerSet is removed to prevent stale activators.
.iter() pub fn remove_prefill_activator(&self, model_name: &str, namespace: &str) {
.find(|r| r.value().display_name == model) let key = Self::model_namespace_key(model_name, namespace);
.and_then(|r| r.value().runtime_config.tool_call_parser.clone()) if self.prefill_router_activators.remove(&key).is_some() {
} tracing::debug!(
model_name = %model_name,
pub fn get_model_reasoning_parser(&self, model: &str) -> Option<String> { namespace = %namespace,
self.cards "Cleaned up prefill router activator for removed WorkerSet"
.iter() );
.find(|r| r.value().display_name == model) }
.and_then(|r| r.value().runtime_config.reasoning_parser.clone())
} }
/// Creates parsing options with tool call parser and reasoning parser for the specified model. // -- Worker monitoring --
pub fn get_parsing_options(&self, model: &str) -> crate::protocols::openai::ParsingOptions {
let tool_call_parser = self.get_model_tool_call_parser(model);
let reasoning_parser = self.get_model_reasoning_parser(model);
crate::protocols::openai::ParsingOptions::new(tool_call_parser, reasoning_parser)
}
/// Gets or sets the load threshold config for a model's worker monitor. /// Gets or sets the load threshold config for a model's worker monitor.
/// Pass `Some(config)` to update, `None` to get. Returns `None` if no monitor exists. /// Checks across all WorkerSets for the model.
pub fn load_threshold_config( pub fn load_threshold_config(
&self, &self,
model: &str, model: &str,
config: Option<&LoadThresholdConfig>, config: Option<&LoadThresholdConfig>,
) -> Option<LoadThresholdConfig> { ) -> Option<LoadThresholdConfig> {
let monitor = self.worker_monitors.get(model)?; let model_entry = self.models.get(model)?;
if let Some(cfg) = config { model_entry.load_threshold_config(config)
monitor.set_load_threshold_config(cfg);
}
Some(monitor.load_threshold_config())
} }
/// Gets an existing worker monitor for a model, if one exists. /// Gets an existing worker monitor for a specific namespace of a model.
pub fn get_worker_monitor(&self, model: &str) -> Option<KvWorkerMonitor> { pub fn get_worker_monitor_for_namespace(
self.worker_monitors.get(model).map(|m| m.clone())
}
/// Gets or creates a worker monitor for a model. Updates thresholds if monitor exists.
pub fn get_or_create_worker_monitor(
&self, &self,
model: &str, model: &str,
client: Client, namespace: &str,
config: LoadThresholdConfig, ) -> Option<KvWorkerMonitor> {
) -> KvWorkerMonitor { let model_entry = self.models.get(model)?;
if let Some(existing) = self.worker_monitors.get(model) { model_entry.get_worker_monitor_for_namespace(namespace)
existing.set_load_threshold_config(&config); }
return existing.clone();
/// Lists all models with worker monitors configured.
pub fn list_busy_thresholds(&self) -> Vec<(String, LoadThresholdConfig)> {
let mut result = Vec::new();
for entry in self.models.iter() {
if let Some(config) = entry.value().load_threshold_config(None) {
result.push((entry.key().clone(), config));
}
} }
let monitor = KvWorkerMonitor::new(client, config); result
self.worker_monitors
.insert(model.to_string(), monitor.clone());
monitor
} }
// -- Runtime configs --
/// Get or create a runtime config watcher for an endpoint. /// Get or create a runtime config watcher for an endpoint.
/// Spawns a background task that joins instance availability and config discovery. /// Spawns a background task that joins instance availability and config discovery.
/// Returns a `watch::Receiver` with the latest `HashMap<WorkerId, ModelRuntimeConfig>`. /// Returns a `watch::Receiver` with the latest `HashMap<WorkerId, ModelRuntimeConfig>`.
...@@ -617,7 +778,6 @@ impl ModelManager { ...@@ -617,7 +778,6 @@ impl ModelManager {
) -> anyhow::Result<RuntimeConfigWatch> { ) -> anyhow::Result<RuntimeConfigWatch> {
let endpoint_id = endpoint.id(); let endpoint_id = endpoint.id();
// Fast path: return existing if present
if let Some(existing) = self.runtime_configs.get(&endpoint_id) { if let Some(existing) = self.runtime_configs.get(&endpoint_id) {
return Ok(existing.clone()); return Ok(existing.clone());
} }
...@@ -638,7 +798,6 @@ impl ModelManager { ...@@ -638,7 +798,6 @@ impl ModelManager {
} }
/// Get disaggregated endpoint for a specific worker. /// Get disaggregated endpoint for a specific worker.
/// Used by PrefillRouter for bootstrap info - works for ANY routing mode.
pub fn get_disaggregated_endpoint( pub fn get_disaggregated_endpoint(
&self, &self,
endpoint_id: &EndpointId, endpoint_id: &EndpointId,
...@@ -648,79 +807,348 @@ impl ModelManager { ...@@ -648,79 +807,348 @@ impl ModelManager {
let configs = rx.borrow(); let configs = rx.borrow();
configs.get(&worker_id)?.disaggregated_endpoint.clone() configs.get(&worker_id)?.disaggregated_endpoint.clone()
} }
}
/// Lists all models with worker monitors configured. #[cfg(test)]
pub fn list_busy_thresholds(&self) -> Vec<(String, LoadThresholdConfig)> { mod tests {
self.worker_monitors use super::*;
.iter() use crate::model_card::ModelDeploymentCard;
.map(|entry| (entry.key().clone(), entry.value().load_threshold_config()))
.collect() fn make_worker_set(namespace: &str, mdcsum: &str) -> WorkerSet {
WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
)
} }
}
pub struct ModelEngines<E> { // -- CRUD delegation tests --
/// Optional default model name
default: Option<String>,
engines: HashMap<String, E>,
/// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
/// same card.
checksums: HashMap<String, String>,
}
impl<E> Default for ModelEngines<E> { #[test]
fn default() -> Self { fn test_add_and_get_worker_set() {
Self { let mm = ModelManager::new();
default: None, let ws = make_worker_set("ns1", "abc");
engines: HashMap::new(), mm.add_worker_set("llama", "ns1", ws).unwrap();
checksums: HashMap::new(),
} let model = mm.get_model("llama");
assert!(model.is_some());
let model = model.unwrap();
assert!(model.has_worker_set("ns1"));
assert_eq!(model.worker_set_count(), 1);
} }
}
impl<E> ModelEngines<E> { #[test]
#[allow(dead_code)] fn test_add_worker_set_creates_model() {
fn set_default(&mut self, model: &str) { let mm = ModelManager::new();
self.default = Some(model.to_string()); assert!(mm.get_model("llama").is_none());
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(mm.get_model("llama").is_some());
} }
#[allow(dead_code)] #[test]
fn clear_default(&mut self) { fn test_remove_worker_set_removes_empty_model() {
self.default = None; let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(mm.get_model("llama").is_some());
let removed = mm.remove_worker_set("llama", "ns1");
assert!(removed.is_some());
assert_eq!(removed.unwrap().namespace(), "ns1");
// Model should be auto-removed since it's now empty
assert!(mm.get_model("llama").is_none());
} }
fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> { #[test]
if self.engines.contains_key(model) { fn test_remove_worker_set_keeps_model_with_remaining() {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); let mm = ModelManager::new();
} mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
self.engines.insert(model.to_string(), engine); .unwrap();
self.checksums mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"))
.insert(model.to_string(), checksum.to_string()); .unwrap();
Ok(())
mm.remove_worker_set("llama", "ns1");
// Model should still exist with ns2
let model = mm.get_model("llama").unwrap();
assert!(!model.has_worker_set("ns1"));
assert!(model.has_worker_set("ns2"));
assert_eq!(model.worker_set_count(), 1);
} }
fn remove(&mut self, model: &str) -> Result<(), ModelManagerError> { #[test]
if self.engines.remove(model).is_none() { fn test_remove_worker_set_nonexistent_model() {
return Err(ModelManagerError::ModelNotFound(model.to_string())); let mm = ModelManager::new();
} assert!(mm.remove_worker_set("llama", "ns1").is_none());
let _ = self.checksums.remove(model); }
Ok(())
#[test]
fn test_remove_worker_set_nonexistent_namespace() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(mm.remove_worker_set("llama", "ns2").is_none());
// Model should still exist (ns1 still there)
assert!(mm.get_model("llama").is_some());
}
#[test]
fn test_remove_model_if_empty_noop_when_not_empty() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
mm.remove_model_if_empty("llama");
assert!(mm.get_model("llama").is_some()); // Still has ns1
} }
fn get(&self, model: &str) -> Option<&E> { #[test]
self.engines.get(model) fn test_remove_model_if_empty_noop_when_missing() {
let mm = ModelManager::new();
mm.remove_model_if_empty("nonexistent"); // Should not panic
} }
fn contains(&self, model: &str) -> bool { #[test]
self.engines.contains_key(model) fn test_remove_model() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
mm.add_worker_set("llama", "ns2", make_worker_set("ns2", "abc"))
.unwrap();
let removed = mm.remove_model("llama");
assert!(removed.is_some());
assert!(mm.get_model("llama").is_none());
}
#[test]
fn test_get_or_create_model_idempotent() {
let mm = ModelManager::new();
let m1 = mm.get_or_create_model("llama");
let m2 = mm.get_or_create_model("llama");
// Both should point to the same Model (same Arc)
assert!(Arc::ptr_eq(&m1, &m2));
}
// -- Checksum validation tests --
#[test]
fn test_is_valid_checksum_match() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc123"))
.unwrap();
assert_eq!(mm.is_valid_checksum("llama", "abc123"), Some(true));
}
#[test]
fn test_is_valid_checksum_mismatch() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc123"))
.unwrap();
assert_eq!(mm.is_valid_checksum("llama", "wrong"), Some(false));
}
#[test]
fn test_is_valid_checksum_no_canonical_yet() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc123"))
.unwrap();
// Canonical is set, so even for a "new namespace" scenario the checksum is checked
assert_eq!(mm.is_valid_checksum("llama", "abc123"), Some(true));
assert_eq!(mm.is_valid_checksum("llama", "xyz"), Some(false));
}
#[test]
fn test_is_valid_checksum_missing_model() {
let mm = ModelManager::new();
assert_eq!(mm.is_valid_checksum("nonexistent", "abc"), None);
}
#[test]
fn test_is_valid_checksum_cross_namespace_enforcement() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "checksum_a"))
.unwrap();
// A different namespace with a different checksum should be rejected at the model level
assert_eq!(mm.is_valid_checksum("llama", "checksum_b"), Some(false));
// Same checksum is accepted
assert_eq!(mm.is_valid_checksum("llama", "checksum_a"), Some(true));
}
// -- Model listing and filtering tests --
#[test]
fn test_has_decode_model() {
let mm = ModelManager::new();
// No model → false
assert!(!mm.has_decode_model("llama"));
// Prefill-only set (no engines) → false
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(!mm.has_decode_model("llama"));
}
#[test]
fn test_has_prefill_model() {
let mm = ModelManager::new();
// Prefill set = no engines
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(mm.has_prefill_model("llama"));
}
#[test]
fn test_has_model_any() {
let mm = ModelManager::new();
assert!(!mm.has_model_any("llama"));
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
assert!(mm.has_model_any("llama")); // has prefill
}
#[test]
fn test_model_display_names_includes_prefill() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
let names = mm.model_display_names();
assert!(names.contains("llama"));
}
#[test]
fn test_model_display_names_empty() {
let mm = ModelManager::new();
assert!(mm.model_display_names().is_empty());
}
#[test]
fn test_list_prefill_models() {
let mm = ModelManager::new();
mm.add_worker_set("llama", "ns1", make_worker_set("ns1", "abc"))
.unwrap();
mm.add_worker_set("gpt", "ns1", make_worker_set("ns1", "def"))
.unwrap();
let prefill = mm.list_prefill_models();
assert_eq!(prefill.len(), 2);
assert!(prefill.contains(&"llama".to_string()));
assert!(prefill.contains(&"gpt".to_string()));
}
// -- Model card tests --
#[test]
fn test_save_and_remove_model_card() {
let mm = ModelManager::new();
let card = ModelDeploymentCard::default();
mm.save_model_card("instance/key/1", card.clone()).unwrap();
let cards = mm.get_model_cards();
assert_eq!(cards.len(), 1);
let removed = mm.remove_model_card("instance/key/1");
assert!(removed.is_some());
assert!(mm.get_model_cards().is_empty());
}
#[test]
fn test_remove_model_card_nonexistent() {
let mm = ModelManager::new();
assert!(mm.remove_model_card("nonexistent").is_none());
}
// -- Prefill router rendezvous tests --
// Note: activate_prefill_router requires an Endpoint (needs DistributedRuntime),
// so we test the registration state machine and cleanup only.
#[test]
fn test_prefill_router_register_new() {
let mm = ModelManager::new();
// First registration for a (model, namespace) returns Some(rx)
let rx = mm.register_prefill_router("llama", "ns1");
assert!(rx.is_some());
}
#[test]
fn test_prefill_router_double_register_returns_none() {
let mm = ModelManager::new();
let rx1 = mm.register_prefill_router("llama", "ns1");
assert!(rx1.is_some());
// Second registration for the same (model, namespace) returns None
let rx2 = mm.register_prefill_router("llama", "ns1");
assert!(rx2.is_none());
}
#[test]
fn test_prefill_router_different_namespaces_independent() {
let mm = ModelManager::new();
// Different namespaces should be independent
let rx1 = mm.register_prefill_router("llama", "ns1");
let rx2 = mm.register_prefill_router("llama", "ns2");
assert!(rx1.is_some());
assert!(rx2.is_some());
}
#[test]
fn test_prefill_router_different_models_independent() {
let mm = ModelManager::new();
// Different models should be independent
let rx1 = mm.register_prefill_router("llama", "ns1");
let rx2 = mm.register_prefill_router("gpt", "ns1");
assert!(rx1.is_some());
assert!(rx2.is_some());
}
#[test]
fn test_prefill_router_remove_allows_reregister() {
let mm = ModelManager::new();
let rx = mm.register_prefill_router("llama", "ns1");
assert!(rx.is_some());
// Remove the activator
mm.remove_prefill_activator("llama", "ns1");
// Should be able to register again
let rx2 = mm.register_prefill_router("llama", "ns1");
assert!(rx2.is_some());
} }
pub fn list(&self) -> Vec<String> { #[test]
self.engines.keys().map(|k| k.to_owned()).collect() fn test_prefill_router_remove_nonexistent_noop() {
let mm = ModelManager::new();
// Should not panic
mm.remove_prefill_activator("llama", "ns1");
} }
/// Returns a newly allocated String for called convenience. All the places I use #[test]
/// this I need a String. fn test_model_namespace_key_format() {
pub fn checksum(&self, model: &str) -> Option<String> { assert_eq!(
self.checksums.get(model).map(|s| s.to_string()) ModelManager::model_namespace_key("llama", "ns1"),
"llama:ns1"
);
assert_eq!(
ModelManager::model_namespace_key("gpt-4", "default-abc"),
"gpt-4:default-abc"
);
} }
} }
...@@ -24,7 +24,7 @@ use dynamo_runtime::{ ...@@ -24,7 +24,7 @@ use dynamo_runtime::{
use crate::{ use crate::{
backend::Backend, backend::Backend,
discovery::WORKER_TYPE_DECODE, discovery::{KvWorkerMonitor, WORKER_TYPE_DECODE, WorkerSet},
entrypoint::{self, ChatEngineFactoryCallback, RouterConfig}, entrypoint::{self, ChatEngineFactoryCallback, RouterConfig},
http::service::metrics::Metrics, http::service::metrics::Metrics,
kv_router::PrefillRouter, kv_router::PrefillRouter,
...@@ -47,7 +47,17 @@ use crate::{ ...@@ -47,7 +47,17 @@ use crate::{
}; };
use super::ModelManager; use super::ModelManager;
use crate::namespace::is_global_namespace; use crate::namespace::NamespaceFilter;
/// Constructs the WorkerSet storage key. Prefill and decode workers in the same
/// namespace get different keys so they don't block each other's registration.
fn worker_set_key(namespace: &str, model_type: ModelType) -> String {
if model_type.supports_prefill() {
format!("{}:prefill", namespace)
} else {
namespace.to_string()
}
}
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub enum ModelUpdate { pub enum ModelUpdate {
...@@ -64,7 +74,8 @@ pub struct ModelWatcher { ...@@ -64,7 +74,8 @@ pub struct ModelWatcher {
model_update_tx: Option<Sender<ModelUpdate>>, model_update_tx: Option<Sender<ModelUpdate>>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
metrics: Arc<Metrics>, metrics: Arc<Metrics>,
registering_models: DashSet<String>, /// Guards against concurrent pipeline construction for the same (model, namespace).
registering_worker_sets: DashSet<String>,
} }
const ALL_MODEL_TYPES: &[ModelType] = &[ const ALL_MODEL_TYPES: &[ModelType] = &[
...@@ -78,6 +89,27 @@ const ALL_MODEL_TYPES: &[ModelType] = &[ ...@@ -78,6 +89,27 @@ const ALL_MODEL_TYPES: &[ModelType] = &[
ModelType::Prefill, ModelType::Prefill,
]; ];
/// Returns true if no models in the manager support the given model type.
fn is_model_type_list_empty(manager: &ModelManager, model_type: ModelType) -> bool {
if model_type == ModelType::Chat {
manager.list_chat_completions_models().is_empty()
} else if model_type == ModelType::Completions {
manager.list_completions_models().is_empty()
} else if model_type == ModelType::Embedding {
manager.list_embeddings_models().is_empty()
} else if model_type == ModelType::Images {
manager.list_images_models().is_empty()
} else if model_type == ModelType::Videos {
manager.list_videos_models().is_empty()
} else if model_type == ModelType::TensorBased {
manager.list_tensor_models().is_empty()
} else if model_type == ModelType::Prefill {
manager.list_prefill_models().is_empty()
} else {
true
}
}
impl ModelWatcher { impl ModelWatcher {
pub fn new( pub fn new(
runtime: DistributedRuntime, runtime: DistributedRuntime,
...@@ -96,7 +128,7 @@ impl ModelWatcher { ...@@ -96,7 +128,7 @@ impl ModelWatcher {
model_update_tx: None, model_update_tx: None,
chat_engine_factory, chat_engine_factory,
metrics, metrics,
registering_models: DashSet::new(), registering_worker_sets: DashSet::new(),
} }
} }
...@@ -119,10 +151,8 @@ impl ModelWatcher { ...@@ -119,10 +151,8 @@ impl ModelWatcher {
pub async fn watch( pub async fn watch(
&self, &self,
mut discovery_stream: DiscoveryStream, mut discovery_stream: DiscoveryStream,
target_namespace: Option<&str>, namespace_filter: NamespaceFilter,
) { ) {
let global_namespace = target_namespace.is_none_or(is_global_namespace);
while let Some(result) = discovery_stream.next().await { while let Some(result) = discovery_stream.next().await {
let event = match result { let event = match result {
Ok(event) => event, Ok(event) => event,
...@@ -168,28 +198,27 @@ impl ModelWatcher { ...@@ -168,28 +198,27 @@ impl ModelWatcher {
} }
}; };
// Filter by namespace if target_namespace is specified // Filter by namespace using the configured filter
if !global_namespace if !namespace_filter.matches(&mcid.namespace) {
&& let Some(target_ns) = target_namespace
&& mcid.namespace != target_ns
{
tracing::debug!( tracing::debug!(
model_namespace = mcid.namespace, model_namespace = mcid.namespace,
target_namespace = target_ns, namespace_filter = ?namespace_filter,
"Skipping model from different namespace" "Skipping model due to namespace filter"
); );
continue; continue;
} }
// If we already have a worker for this model, and the ModelDeploymentCard // If we already have a WorkerSet for this model and the checksums
// cards don't match, alert, and don't add the new instance // don't match, reject the new worker. All WorkerSets of a model
let can_add = // must share the same checksum.
self.manager let can_add = self.manager.is_valid_checksum(card.name(), card.mdcsum());
.is_valid_checksum(card.model_type, card.name(), card.mdcsum());
if can_add.is_some_and(|is_valid| !is_valid) { if can_add.is_some_and(|is_valid| !is_valid) {
tracing::error!( tracing::error!(
model_name = card.name(), model_name = card.name(),
"Checksum for new model does not match existing model." namespace = mcid.namespace,
"Checksum for new worker does not match model's canonical checksum. \
All WorkerSets must share the same checksum. \
Drain all old workers before deploying a new version."
); );
// TODO: mark that instance down in clients // TODO: mark that instance down in clients
...@@ -199,7 +228,6 @@ impl ModelWatcher { ...@@ -199,7 +228,6 @@ impl ModelWatcher {
// needs more testing). // needs more testing).
// The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside // The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside
// interface `AsyncEngine` which only has a `generate` method. // interface `AsyncEngine` which only has a `generate` method.
continue; continue;
} }
...@@ -235,7 +263,7 @@ impl ModelWatcher { ...@@ -235,7 +263,7 @@ impl ModelWatcher {
}; };
match self match self
.handle_delete(model_card_instance_id, target_namespace, global_namespace) .handle_delete(model_card_instance_id, &namespace_filter)
.await .await
{ {
Ok(Some(model_name)) => { Ok(Some(model_name)) => {
...@@ -253,13 +281,12 @@ impl ModelWatcher { ...@@ -253,13 +281,12 @@ impl ModelWatcher {
} }
} }
/// If the last instance running this model has gone delete it. /// Handle a worker removal. Cleans up per-namespace WorkerSets and the Model itself
/// Returns the name of the model we just deleted, if any. /// when no instances remain. Returns the model name if the entire Model was removed.
async fn handle_delete( async fn handle_delete(
&self, &self,
mcid: &ModelCardInstanceId, mcid: &ModelCardInstanceId,
target_namespace: Option<&str>, namespace_filter: &NamespaceFilter,
is_global_namespace: bool,
) -> anyhow::Result<Option<String>> { ) -> anyhow::Result<Option<String>> {
let key = mcid.to_path(); let key = mcid.to_path();
let card = match self.manager.remove_model_card(&key) { let card = match self.manager.remove_model_card(&key) {
...@@ -269,89 +296,55 @@ impl ModelWatcher { ...@@ -269,89 +296,55 @@ impl ModelWatcher {
} }
}; };
let model_name = card.name().to_string(); let model_name = card.name().to_string();
let worker_namespace = &mcid.namespace;
let worker_component = &mcid.component;
let ws_key = worker_set_key(&mcid.namespace, card.model_type);
// Query discovery for all remaining instances of this model
let active_instances = self let active_instances = self
.cards_for_model(&model_name, target_namespace, is_global_namespace) .cards_for_model_with_endpoints(&model_name, namespace_filter)
.await .await
.with_context(|| model_name.clone())?; .with_context(|| model_name.clone())?;
// Check if instances of the SAME component remain in this namespace.
// In disaggregated deployments, prefill and decode are different components
// in the same namespace, so we must check at the component level to avoid
// removing one type's WorkerSet while the other still has workers.
let component_has_instances = active_instances.iter().any(|(eid, _)| {
eid.namespace == *worker_namespace && eid.component == *worker_component
});
if !component_has_instances {
// No more workers of this component in this namespace — remove its WorkerSet
if let Some(_removed_ws) = self.manager.remove_worker_set(&model_name, &ws_key) {
// remove_prefill_activator uses deployment namespace (not ws_key)
self.manager
.remove_prefill_activator(&model_name, worker_namespace);
tracing::info!(
model_name,
namespace = %worker_namespace,
"Removed WorkerSet (no remaining instances in namespace)"
);
}
}
// Check if the Model still has instances in any namespace
if !active_instances.is_empty() { if !active_instances.is_empty() {
tracing::debug!( tracing::debug!(
model_name, model_name,
target_namespace = ?target_namespace,
active_instance_count = active_instances.len(), active_instance_count = active_instances.len(),
"Model has other active instances, not removing" "Model has other active instances in other namespaces"
); );
return Ok(None); return Ok(None);
} }
// Ignore the errors because model could be either type // No instances remain anywhere — remove the entire Model
let chat_model_remove_err = self.manager.remove_chat_completions_model(&model_name); let _ = self.manager.remove_model(&model_name);
let completions_model_remove_err = self.manager.remove_completions_model(&model_name);
let embeddings_model_remove_err = self.manager.remove_embeddings_model(&model_name);
let images_model_remove_err = self.manager.remove_images_model(&model_name);
let videos_model_remove_err = self.manager.remove_videos_model(&model_name);
let tensor_model_remove_err = self.manager.remove_tensor_model(&model_name);
let prefill_model_remove_err = self.manager.remove_prefill_model(&model_name);
let mut chat_model_removed = false;
let mut completions_model_removed = false;
let mut embeddings_model_removed = false;
let mut images_model_removed = false;
let mut videos_model_removed = false;
let mut tensor_model_removed = false;
let mut prefill_model_removed = false;
if chat_model_remove_err.is_ok() && self.manager.list_chat_completions_models().is_empty() {
chat_model_removed = true;
}
if completions_model_remove_err.is_ok() && self.manager.list_completions_models().is_empty()
{
completions_model_removed = true;
}
if embeddings_model_remove_err.is_ok() && self.manager.list_embeddings_models().is_empty() {
embeddings_model_removed = true;
}
if images_model_remove_err.is_ok() && self.manager.list_images_models().is_empty() {
images_model_removed = true;
}
if videos_model_remove_err.is_ok() && self.manager.list_videos_models().is_empty() {
videos_model_removed = true;
}
if tensor_model_remove_err.is_ok() && self.manager.list_tensor_models().is_empty() {
tensor_model_removed = true;
}
if prefill_model_remove_err.is_ok() && self.manager.list_prefill_models().is_empty() {
prefill_model_removed = true;
}
if !chat_model_removed if let Some(tx) = &self.model_update_tx {
&& !completions_model_removed
&& !embeddings_model_removed
&& !images_model_removed
&& !videos_model_removed
&& !tensor_model_removed
&& !prefill_model_removed
{
tracing::debug!(
"No updates to send for model {}: chat_model_removed: {}, completions_model_removed: {}, embeddings_model_removed: {}, images_model_removed: {}, videos_model_removed: {}, tensor_model_removed: {}, prefill_model_removed: {}",
model_name,
chat_model_removed,
completions_model_removed,
embeddings_model_removed,
images_model_removed,
videos_model_removed,
tensor_model_removed,
prefill_model_removed
);
} else {
for model_type in ALL_MODEL_TYPES { for model_type in ALL_MODEL_TYPES {
if ((chat_model_removed && *model_type == ModelType::Chat) if card.model_type.intersects(*model_type)
|| (completions_model_removed && *model_type == ModelType::Completions) && is_model_type_list_empty(&self.manager, *model_type)
|| (embeddings_model_removed && *model_type == ModelType::Embedding)
|| (images_model_removed && *model_type == ModelType::Images)
|| (videos_model_removed && *model_type == ModelType::Videos)
|| (tensor_model_removed && *model_type == ModelType::TensorBased)
|| (prefill_model_removed && *model_type == ModelType::Prefill))
&& let Some(tx) = &self.model_update_tx
{ {
tx.send(ModelUpdate::Removed(card.clone())).await.ok(); tx.send(ModelUpdate::Removed(card.clone())).await.ok();
} }
...@@ -368,54 +361,52 @@ impl ModelWatcher { ...@@ -368,54 +361,52 @@ impl ModelWatcher {
mcid: &ModelCardInstanceId, mcid: &ModelCardInstanceId,
card: &mut ModelDeploymentCard, card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Check if model is already registered before downloading config. // Check if this specific (model, namespace, type) WorkerSet already exists.
// This prevents duplicate HuggingFace API calls when multiple workers register // If so, this is just another worker joining an existing set — no pipeline build needed.
// the same model. let model_name = card.name().to_string();
// Prefill and decode models are tracked separately, so registering one let namespace = mcid.namespace.clone();
// doesn't block the other (they can arrive in any order). let ws_key = worker_set_key(&namespace, card.model_type);
let already_registered = if card.model_type.supports_prefill() {
self.manager.has_prefill_model(card.name())
} else {
self.manager.has_decode_model(card.name())
};
if already_registered { if let Some(model) = self.manager.get_model(&model_name)
&& model.has_worker_set(&ws_key)
{
self.manager self.manager
.save_model_card(&mcid.to_path(), card.clone())?; .save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!( tracing::debug!(
model_name = card.name(), model_name = card.name(),
namespace = mcid.namespace, namespace = namespace,
model_type = %card.model_type, "Worker joined existing WorkerSet, skipping pipeline build"
"Model already registered, skipping config download"
); );
return Ok(()); return Ok(());
} }
// Use registering_models set to prevent concurrent registrations. // Guard against concurrent pipeline construction for the same (model, namespace, type)
let model_key = card.name().to_string(); let registration_key = ModelManager::model_namespace_key(&model_name, &ws_key);
if !self.registering_models.insert(model_key.clone()) { if !self
.registering_worker_sets
.insert(registration_key.clone())
{
self.manager self.manager
.save_model_card(&mcid.to_path(), card.clone())?; .save_model_card(&mcid.to_path(), card.clone())?;
tracing::debug!( tracing::debug!(
model_name = card.name(), model_name = card.name(),
namespace = mcid.namespace, namespace = namespace,
"Model registration in progress by another worker, skipping" "WorkerSet registration in progress, skipping"
); );
return Ok(()); return Ok(());
} }
// We acquired the registration lock. Use a helper to ensure cleanup on all exit paths. let result = self.do_worker_set_registration(mcid, card).await;
let result = self.do_model_registration(mcid, card).await;
// Always remove from registering set, whether success or failure // Always remove from registering set
self.registering_models.remove(&model_key); self.registering_worker_sets.remove(&registration_key);
result result
} }
/// Inner function that performs the actual model registration. /// Build a complete WorkerSet with all engines for this (model, namespace)
/// Called by handle_put after acquiring the registration lock. /// and add it to the Model.
async fn do_model_registration( async fn do_worker_set_registration(
&self, &self,
mcid: &ModelCardInstanceId, mcid: &ModelCardInstanceId,
card: &mut ModelDeploymentCard, card: &mut ModelDeploymentCard,
...@@ -428,7 +419,12 @@ impl ModelWatcher { ...@@ -428,7 +419,12 @@ impl ModelWatcher {
.component(&mcid.component)?; .component(&mcid.component)?;
let endpoint = component.endpoint(&mcid.endpoint); let endpoint = component.endpoint(&mcid.endpoint);
let client = endpoint.client().await?; let client = endpoint.client().await?;
tracing::debug!(model_name = card.name(), "adding model"); let instance_watcher = client.instance_avail_watcher();
tracing::debug!(
model_name = card.name(),
namespace = mcid.namespace,
"building worker set pipeline"
);
self.manager self.manager
.save_model_card(&mcid.to_path(), card.clone())?; .save_model_card(&mcid.to_path(), card.clone())?;
...@@ -437,6 +433,12 @@ impl ModelWatcher { ...@@ -437,6 +433,12 @@ impl ModelWatcher {
} }
let checksum = card.mdcsum(); let checksum = card.mdcsum();
let namespace = mcid.namespace.clone();
let ws_key = worker_set_key(&namespace, card.model_type);
// Build the WorkerSet with all applicable engines
let mut worker_set = WorkerSet::new(namespace.clone(), checksum.to_string(), card.clone());
worker_set.set_instance_watcher(instance_watcher);
if card.model_input == ModelInput::Tokens if card.model_input == ModelInput::Tokens
&& (card.model_type.supports_chat() || card.model_type.supports_completions()) && (card.model_type.supports_chat() || card.model_type.supports_completions())
...@@ -477,7 +479,7 @@ impl ModelWatcher { ...@@ -477,7 +479,7 @@ impl ModelWatcher {
let model_name = card.name().to_string(); let model_name = card.name().to_string();
let prefill_chooser = self let prefill_chooser = self
.manager .manager
.register_prefill_router(model_name.clone()) .register_prefill_router(&model_name, &namespace)
.map(|rx| { .map(|rx| {
// Create prefill-specific config with track_active_blocks disabled // Create prefill-specific config with track_active_blocks disabled
let mut prefill_config = self.router_config.kv_router_config; let mut prefill_config = self.router_config.kv_router_config;
...@@ -490,20 +492,24 @@ impl ModelWatcher { ...@@ -490,20 +492,24 @@ impl ModelWatcher {
card.kv_cache_block_size, card.kv_cache_block_size,
Some(prefill_config), Some(prefill_config),
self.router_config.enforce_disagg, self.router_config.enforce_disagg,
model_name.clone(), // Pass model name for worker monitor lookup model_name.clone(),
namespace.clone(),
) )
}); });
// Get or create the worker monitor for this model. // Create a new worker monitor for this WorkerSet. Each WorkerSet gets its own
// Always create the monitor for Prometheus metrics (active_decode_blocks, active_prefill_tokens, // monitor (1-to-1) since each monitor is scoped to this WorkerSet's Client/namespace.
// The monitor tracks Prometheus metrics (active_decode_blocks, active_prefill_tokens,
// worker TTFT/ITL cleanup). The thresholds control busy detection behavior only. // worker TTFT/ITL cleanup). The thresholds control busy detection behavior only.
// LoadThresholdConfig allows dynamic threshold updates via the ModelManager. let worker_monitor = Some(KvWorkerMonitor::new(
let worker_monitor = Some(self.manager.get_or_create_worker_monitor(
card.name(),
client.clone(), client.clone(),
self.router_config.load_threshold_config.clone(), self.router_config.load_threshold_config.clone(),
)); ));
// Store KV router and worker monitor on the WorkerSet
worker_set.kv_router = kv_chooser.clone();
worker_set.worker_monitor = worker_monitor.clone();
// Add chat engine only if the model supports chat // Add chat engine only if the model supports chat
if card.model_type.supports_chat() { if card.model_type.supports_chat() {
let factory_engine = if let Some(ref factory) = self.chat_engine_factory { let factory_engine = if let Some(ref factory) = self.chat_engine_factory {
...@@ -537,9 +543,7 @@ impl ModelWatcher { ...@@ -537,9 +543,7 @@ impl ModelWatcher {
.await .await
.context("build_routed_pipeline")? .context("build_routed_pipeline")?
}; };
self.manager worker_set.chat_engine = Some(chat_engine);
.add_chat_completions_model(card.name(), checksum, chat_engine)
.context("add_chat_completions_model")?;
tracing::info!("Chat completions is ready"); tracing::info!("Chat completions is ready");
} }
...@@ -572,9 +576,7 @@ impl ModelWatcher { ...@@ -572,9 +576,7 @@ impl ModelWatcher {
) )
.await .await
.context("build_routed_pipeline_with_preprocessor")?; .context("build_routed_pipeline_with_preprocessor")?;
self.manager worker_set.completions_engine = Some(completions_engine);
.add_completions_model(card.name(), checksum, completions_engine)
.context("add_completions_model")?;
tracing::info!("Completions is ready"); tracing::info!("Completions is ready");
} }
} else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() { } else if card.model_input == ModelInput::Text && card.model_type.supports_embedding() {
...@@ -586,9 +588,7 @@ impl ModelWatcher { ...@@ -586,9 +588,7 @@ impl ModelWatcher {
client, self.router_config.router_mode, None, None client, self.router_config.router_mode, None, None
) )
.await?; .await?;
let engine = Arc::new(push_router); worker_set.embeddings_engine = Some(Arc::new(push_router));
self.manager
.add_embeddings_model(card.name(), checksum, engine)?;
} }
// Case: Text + (Images, Audio, Videos) // Case: Text + (Images, Audio, Videos)
// Must come before the plain Text+Chat / Text+Completions branches because // Must come before the plain Text+Chat / Text+Completions branches because
...@@ -599,8 +599,7 @@ impl ModelWatcher { ...@@ -599,8 +599,7 @@ impl ModelWatcher {
|| card.model_type.supports_audios() || card.model_type.supports_audios()
|| card.model_type.supports_videos()) || card.model_type.supports_videos())
{ {
// Image Models can support chat completions (vllm omni way) // Image/Audio/Video models can also support chat completions (vLLM omni way)
// So register chat_completions model as well
if card.model_type.supports_chat() { if card.model_type.supports_chat() {
let chat_router = PushRouter::< let chat_router = PushRouter::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
...@@ -612,14 +611,9 @@ impl ModelWatcher { ...@@ -612,14 +611,9 @@ impl ModelWatcher {
None, None,
) )
.await?; .await?;
self.manager.add_chat_completions_model( worker_set.chat_engine = Some(Arc::new(chat_router));
card.name(),
checksum,
Arc::new(chat_router),
)?;
} }
// This is ModelType::Images : registers /v1/images/* endpoints
if card.model_type.supports_images() { if card.model_type.supports_images() {
let images_router = PushRouter::< let images_router = PushRouter::<
NvCreateImageRequest, NvCreateImageRequest,
...@@ -628,11 +622,9 @@ impl ModelWatcher { ...@@ -628,11 +622,9 @@ impl ModelWatcher {
client.clone(), self.router_config.router_mode, None, None client.clone(), self.router_config.router_mode, None, None
) )
.await?; .await?;
self.manager worker_set.images_engine = Some(Arc::new(images_router));
.add_images_model(card.name(), checksum, Arc::new(images_router))?;
} }
// This is ModelType::Videos : registers /v1/videos/* endpoints
if card.model_type.supports_videos() { if card.model_type.supports_videos() {
let videos_router = PushRouter::< let videos_router = PushRouter::<
NvCreateVideoRequest, NvCreateVideoRequest,
...@@ -641,8 +633,7 @@ impl ModelWatcher { ...@@ -641,8 +633,7 @@ impl ModelWatcher {
client.clone(), self.router_config.router_mode, None, None client.clone(), self.router_config.router_mode, None, None
) )
.await?; .await?;
self.manager worker_set.videos_engine = Some(Arc::new(videos_router));
.add_videos_model(card.name(), checksum, Arc::new(videos_router))?;
} }
// TODO: add audio models support // TODO: add audio models support
...@@ -655,9 +646,7 @@ impl ModelWatcher { ...@@ -655,9 +646,7 @@ impl ModelWatcher {
client, self.router_config.router_mode, None, None client, self.router_config.router_mode, None, None
) )
.await?; .await?;
let engine = Arc::new(push_router); worker_set.chat_engine = Some(Arc::new(push_router));
self.manager
.add_chat_completions_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_completions() { } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
// Case: Text + Completions // Case: Text + Completions
let push_router = PushRouter::< let push_router = PushRouter::<
...@@ -667,12 +656,9 @@ impl ModelWatcher { ...@@ -667,12 +656,9 @@ impl ModelWatcher {
client, self.router_config.router_mode, None, None client, self.router_config.router_mode, None, None
) )
.await?; .await?;
let engine = Arc::new(push_router); worker_set.completions_engine = Some(Arc::new(push_router));
self.manager
.add_completions_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() { } else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
// Case 4: Tokens + Embeddings // Case 4: Tokens + Embeddings
// Create preprocessing pipeline similar to Backend // Create preprocessing pipeline similar to Backend
let frontend = SegmentSource::< let frontend = SegmentSource::<
SingleIn<NvCreateEmbeddingRequest>, SingleIn<NvCreateEmbeddingRequest>,
...@@ -702,8 +688,7 @@ impl ModelWatcher { ...@@ -702,8 +688,7 @@ impl ModelWatcher {
.link(preprocessor.backward_edge())? .link(preprocessor.backward_edge())?
.link(frontend)?; .link(frontend)?;
self.manager worker_set.embeddings_engine = Some(embedding_engine);
.add_embeddings_model(card.name(), checksum, embedding_engine)?;
} else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() { } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
// Case 6: Tensor + TensorBased (non-LLM) // Case 6: Tensor + TensorBased (non-LLM)
// No KV cache concepts - not an LLM model // No KV cache concepts - not an LLM model
...@@ -714,9 +699,7 @@ impl ModelWatcher { ...@@ -714,9 +699,7 @@ impl ModelWatcher {
client, self.router_config.router_mode, None, None client, self.router_config.router_mode, None, None
) )
.await?; .await?;
let engine = Arc::new(push_router); worker_set.tensor_engine = Some(Arc::new(push_router));
self.manager
.add_tensor_model(card.name(), checksum, engine)?;
} else if card.model_type.supports_prefill() { } else if card.model_type.supports_prefill() {
// Case 6: Prefill // Case 6: Prefill
// Guardrail: Verify model_input is Tokens // Guardrail: Verify model_input is Tokens
...@@ -732,13 +715,18 @@ impl ModelWatcher { ...@@ -732,13 +715,18 @@ impl ModelWatcher {
"Prefill model detected, registering and activating prefill router" "Prefill model detected, registering and activating prefill router"
); );
// Register prefill model for tracking (no engine needed, just lifecycle) // Prefill sets have no engines — we add the WorkerSet first for tracking,
// then activate the prefill router.
self.manager self.manager
.add_prefill_model(card.name(), checksum) .add_worker_set(card.name(), &ws_key, worker_set)?;
.context("add_prefill_model")?;
// Activate the prefill router with the endpoint for this prefill model // Note: activate_prefill_router is keyed by deployment namespace (not ws_key)
let Ok(()) = self.manager.activate_prefill_router(card.name(), endpoint) else { // because it coordinates between decode and prefill WorkerSets that share
// the same deployment namespace but have different ws_keys ("ns" vs "ns:prefill").
let Ok(()) = self
.manager
.activate_prefill_router(card.name(), &namespace, endpoint)
else {
tracing::warn!( tracing::warn!(
model_name = card.name(), model_name = card.name(),
"Failed to activate prefill router - prefill model may already be activated" "Failed to activate prefill router - prefill model may already be activated"
...@@ -750,6 +738,8 @@ impl ModelWatcher { ...@@ -750,6 +738,8 @@ impl ModelWatcher {
model_name = card.name(), model_name = card.name(),
"Prefill model registered and router activated successfully" "Prefill model registered and router activated successfully"
); );
return Ok(());
} else { } else {
// Reject unsupported combinations // Reject unsupported combinations
anyhow::bail!( anyhow::bail!(
...@@ -760,6 +750,10 @@ impl ModelWatcher { ...@@ -760,6 +750,10 @@ impl ModelWatcher {
); );
} }
// Add the completed WorkerSet to the Model
self.manager
.add_worker_set(card.name(), &ws_key, worker_set)?;
Ok(()) Ok(())
} }
...@@ -772,7 +766,6 @@ impl ModelWatcher { ...@@ -772,7 +766,6 @@ impl ModelWatcher {
for instance in instances { for instance in instances {
match instance.deserialize_model::<ModelDeploymentCard>() { match instance.deserialize_model::<ModelDeploymentCard>() {
Ok(card) => { Ok(card) => {
// Extract EndpointId from the instance
let endpoint_id = match &instance { let endpoint_id = match &instance {
dynamo_runtime::discovery::DiscoveryInstance::Model { dynamo_runtime::discovery::DiscoveryInstance::Model {
namespace, namespace,
...@@ -805,19 +798,101 @@ impl ModelWatcher { ...@@ -805,19 +798,101 @@ impl ModelWatcher {
pub async fn cards_for_model( pub async fn cards_for_model(
&self, &self,
model_name: &str, model_name: &str,
target_namespace: Option<&str>, namespace_filter: &NamespaceFilter,
is_global_namespace: bool,
) -> anyhow::Result<Vec<ModelDeploymentCard>> { ) -> anyhow::Result<Vec<ModelDeploymentCard>> {
Ok(self
.cards_for_model_with_endpoints(model_name, namespace_filter)
.await?
.into_iter()
.map(|(_, card)| card)
.collect())
}
/// Like `cards_for_model` but also returns the EndpointId for each card,
/// allowing callers to filter by namespace.
async fn cards_for_model_with_endpoints(
&self,
model_name: &str,
namespace_filter: &NamespaceFilter,
) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
let mut all = self.all_cards().await?; let mut all = self.all_cards().await?;
all.retain(|(endpoint_id, card)| { all.retain(|(endpoint_id, card)| {
let matches_name = card.name() == model_name; let matches_name = card.name() == model_name;
let matches_namespace = match (is_global_namespace, target_namespace) { let matches_namespace = namespace_filter.matches(&endpoint_id.namespace);
(true, _) => true,
(false, None) => true,
(false, Some(target_ns)) => endpoint_id.namespace == target_ns,
};
matches_name && matches_namespace matches_name && matches_namespace
}); });
Ok(all.into_iter().map(|(_eid, card)| card).collect()) Ok(all)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::discovery::WorkerSet;
use crate::model_card::ModelDeploymentCard;
fn make_worker_set(namespace: &str) -> WorkerSet {
WorkerSet::new(
namespace.to_string(),
"test-checksum".to_string(),
ModelDeploymentCard::default(),
)
}
#[test]
fn test_is_model_type_list_empty_on_empty_manager() {
let mm = ModelManager::new();
assert!(is_model_type_list_empty(&mm, ModelType::Chat));
assert!(is_model_type_list_empty(&mm, ModelType::Completions));
assert!(is_model_type_list_empty(&mm, ModelType::Embedding));
assert!(is_model_type_list_empty(&mm, ModelType::Images));
assert!(is_model_type_list_empty(&mm, ModelType::Videos));
assert!(is_model_type_list_empty(&mm, ModelType::TensorBased));
assert!(is_model_type_list_empty(&mm, ModelType::Prefill));
}
#[test]
fn test_is_model_type_list_empty_prefill_present() {
let mm = ModelManager::new();
// A WorkerSet with no engines is treated as a prefill set
mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"))
.unwrap();
assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));
// Other types should still be empty since the WorkerSet has no engines
assert!(is_model_type_list_empty(&mm, ModelType::Chat));
assert!(is_model_type_list_empty(&mm, ModelType::Completions));
assert!(is_model_type_list_empty(&mm, ModelType::Embedding));
assert!(is_model_type_list_empty(&mm, ModelType::Images));
assert!(is_model_type_list_empty(&mm, ModelType::Videos));
assert!(is_model_type_list_empty(&mm, ModelType::TensorBased));
}
#[test]
fn test_is_model_type_list_empty_after_removal() {
let mm = ModelManager::new();
mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"))
.unwrap();
assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));
mm.remove_model("model-a");
assert!(is_model_type_list_empty(&mm, ModelType::Prefill));
}
#[test]
fn test_is_model_type_list_not_empty_when_other_model_remains() {
let mm = ModelManager::new();
mm.add_worker_set("model-a", "ns1", make_worker_set("ns1"))
.unwrap();
mm.add_worker_set("model-b", "ns1", make_worker_set("ns1"))
.unwrap();
// Remove one model — other still provides prefill
mm.remove_model("model-a");
assert!(!is_model_type_list_empty(&mm, ModelType::Prefill));
// Remove the last model — now empty
mm.remove_model("model-b");
assert!(is_model_type_list_empty(&mm, ModelType::Prefill));
} }
} }
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! A WorkerSet represents a group of workers deployed from the same configuration,
//! identified by their shared namespace. Each WorkerSet owns a complete pipeline
//! (engines, KV router, prefill router) built from its specific ModelDeploymentCard.
use std::sync::Arc;
use tokio::sync::watch;
use crate::{
discovery::KvWorkerMonitor,
kv_router::KvRouter,
model_card::ModelDeploymentCard,
types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
videos::OpenAIVideosStreamingEngine,
},
},
};
/// A set of workers from the same namespace/configuration with their own pipeline.
pub struct WorkerSet {
/// Full namespace (e.g., "ns-abc12345")
namespace: String,
/// MDC checksum for this set's configuration
mdcsum: String,
/// The model deployment card used to build this set's pipeline
card: ModelDeploymentCard,
// Engines — each WorkerSet owns its own pipelines
pub(crate) chat_engine: Option<OpenAIChatCompletionsStreamingEngine>,
pub(crate) completions_engine: Option<OpenAICompletionsStreamingEngine>,
pub(crate) embeddings_engine: Option<OpenAIEmbeddingsStreamingEngine>,
pub(crate) images_engine: Option<OpenAIImagesStreamingEngine>,
pub(crate) videos_engine: Option<OpenAIVideosStreamingEngine>,
pub(crate) tensor_engine: Option<TensorStreamingEngine>,
/// KV router for this set's workers (if KV mode)
pub(crate) kv_router: Option<Arc<KvRouter>>,
/// Worker monitor for load-based rejection
pub(crate) worker_monitor: Option<KvWorkerMonitor>,
/// Watcher for available instance IDs (from the Client's discovery watch).
/// None for in-process models (http/grpc) which don't have a discovery client.
instance_count_rx: Option<watch::Receiver<Vec<u64>>>,
}
impl WorkerSet {
pub fn new(namespace: String, mdcsum: String, card: ModelDeploymentCard) -> Self {
Self {
namespace,
mdcsum,
card,
chat_engine: None,
completions_engine: None,
embeddings_engine: None,
images_engine: None,
videos_engine: None,
tensor_engine: None,
kv_router: None,
worker_monitor: None,
instance_count_rx: None,
}
}
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn mdcsum(&self) -> &str {
&self.mdcsum
}
pub fn card(&self) -> &ModelDeploymentCard {
&self.card
}
pub fn has_chat_engine(&self) -> bool {
self.chat_engine.is_some()
}
pub fn has_completions_engine(&self) -> bool {
self.completions_engine.is_some()
}
pub fn has_embeddings_engine(&self) -> bool {
self.embeddings_engine.is_some()
}
pub fn has_images_engine(&self) -> bool {
self.images_engine.is_some()
}
pub fn has_videos_engine(&self) -> bool {
self.videos_engine.is_some()
}
pub fn has_tensor_engine(&self) -> bool {
self.tensor_engine.is_some()
}
/// Whether this set has any decode engine (chat or completions)
pub fn has_decode_engine(&self) -> bool {
self.has_chat_engine() || self.has_completions_engine()
}
/// Whether this set tracks a prefill model (no engine, just lifecycle)
pub fn is_prefill_set(&self) -> bool {
!self.has_decode_engine()
&& !self.has_embeddings_engine()
&& !self.has_images_engine()
&& !self.has_videos_engine()
&& !self.has_tensor_engine()
}
/// Build ParsingOptions from this WorkerSet's card configuration.
pub fn parsing_options(&self) -> crate::protocols::openai::ParsingOptions {
crate::protocols::openai::ParsingOptions::new(
self.card.runtime_config.tool_call_parser.clone(),
self.card.runtime_config.reasoning_parser.clone(),
)
}
/// Number of active workers in this set, derived from the Client's discovery watcher.
/// Returns 1 for in-process models (no watcher) since they always have one local worker.
pub fn worker_count(&self) -> usize {
match &self.instance_count_rx {
Some(rx) => rx.borrow().len(),
None => 1,
}
}
/// Store the instance watcher from the Client's discovery system.
/// Must be called before the WorkerSet is wrapped in Arc.
pub fn set_instance_watcher(&mut self, rx: watch::Receiver<Vec<u64>>) {
self.instance_count_rx = Some(rx);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_card::ModelDeploymentCard;
fn make_worker_set(namespace: &str, mdcsum: &str) -> WorkerSet {
WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
)
}
#[test]
fn test_worker_set_basics() {
let ws = make_worker_set("ns1", "abc123");
assert_eq!(ws.namespace(), "ns1");
assert_eq!(ws.mdcsum(), "abc123");
}
#[test]
fn test_no_engines_by_default() {
let ws = make_worker_set("ns1", "abc123");
assert!(!ws.has_chat_engine());
assert!(!ws.has_completions_engine());
assert!(!ws.has_embeddings_engine());
assert!(!ws.has_images_engine());
assert!(!ws.has_tensor_engine());
assert!(!ws.has_decode_engine());
assert!(ws.is_prefill_set());
}
#[test]
fn test_worker_count_without_watcher() {
// In-process models have no discovery watcher; worker_count defaults to 1
let ws = make_worker_set("ns1", "abc");
assert_eq!(ws.worker_count(), 1);
}
#[test]
fn test_worker_count_with_watcher() {
let mut ws = make_worker_set("ns1", "abc");
// Simulate a discovery watcher with 3 workers
let (tx, rx) = watch::channel(vec![1, 2, 3]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 3);
// Workers leave → count drops
tx.send(vec![1]).unwrap();
assert_eq!(ws.worker_count(), 1);
// All workers gone → count is 0
tx.send(vec![]).unwrap();
assert_eq!(ws.worker_count(), 0);
}
#[test]
fn test_worker_count_with_empty_watcher() {
// Discovery watcher starts empty (no workers have joined yet)
let mut ws = make_worker_set("ns1", "abc");
let (_tx, rx) = watch::channel::<Vec<u64>>(vec![]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 0);
}
#[test]
fn test_worker_count_updates_on_join() {
let mut ws = make_worker_set("ns1", "abc");
let (tx, rx) = watch::channel::<Vec<u64>>(vec![]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 0);
// Workers join one by one
tx.send(vec![100]).unwrap();
assert_eq!(ws.worker_count(), 1);
tx.send(vec![100, 200]).unwrap();
assert_eq!(ws.worker_count(), 2);
tx.send(vec![100, 200, 300]).unwrap();
assert_eq!(ws.worker_count(), 3);
}
}
...@@ -12,6 +12,7 @@ use crate::{ ...@@ -12,6 +12,7 @@ use crate::{
kv_router::{DirectRoutingRouter, KvPushRouter, KvRouter, PrefillRouter}, kv_router::{DirectRoutingRouter, KvPushRouter, KvRouter, PrefillRouter},
migration::Migration, migration::Migration,
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
namespace::NamespaceFilter,
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate, request_template::RequestTemplate,
...@@ -82,8 +83,14 @@ pub async fn prepare_engine( ...@@ -82,8 +83,14 @@ pub async fn prepare_engine(
) )
.await?; .await?;
let inner_watch_obj = watch_obj.clone(); let inner_watch_obj = watch_obj.clone();
let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
local_model.namespace(),
local_model.namespace_prefix(),
);
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(discovery_stream, None).await; inner_watch_obj
.watch(discovery_stream, namespace_filter)
.await;
}); });
tracing::info!("Waiting for remote model.."); tracing::info!("Waiting for remote model..");
......
...@@ -9,7 +9,7 @@ use crate::{ ...@@ -9,7 +9,7 @@ use crate::{
entrypoint::{EngineConfig, RouterConfig, input::common}, entrypoint::{EngineConfig, RouterConfig, input::common},
grpc::service::kserve, grpc::service::kserve,
http::service::metrics::Metrics, http::service::metrics::Metrics,
namespace::is_global_namespace, namespace::NamespaceFilter,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
...@@ -38,18 +38,16 @@ pub async fn run( ...@@ -38,18 +38,16 @@ pub async fn run(
let router_config = model.router_config(); let router_config = model.router_config();
let migration_limit = model.migration_limit(); let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to gRPC service // Listen for models registering themselves, add them to gRPC service
let namespace = model.namespace().unwrap_or(""); let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
let target_namespace = if is_global_namespace(namespace) { model.namespace(),
None model.namespace_prefix(),
} else { );
Some(namespace.to_string())
};
run_watcher( run_watcher(
distributed_runtime.clone(), distributed_runtime.clone(),
grpc_service.state().manager_clone(), grpc_service.state().manager_clone(),
router_config.clone(), router_config.clone(),
migration_limit, migration_limit,
target_namespace, namespace_filter,
) )
.await?; .await?;
grpc_service grpc_service
...@@ -113,7 +111,7 @@ async fn run_watcher( ...@@ -113,7 +111,7 @@ async fn run_watcher(
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
target_namespace: Option<String>, namespace_filter: NamespaceFilter,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Create metrics for migration tracking (not exposed via /metrics in gRPC mode) // Create metrics for migration tracking (not exposed via /metrics in gRPC mode)
let metrics = Arc::new(Metrics::new()); let metrics = Arc::new(Metrics::new());
...@@ -140,9 +138,7 @@ async fn run_watcher( ...@@ -140,9 +138,7 @@ async fn run_watcher(
// Pass the discovery stream to the watcher // Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
watch_obj watch_obj.watch(discovery_stream, namespace_filter).await;
.watch(discovery_stream, target_namespace.as_deref())
.await;
}); });
Ok(()) Ok(())
......
...@@ -9,7 +9,7 @@ use crate::{ ...@@ -9,7 +9,7 @@ use crate::{
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common}, entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common},
http::service::service_v2::{self, HttpService}, http::service::service_v2::{self, HttpService},
namespace::is_global_namespace, namespace::NamespaceFilter,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
...@@ -66,20 +66,17 @@ pub async fn run( ...@@ -66,20 +66,17 @@ pub async fn run(
let router_config = model.router_config(); let router_config = model.router_config();
let migration_limit = model.migration_limit(); let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to HTTP service // Listen for models registering themselves, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace) // Create namespace filter from model configuration
// Get namespace from the model, fallback to endpoint_id namespace if not set let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
let namespace = model.namespace().unwrap_or(""); model.namespace(),
let target_namespace = if is_global_namespace(namespace) { model.namespace_prefix(),
None );
} else {
Some(namespace.to_string())
};
run_watcher( run_watcher(
distributed_runtime.clone(), distributed_runtime.clone(),
http_service.state().manager_clone(), http_service.state().manager_clone(),
router_config.clone(), router_config.clone(),
migration_limit, migration_limit,
target_namespace, namespace_filter,
Arc::new(http_service.clone()), Arc::new(http_service.clone()),
http_service.state().metrics_clone(), http_service.state().metrics_clone(),
chat_engine_factory.clone(), chat_engine_factory.clone(),
...@@ -157,7 +154,7 @@ async fn run_watcher( ...@@ -157,7 +154,7 @@ async fn run_watcher(
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
target_namespace: Option<String>, namespace_filter: NamespaceFilter,
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
...@@ -193,9 +190,7 @@ async fn run_watcher( ...@@ -193,9 +190,7 @@ async fn run_watcher(
// Pass the discovery stream to the watcher // Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
watch_obj watch_obj.watch(discovery_stream, namespace_filter).await;
.watch(discovery_stream, target_namespace.as_deref())
.await;
}); });
Ok(()) Ok(())
......
...@@ -377,10 +377,8 @@ impl GrpcInferenceService for KserveService { ...@@ -377,10 +377,8 @@ impl GrpcInferenceService for KserveService {
} }
} }
let model = completion_request.inner.model.clone(); let (stream, parsing_options) =
let parsing_options = self.state.manager.get_parsing_options(&model); completion_response_stream(self.state_clone(), completion_request).await?;
let stream = completion_response_stream(self.state_clone(), completion_request).await?;
let completion_response = let completion_response =
NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options) NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
...@@ -494,12 +492,9 @@ impl GrpcInferenceService for KserveService { ...@@ -494,12 +492,9 @@ impl GrpcInferenceService for KserveService {
} }
} }
let model = completion_request.inner.model.clone();
let parsing_options = state.manager.get_parsing_options(&model);
let streaming = completion_request.inner.stream.unwrap_or(false); let streaming = completion_request.inner.stream.unwrap_or(false);
let stream = completion_response_stream(state.clone(), completion_request).await?; let (stream, parsing_options) = completion_response_stream(state.clone(), completion_request).await?;
if streaming { if streaming {
pin_mut!(stream); pin_mut!(stream);
......
...@@ -9,6 +9,7 @@ use dynamo_runtime::{ ...@@ -9,6 +9,7 @@ use dynamo_runtime::{
use futures::{Stream, StreamExt, stream}; use futures::{Stream, StreamExt, stream};
use std::sync::Arc; use std::sync::Arc;
use crate::protocols::openai::ParsingOptions;
use crate::protocols::openai::completions::{ use crate::protocols::openai::completions::{
NvCreateCompletionRequest, NvCreateCompletionResponse, NvCreateCompletionRequest, NvCreateCompletionResponse,
}; };
...@@ -43,7 +44,13 @@ pub const ANNOTATION_REQUEST_ID: &str = "request_id"; ...@@ -43,7 +44,13 @@ pub const ANNOTATION_REQUEST_ID: &str = "request_id";
pub async fn completion_response_stream( pub async fn completion_response_stream(
state: Arc<kserve::State>, state: Arc<kserve::State>,
request: NvCreateCompletionRequest, request: NvCreateCompletionRequest,
) -> Result<impl Stream<Item = Annotated<NvCreateCompletionResponse>>, Status> { ) -> Result<
(
impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
ParsingOptions,
),
Status,
> {
// create the context for the request // create the context for the request
// [WIP] from request id. // [WIP] from request id.
let request_id = get_or_create_request_id(request.inner.user.as_deref()); let request_id = get_or_create_request_id(request.inner.user.as_deref());
...@@ -66,9 +73,9 @@ pub async fn completion_response_stream( ...@@ -66,9 +73,9 @@ pub async fn completion_response_stream(
let model = &request.inner.model; let model = &request.inner.model;
// todo - error handling should be more robust // todo - error handling should be more robust
let engine = state let (engine, parsing_options) = state
.manager() .manager()
.get_completions_engine(model) .get_completions_engine_with_parsing(model)
.map_err(|_| Status::not_found("model not found"))?; .map_err(|_| Status::not_found("model not found"))?;
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model); let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
...@@ -130,7 +137,7 @@ pub async fn completion_response_stream( ...@@ -130,7 +137,7 @@ pub async fn completion_response_stream(
// without need to be cancelled. // without need to be cancelled.
connection_handle.disarm(); connection_handle.disarm();
Ok(stream) Ok((stream, parsing_options))
} }
/// This method will consume an AsyncEngineStream and monitor for disconnects or context cancellation. /// This method will consume an AsyncEngineStream and monitor for disconnects or context cancellation.
......
...@@ -194,9 +194,9 @@ async fn anthropic_messages( ...@@ -194,9 +194,9 @@ async fn anthropic_messages(
tracing::trace!("Getting chat completions engine for model: {}", model); tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state let (engine, parsing_options) = state
.manager() .manager()
.get_chat_completions_engine(&model) .get_chat_completions_engine_with_parsing(&model)
.map_err(|_| { .map_err(|_| {
anthropic_error( anthropic_error(
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
...@@ -205,8 +205,6 @@ async fn anthropic_messages( ...@@ -205,8 +205,6 @@ async fn anthropic_messages(
) )
})?; })?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model); let mut response_collector = state.metrics_clone().create_response_collector(&model);
tracing::trace!("Issuing generate call for Anthropic messages"); tracing::trace!("Issuing generate call for Anthropic messages");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment