Unverified Commit cb55766c authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat(runtime): add hierarchical Model/WorkerSet architecture for multi-namespace support (#6054)


Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
parent 8dd6369e
......@@ -9,6 +9,7 @@ from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from dynamo.common.utils.namespace import get_worker_namespace
from dynamo.common.utils.output_modalities import OutputModality
......@@ -35,6 +36,8 @@ class DynamoRuntimeConfig(ConfigBase):
media_output_http_url: Optional[str] = None
def validate(self) -> None:
self.namespace = get_worker_namespace(self.namespace)
# TODO get a better way for spot fixes like this.
self.enable_local_indexer = not self.durable_kv_events
self._validate_output_modalities()
......@@ -69,7 +72,8 @@ class DynamoRuntimeArgGroup(ArgGroup):
flag_name="--namespace",
env_var="DYN_NAMESPACE",
default="dynamo",
help="Dynamo namespace",
help="Dynamo namespace. If DYN_NAMESPACE_WORKER_SUFFIX is set, "
"'-{suffix}' is appended to support multiple worker pools",
)
add_argument(
g,
......
......@@ -17,6 +17,7 @@ Submodules:
from dynamo.common.utils import (
endpoint_types,
engine_response,
namespace,
otel_tracing,
paths,
prometheus,
......@@ -26,6 +27,7 @@ from dynamo.common.utils import (
__all__ = [
"endpoint_types",
"engine_response",
"namespace",
"otel_tracing",
"paths",
"prometheus",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
def get_worker_namespace(namespace: Optional[str] = None) -> str:
"""Get the Dynamo namespace for a worker.
Uses the provided namespace, or falls back to the DYN_NAMESPACE environment
variable (defaulting to "dynamo"). If DYN_NAMESPACE_WORKER_SUFFIX is set,
it is appended as "{namespace}-{suffix}" to support multiple sets of workers
for the same model.
"""
if not namespace:
namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
suffix = os.environ.get("DYN_NAMESPACE_WORKER_SUFFIX")
if suffix:
namespace = f"{namespace}-{suffix}"
return namespace
......@@ -54,6 +54,7 @@ class FrontendConfig(ConfigBase):
router_max_tree_size: int
router_prune_target_ratio: float
namespace: Optional[str] = None
namespace_prefix: Optional[str] = None
router_replica_sync: bool
router_snapshot_threshold: int
router_reset_states: bool
......@@ -128,9 +129,8 @@ class FrontendArgGroup(ArgGroup):
env_var="DYN_NAMESPACE",
default=None,
help=(
"Dynamo namespace for model discovery scoping. If specified, models will "
"only be discovered from this namespace. If not specified, discovers models "
"from all namespaces (global discovery)."
"Dynamo namespace for model discovery scoping. Use for exact namespace matching. "
"If --namespace-prefix is also specified, prefix takes precedence."
),
)
......@@ -256,6 +256,18 @@ class FrontendArgGroup(ArgGroup):
arg_type=float,
)
add_argument(
g,
flag_name="--namespace-prefix",
env_var="DYN_NAMESPACE_PREFIX",
default=None,
help=(
"Dynamo namespace prefix for model discovery scoping. Discovers models from "
"namespaces starting with this prefix (e.g., 'ns' matches 'ns', 'ns-abc123', "
"'ns-def456'). Takes precedence over --namespace if both are specified."
),
)
add_negatable_bool_argument(
g,
flag_name="--router-replica-sync",
......
......@@ -230,6 +230,8 @@ async def async_main():
kwargs["tls_key_path"] = config.tls_key_path
if config.namespace:
kwargs["namespace"] = config.namespace
if config.namespace_prefix:
kwargs["namespace_prefix"] = config.namespace_prefix
if config.kserve_grpc_server and config.grpc_metrics_port:
kwargs["http_metrics_port"] = config.grpc_metrics_port
......
......@@ -8,6 +8,8 @@ import os
import tempfile
from pathlib import Path
from dynamo.common.utils.namespace import get_worker_namespace
from . import __version__
from .utils.planner_profiler_perf_data_converter import (
convert_profile_results_to_npz,
......@@ -15,7 +17,7 @@ from .utils.planner_profiler_perf_data_converter import (
is_profile_results_dir,
)
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DYN_NAMESPACE = get_worker_namespace()
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"
DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate"
......
......@@ -7,11 +7,12 @@ This module defines the DiffusionConfig dataclass used for configuring
video and image diffusion workers.
"""
import os
from dataclasses import dataclass
from typing import Optional
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
from dynamo.common.utils.namespace import get_worker_namespace
DYN_NAMESPACE = get_worker_namespace()
# Default model paths
DEFAULT_VIDEO_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
......
......@@ -704,6 +704,7 @@ pub unsafe extern "C" fn create_routers(
Some(prefill_config),
enforce_disagg,
model_name.clone(),
namespace_str.clone(),
)
}
None if enforce_disagg => {
......
......@@ -182,6 +182,7 @@ pub(crate) struct EntrypointArgs {
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
namespace: Option<String>,
namespace_prefix: Option<String>,
is_prefill: bool,
migration_limit: u32,
chat_engine_factory: Option<PyEngineFactory>,
......@@ -191,7 +192,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs {
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
pub fn new(
py: Python<'_>,
engine_type: EngineType,
......@@ -209,6 +210,7 @@ impl EntrypointArgs {
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
namespace: Option<String>,
namespace_prefix: Option<String>,
is_prefill: bool,
migration_limit: u32,
chat_engine_factory: Option<PyObject>,
......@@ -254,6 +256,7 @@ impl EntrypointArgs {
tls_key_path,
extra_engine_args,
namespace,
namespace_prefix,
is_prefill,
migration_limit,
chat_engine_factory,
......@@ -296,7 +299,8 @@ pub fn make_engine<'p>(
.tls_key_path(args.tls_key_path.clone())
.is_mocker(matches!(args.engine_type, EngineType::Mocker))
.extra_engine_args(args.extra_engine_args.clone())
.namespace(args.namespace.clone());
.namespace(args.namespace.clone())
.namespace_prefix(args.namespace_prefix.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
if let Some(model_path) = args.model_path.clone() {
let local_path = if model_path.exists() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
mod model;
pub use model::Model;
mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError};
mod worker_set;
pub use worker_set::WorkerSet;
pub(crate) mod runtime_configs;
pub use runtime_configs::{RuntimeConfigWatch, runtime_config_watch};
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! A WorkerSet represents a group of workers deployed from the same configuration,
//! identified by their shared namespace. Each WorkerSet owns a complete pipeline
//! (engines, KV router, prefill router) built from its specific ModelDeploymentCard.
use std::sync::Arc;
use tokio::sync::watch;
use crate::{
discovery::KvWorkerMonitor,
kv_router::KvRouter,
model_card::ModelDeploymentCard,
types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
videos::OpenAIVideosStreamingEngine,
},
},
};
/// A set of workers from the same namespace/configuration with their own pipeline.
pub struct WorkerSet {
/// Full namespace (e.g., "ns-abc12345")
namespace: String,
/// MDC checksum for this set's configuration
mdcsum: String,
/// The model deployment card used to build this set's pipeline
card: ModelDeploymentCard,
// Engines — each WorkerSet owns its own pipelines
pub(crate) chat_engine: Option<OpenAIChatCompletionsStreamingEngine>,
pub(crate) completions_engine: Option<OpenAICompletionsStreamingEngine>,
pub(crate) embeddings_engine: Option<OpenAIEmbeddingsStreamingEngine>,
pub(crate) images_engine: Option<OpenAIImagesStreamingEngine>,
pub(crate) videos_engine: Option<OpenAIVideosStreamingEngine>,
pub(crate) tensor_engine: Option<TensorStreamingEngine>,
/// KV router for this set's workers (if KV mode)
pub(crate) kv_router: Option<Arc<KvRouter>>,
/// Worker monitor for load-based rejection
pub(crate) worker_monitor: Option<KvWorkerMonitor>,
/// Watcher for available instance IDs (from the Client's discovery watch).
/// None for in-process models (http/grpc) which don't have a discovery client.
instance_count_rx: Option<watch::Receiver<Vec<u64>>>,
}
impl WorkerSet {
pub fn new(namespace: String, mdcsum: String, card: ModelDeploymentCard) -> Self {
Self {
namespace,
mdcsum,
card,
chat_engine: None,
completions_engine: None,
embeddings_engine: None,
images_engine: None,
videos_engine: None,
tensor_engine: None,
kv_router: None,
worker_monitor: None,
instance_count_rx: None,
}
}
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn mdcsum(&self) -> &str {
&self.mdcsum
}
pub fn card(&self) -> &ModelDeploymentCard {
&self.card
}
pub fn has_chat_engine(&self) -> bool {
self.chat_engine.is_some()
}
pub fn has_completions_engine(&self) -> bool {
self.completions_engine.is_some()
}
pub fn has_embeddings_engine(&self) -> bool {
self.embeddings_engine.is_some()
}
pub fn has_images_engine(&self) -> bool {
self.images_engine.is_some()
}
pub fn has_videos_engine(&self) -> bool {
self.videos_engine.is_some()
}
pub fn has_tensor_engine(&self) -> bool {
self.tensor_engine.is_some()
}
/// Whether this set has any decode engine (chat or completions)
pub fn has_decode_engine(&self) -> bool {
self.has_chat_engine() || self.has_completions_engine()
}
/// Whether this set tracks a prefill model (no engine, just lifecycle)
pub fn is_prefill_set(&self) -> bool {
!self.has_decode_engine()
&& !self.has_embeddings_engine()
&& !self.has_images_engine()
&& !self.has_videos_engine()
&& !self.has_tensor_engine()
}
/// Build ParsingOptions from this WorkerSet's card configuration.
pub fn parsing_options(&self) -> crate::protocols::openai::ParsingOptions {
crate::protocols::openai::ParsingOptions::new(
self.card.runtime_config.tool_call_parser.clone(),
self.card.runtime_config.reasoning_parser.clone(),
)
}
/// Number of active workers in this set, derived from the Client's discovery watcher.
/// Returns 1 for in-process models (no watcher) since they always have one local worker.
pub fn worker_count(&self) -> usize {
match &self.instance_count_rx {
Some(rx) => rx.borrow().len(),
None => 1,
}
}
/// Store the instance watcher from the Client's discovery system.
/// Must be called before the WorkerSet is wrapped in Arc.
pub fn set_instance_watcher(&mut self, rx: watch::Receiver<Vec<u64>>) {
self.instance_count_rx = Some(rx);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_card::ModelDeploymentCard;
fn make_worker_set(namespace: &str, mdcsum: &str) -> WorkerSet {
WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
)
}
#[test]
fn test_worker_set_basics() {
let ws = make_worker_set("ns1", "abc123");
assert_eq!(ws.namespace(), "ns1");
assert_eq!(ws.mdcsum(), "abc123");
}
#[test]
fn test_no_engines_by_default() {
let ws = make_worker_set("ns1", "abc123");
assert!(!ws.has_chat_engine());
assert!(!ws.has_completions_engine());
assert!(!ws.has_embeddings_engine());
assert!(!ws.has_images_engine());
assert!(!ws.has_tensor_engine());
assert!(!ws.has_decode_engine());
assert!(ws.is_prefill_set());
}
#[test]
fn test_worker_count_without_watcher() {
// In-process models have no discovery watcher; worker_count defaults to 1
let ws = make_worker_set("ns1", "abc");
assert_eq!(ws.worker_count(), 1);
}
#[test]
fn test_worker_count_with_watcher() {
let mut ws = make_worker_set("ns1", "abc");
// Simulate a discovery watcher with 3 workers
let (tx, rx) = watch::channel(vec![1, 2, 3]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 3);
// Workers leave → count drops
tx.send(vec![1]).unwrap();
assert_eq!(ws.worker_count(), 1);
// All workers gone → count is 0
tx.send(vec![]).unwrap();
assert_eq!(ws.worker_count(), 0);
}
#[test]
fn test_worker_count_with_empty_watcher() {
// Discovery watcher starts empty (no workers have joined yet)
let mut ws = make_worker_set("ns1", "abc");
let (_tx, rx) = watch::channel::<Vec<u64>>(vec![]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 0);
}
#[test]
fn test_worker_count_updates_on_join() {
let mut ws = make_worker_set("ns1", "abc");
let (tx, rx) = watch::channel::<Vec<u64>>(vec![]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 0);
// Workers join one by one
tx.send(vec![100]).unwrap();
assert_eq!(ws.worker_count(), 1);
tx.send(vec![100, 200]).unwrap();
assert_eq!(ws.worker_count(), 2);
tx.send(vec![100, 200, 300]).unwrap();
assert_eq!(ws.worker_count(), 3);
}
}
......@@ -12,6 +12,7 @@ use crate::{
kv_router::{DirectRoutingRouter, KvPushRouter, KvRouter, PrefillRouter},
migration::Migration,
model_card::ModelDeploymentCard,
namespace::NamespaceFilter,
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate,
......@@ -82,8 +83,14 @@ pub async fn prepare_engine(
)
.await?;
let inner_watch_obj = watch_obj.clone();
let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
local_model.namespace(),
local_model.namespace_prefix(),
);
let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(discovery_stream, None).await;
inner_watch_obj
.watch(discovery_stream, namespace_filter)
.await;
});
tracing::info!("Waiting for remote model..");
......
......@@ -9,7 +9,7 @@ use crate::{
entrypoint::{EngineConfig, RouterConfig, input::common},
grpc::service::kserve,
http::service::metrics::Metrics,
namespace::is_global_namespace,
namespace::NamespaceFilter,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
......@@ -38,18 +38,16 @@ pub async fn run(
let router_config = model.router_config();
let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to gRPC service
let namespace = model.namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) {
None
} else {
Some(namespace.to_string())
};
let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
model.namespace(),
model.namespace_prefix(),
);
run_watcher(
distributed_runtime.clone(),
grpc_service.state().manager_clone(),
router_config.clone(),
migration_limit,
target_namespace,
namespace_filter,
)
.await?;
grpc_service
......@@ -113,7 +111,7 @@ async fn run_watcher(
model_manager: Arc<ModelManager>,
router_config: RouterConfig,
migration_limit: u32,
target_namespace: Option<String>,
namespace_filter: NamespaceFilter,
) -> anyhow::Result<()> {
// Create metrics for migration tracking (not exposed via /metrics in gRPC mode)
let metrics = Arc::new(Metrics::new());
......@@ -140,9 +138,7 @@ async fn run_watcher(
// Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move {
watch_obj
.watch(discovery_stream, target_namespace.as_deref())
.await;
watch_obj.watch(discovery_stream, namespace_filter).await;
});
Ok(())
......
......@@ -9,7 +9,7 @@ use crate::{
engines::StreamingEngineAdapter,
entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common},
http::service::service_v2::{self, HttpService},
namespace::is_global_namespace,
namespace::NamespaceFilter,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
......@@ -66,20 +66,17 @@ pub async fn run(
let router_config = model.router_config();
let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace)
// Get namespace from the model, fallback to endpoint_id namespace if not set
let namespace = model.namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) {
None
} else {
Some(namespace.to_string())
};
// Create namespace filter from model configuration
let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
model.namespace(),
model.namespace_prefix(),
);
run_watcher(
distributed_runtime.clone(),
http_service.state().manager_clone(),
router_config.clone(),
migration_limit,
target_namespace,
namespace_filter,
Arc::new(http_service.clone()),
http_service.state().metrics_clone(),
chat_engine_factory.clone(),
......@@ -157,7 +154,7 @@ async fn run_watcher(
model_manager: Arc<ModelManager>,
router_config: RouterConfig,
migration_limit: u32,
target_namespace: Option<String>,
namespace_filter: NamespaceFilter,
http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>,
chat_engine_factory: Option<ChatEngineFactoryCallback>,
......@@ -193,9 +190,7 @@ async fn run_watcher(
// Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move {
watch_obj
.watch(discovery_stream, target_namespace.as_deref())
.await;
watch_obj.watch(discovery_stream, namespace_filter).await;
});
Ok(())
......
......@@ -377,10 +377,8 @@ impl GrpcInferenceService for KserveService {
}
}
let model = completion_request.inner.model.clone();
let parsing_options = self.state.manager.get_parsing_options(&model);
let stream = completion_response_stream(self.state_clone(), completion_request).await?;
let (stream, parsing_options) =
completion_response_stream(self.state_clone(), completion_request).await?;
let completion_response =
NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
......@@ -494,12 +492,9 @@ impl GrpcInferenceService for KserveService {
}
}
let model = completion_request.inner.model.clone();
let parsing_options = state.manager.get_parsing_options(&model);
let streaming = completion_request.inner.stream.unwrap_or(false);
let stream = completion_response_stream(state.clone(), completion_request).await?;
let (stream, parsing_options) = completion_response_stream(state.clone(), completion_request).await?;
if streaming {
pin_mut!(stream);
......
......@@ -9,6 +9,7 @@ use dynamo_runtime::{
use futures::{Stream, StreamExt, stream};
use std::sync::Arc;
use crate::protocols::openai::ParsingOptions;
use crate::protocols::openai::completions::{
NvCreateCompletionRequest, NvCreateCompletionResponse,
};
......@@ -43,7 +44,13 @@ pub const ANNOTATION_REQUEST_ID: &str = "request_id";
pub async fn completion_response_stream(
state: Arc<kserve::State>,
request: NvCreateCompletionRequest,
) -> Result<impl Stream<Item = Annotated<NvCreateCompletionResponse>>, Status> {
) -> Result<
(
impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
ParsingOptions,
),
Status,
> {
// create the context for the request
// [WIP] from request id.
let request_id = get_or_create_request_id(request.inner.user.as_deref());
......@@ -66,9 +73,9 @@ pub async fn completion_response_stream(
let model = &request.inner.model;
// todo - error handling should be more robust
let engine = state
let (engine, parsing_options) = state
.manager()
.get_completions_engine(model)
.get_completions_engine_with_parsing(model)
.map_err(|_| Status::not_found("model not found"))?;
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
......@@ -130,7 +137,7 @@ pub async fn completion_response_stream(
// without need to be cancelled.
connection_handle.disarm();
Ok(stream)
Ok((stream, parsing_options))
}
/// This method will consume an AsyncEngineStream and monitor for disconnects or context cancellation.
......
......@@ -194,9 +194,9 @@ async fn anthropic_messages(
tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state
let (engine, parsing_options) = state
.manager()
.get_chat_completions_engine(&model)
.get_chat_completions_engine_with_parsing(&model)
.map_err(|_| {
anthropic_error(
StatusCode::NOT_FOUND,
......@@ -205,8 +205,6 @@ async fn anthropic_messages(
)
})?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model);
tracing::trace!("Issuing generate call for Anthropic messages");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment