Unverified Commit cb55766c authored by Thomas Montfort's avatar Thomas Montfort Committed by GitHub
Browse files

feat(runtime): add hierarchical Model/WorkerSet architecture for multi-namespace support (#6054)


Signed-off-by: default avatartmontfort <tmontfort@nvidia.com>
parent 8dd6369e
...@@ -9,6 +9,7 @@ from dynamo._core import get_reasoning_parser_names, get_tool_parser_names ...@@ -9,6 +9,7 @@ from dynamo._core import get_reasoning_parser_names, get_tool_parser_names
from dynamo.common.configuration.arg_group import ArgGroup from dynamo.common.configuration.arg_group import ArgGroup
from dynamo.common.configuration.config_base import ConfigBase from dynamo.common.configuration.config_base import ConfigBase
from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument from dynamo.common.configuration.utils import add_argument, add_negatable_bool_argument
from dynamo.common.utils.namespace import get_worker_namespace
from dynamo.common.utils.output_modalities import OutputModality from dynamo.common.utils.output_modalities import OutputModality
...@@ -35,6 +36,8 @@ class DynamoRuntimeConfig(ConfigBase): ...@@ -35,6 +36,8 @@ class DynamoRuntimeConfig(ConfigBase):
media_output_http_url: Optional[str] = None media_output_http_url: Optional[str] = None
def validate(self) -> None: def validate(self) -> None:
self.namespace = get_worker_namespace(self.namespace)
# TODO get a better way for spot fixes like this. # TODO get a better way for spot fixes like this.
self.enable_local_indexer = not self.durable_kv_events self.enable_local_indexer = not self.durable_kv_events
self._validate_output_modalities() self._validate_output_modalities()
...@@ -69,7 +72,8 @@ class DynamoRuntimeArgGroup(ArgGroup): ...@@ -69,7 +72,8 @@ class DynamoRuntimeArgGroup(ArgGroup):
flag_name="--namespace", flag_name="--namespace",
env_var="DYN_NAMESPACE", env_var="DYN_NAMESPACE",
default="dynamo", default="dynamo",
help="Dynamo namespace", help="Dynamo namespace. If DYN_NAMESPACE_WORKER_SUFFIX is set, "
"'-{suffix}' is appended to support multiple worker pools",
) )
add_argument( add_argument(
g, g,
......
...@@ -17,6 +17,7 @@ Submodules: ...@@ -17,6 +17,7 @@ Submodules:
from dynamo.common.utils import ( from dynamo.common.utils import (
endpoint_types, endpoint_types,
engine_response, engine_response,
namespace,
otel_tracing, otel_tracing,
paths, paths,
prometheus, prometheus,
...@@ -26,6 +27,7 @@ from dynamo.common.utils import ( ...@@ -26,6 +27,7 @@ from dynamo.common.utils import (
__all__ = [ __all__ = [
"endpoint_types", "endpoint_types",
"engine_response", "engine_response",
"namespace",
"otel_tracing", "otel_tracing",
"paths", "paths",
"prometheus", "prometheus",
......
# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
import os
from typing import Optional
def get_worker_namespace(namespace: Optional[str] = None) -> str:
"""Get the Dynamo namespace for a worker.
Uses the provided namespace, or falls back to the DYN_NAMESPACE environment
variable (defaulting to "dynamo"). If DYN_NAMESPACE_WORKER_SUFFIX is set,
it is appended as "{namespace}-{suffix}" to support multiple sets of workers
for the same model.
"""
if not namespace:
namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
suffix = os.environ.get("DYN_NAMESPACE_WORKER_SUFFIX")
if suffix:
namespace = f"{namespace}-{suffix}"
return namespace
...@@ -54,6 +54,7 @@ class FrontendConfig(ConfigBase): ...@@ -54,6 +54,7 @@ class FrontendConfig(ConfigBase):
router_max_tree_size: int router_max_tree_size: int
router_prune_target_ratio: float router_prune_target_ratio: float
namespace: Optional[str] = None namespace: Optional[str] = None
namespace_prefix: Optional[str] = None
router_replica_sync: bool router_replica_sync: bool
router_snapshot_threshold: int router_snapshot_threshold: int
router_reset_states: bool router_reset_states: bool
...@@ -128,9 +129,8 @@ class FrontendArgGroup(ArgGroup): ...@@ -128,9 +129,8 @@ class FrontendArgGroup(ArgGroup):
env_var="DYN_NAMESPACE", env_var="DYN_NAMESPACE",
default=None, default=None,
help=( help=(
"Dynamo namespace for model discovery scoping. If specified, models will " "Dynamo namespace for model discovery scoping. Use for exact namespace matching. "
"only be discovered from this namespace. If not specified, discovers models " "If --namespace-prefix is also specified, prefix takes precedence."
"from all namespaces (global discovery)."
), ),
) )
...@@ -256,6 +256,18 @@ class FrontendArgGroup(ArgGroup): ...@@ -256,6 +256,18 @@ class FrontendArgGroup(ArgGroup):
arg_type=float, arg_type=float,
) )
add_argument(
g,
flag_name="--namespace-prefix",
env_var="DYN_NAMESPACE_PREFIX",
default=None,
help=(
"Dynamo namespace prefix for model discovery scoping. Discovers models from "
"namespaces starting with this prefix (e.g., 'ns' matches 'ns', 'ns-abc123', "
"'ns-def456'). Takes precedence over --namespace if both are specified."
),
)
add_negatable_bool_argument( add_negatable_bool_argument(
g, g,
flag_name="--router-replica-sync", flag_name="--router-replica-sync",
......
...@@ -230,6 +230,8 @@ async def async_main(): ...@@ -230,6 +230,8 @@ async def async_main():
kwargs["tls_key_path"] = config.tls_key_path kwargs["tls_key_path"] = config.tls_key_path
if config.namespace: if config.namespace:
kwargs["namespace"] = config.namespace kwargs["namespace"] = config.namespace
if config.namespace_prefix:
kwargs["namespace_prefix"] = config.namespace_prefix
if config.kserve_grpc_server and config.grpc_metrics_port: if config.kserve_grpc_server and config.grpc_metrics_port:
kwargs["http_metrics_port"] = config.grpc_metrics_port kwargs["http_metrics_port"] = config.grpc_metrics_port
......
...@@ -8,6 +8,8 @@ import os ...@@ -8,6 +8,8 @@ import os
import tempfile import tempfile
from pathlib import Path from pathlib import Path
from dynamo.common.utils.namespace import get_worker_namespace
from . import __version__ from . import __version__
from .utils.planner_profiler_perf_data_converter import ( from .utils.planner_profiler_perf_data_converter import (
convert_profile_results_to_npz, convert_profile_results_to_npz,
...@@ -15,7 +17,7 @@ from .utils.planner_profiler_perf_data_converter import ( ...@@ -15,7 +17,7 @@ from .utils.planner_profiler_perf_data_converter import (
is_profile_results_dir, is_profile_results_dir,
) )
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo") DYN_NAMESPACE = get_worker_namespace()
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate" DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"
DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate" DEFAULT_PREFILL_ENDPOINT = f"dyn://{DYN_NAMESPACE}.prefill.generate"
......
...@@ -7,11 +7,12 @@ This module defines the DiffusionConfig dataclass used for configuring ...@@ -7,11 +7,12 @@ This module defines the DiffusionConfig dataclass used for configuring
video and image diffusion workers. video and image diffusion workers.
""" """
import os
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional from typing import Optional
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo") from dynamo.common.utils.namespace import get_worker_namespace
DYN_NAMESPACE = get_worker_namespace()
# Default model paths # Default model paths
DEFAULT_VIDEO_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers" DEFAULT_VIDEO_MODEL_PATH = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
......
...@@ -704,6 +704,7 @@ pub unsafe extern "C" fn create_routers( ...@@ -704,6 +704,7 @@ pub unsafe extern "C" fn create_routers(
Some(prefill_config), Some(prefill_config),
enforce_disagg, enforce_disagg,
model_name.clone(), model_name.clone(),
namespace_str.clone(),
) )
} }
None if enforce_disagg => { None if enforce_disagg => {
......
...@@ -182,6 +182,7 @@ pub(crate) struct EntrypointArgs { ...@@ -182,6 +182,7 @@ pub(crate) struct EntrypointArgs {
tls_key_path: Option<PathBuf>, tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>,
is_prefill: bool, is_prefill: bool,
migration_limit: u32, migration_limit: u32,
chat_engine_factory: Option<PyEngineFactory>, chat_engine_factory: Option<PyEngineFactory>,
...@@ -191,7 +192,7 @@ pub(crate) struct EntrypointArgs { ...@@ -191,7 +192,7 @@ pub(crate) struct EntrypointArgs {
impl EntrypointArgs { impl EntrypointArgs {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[new] #[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))] #[pyo3(signature = (engine_type, model_path=None, model_name=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, http_metrics_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None, namespace_prefix=None, is_prefill=false, migration_limit=0, chat_engine_factory=None))]
pub fn new( pub fn new(
py: Python<'_>, py: Python<'_>,
engine_type: EngineType, engine_type: EngineType,
...@@ -209,6 +210,7 @@ impl EntrypointArgs { ...@@ -209,6 +210,7 @@ impl EntrypointArgs {
tls_key_path: Option<PathBuf>, tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>, extra_engine_args: Option<PathBuf>,
namespace: Option<String>, namespace: Option<String>,
namespace_prefix: Option<String>,
is_prefill: bool, is_prefill: bool,
migration_limit: u32, migration_limit: u32,
chat_engine_factory: Option<PyObject>, chat_engine_factory: Option<PyObject>,
...@@ -254,6 +256,7 @@ impl EntrypointArgs { ...@@ -254,6 +256,7 @@ impl EntrypointArgs {
tls_key_path, tls_key_path,
extra_engine_args, extra_engine_args,
namespace, namespace,
namespace_prefix,
is_prefill, is_prefill,
migration_limit, migration_limit,
chat_engine_factory, chat_engine_factory,
...@@ -296,7 +299,8 @@ pub fn make_engine<'p>( ...@@ -296,7 +299,8 @@ pub fn make_engine<'p>(
.tls_key_path(args.tls_key_path.clone()) .tls_key_path(args.tls_key_path.clone())
.is_mocker(matches!(args.engine_type, EngineType::Mocker)) .is_mocker(matches!(args.engine_type, EngineType::Mocker))
.extra_engine_args(args.extra_engine_args.clone()) .extra_engine_args(args.extra_engine_args.clone())
.namespace(args.namespace.clone()); .namespace(args.namespace.clone())
.namespace_prefix(args.namespace_prefix.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move { pyo3_async_runtimes::tokio::future_into_py(py, async move {
if let Some(model_path) = args.model_path.clone() { if let Some(model_path) = args.model_path.clone() {
let local_path = if model_path.exists() { let local_path = if model_path.exists() {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
mod model;
pub use model::Model;
mod model_manager; mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError}; pub use model_manager::{ModelManager, ModelManagerError};
mod worker_set;
pub use worker_set::WorkerSet;
pub(crate) mod runtime_configs; pub(crate) mod runtime_configs;
pub use runtime_configs::{RuntimeConfigWatch, runtime_config_watch}; pub use runtime_configs::{RuntimeConfigWatch, runtime_config_watch};
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! A WorkerSet represents a group of workers deployed from the same configuration,
//! identified by their shared namespace. Each WorkerSet owns a complete pipeline
//! (engines, KV router, prefill router) built from its specific ModelDeploymentCard.
use std::sync::Arc;
use tokio::sync::watch;
use crate::{
discovery::KvWorkerMonitor,
kv_router::KvRouter,
model_card::ModelDeploymentCard,
types::{
generic::tensor::TensorStreamingEngine,
openai::{
chat_completions::OpenAIChatCompletionsStreamingEngine,
completions::OpenAICompletionsStreamingEngine,
embeddings::OpenAIEmbeddingsStreamingEngine, images::OpenAIImagesStreamingEngine,
videos::OpenAIVideosStreamingEngine,
},
},
};
/// A set of workers from the same namespace/configuration with their own pipeline.
pub struct WorkerSet {
/// Full namespace (e.g., "ns-abc12345")
namespace: String,
/// MDC checksum for this set's configuration
mdcsum: String,
/// The model deployment card used to build this set's pipeline
card: ModelDeploymentCard,
// Engines — each WorkerSet owns its own pipelines
pub(crate) chat_engine: Option<OpenAIChatCompletionsStreamingEngine>,
pub(crate) completions_engine: Option<OpenAICompletionsStreamingEngine>,
pub(crate) embeddings_engine: Option<OpenAIEmbeddingsStreamingEngine>,
pub(crate) images_engine: Option<OpenAIImagesStreamingEngine>,
pub(crate) videos_engine: Option<OpenAIVideosStreamingEngine>,
pub(crate) tensor_engine: Option<TensorStreamingEngine>,
/// KV router for this set's workers (if KV mode)
pub(crate) kv_router: Option<Arc<KvRouter>>,
/// Worker monitor for load-based rejection
pub(crate) worker_monitor: Option<KvWorkerMonitor>,
/// Watcher for available instance IDs (from the Client's discovery watch).
/// None for in-process models (http/grpc) which don't have a discovery client.
instance_count_rx: Option<watch::Receiver<Vec<u64>>>,
}
impl WorkerSet {
pub fn new(namespace: String, mdcsum: String, card: ModelDeploymentCard) -> Self {
Self {
namespace,
mdcsum,
card,
chat_engine: None,
completions_engine: None,
embeddings_engine: None,
images_engine: None,
videos_engine: None,
tensor_engine: None,
kv_router: None,
worker_monitor: None,
instance_count_rx: None,
}
}
pub fn namespace(&self) -> &str {
&self.namespace
}
pub fn mdcsum(&self) -> &str {
&self.mdcsum
}
pub fn card(&self) -> &ModelDeploymentCard {
&self.card
}
pub fn has_chat_engine(&self) -> bool {
self.chat_engine.is_some()
}
pub fn has_completions_engine(&self) -> bool {
self.completions_engine.is_some()
}
pub fn has_embeddings_engine(&self) -> bool {
self.embeddings_engine.is_some()
}
pub fn has_images_engine(&self) -> bool {
self.images_engine.is_some()
}
pub fn has_videos_engine(&self) -> bool {
self.videos_engine.is_some()
}
pub fn has_tensor_engine(&self) -> bool {
self.tensor_engine.is_some()
}
/// Whether this set has any decode engine (chat or completions)
pub fn has_decode_engine(&self) -> bool {
self.has_chat_engine() || self.has_completions_engine()
}
/// Whether this set tracks a prefill model (no engine, just lifecycle)
pub fn is_prefill_set(&self) -> bool {
!self.has_decode_engine()
&& !self.has_embeddings_engine()
&& !self.has_images_engine()
&& !self.has_videos_engine()
&& !self.has_tensor_engine()
}
/// Build ParsingOptions from this WorkerSet's card configuration.
pub fn parsing_options(&self) -> crate::protocols::openai::ParsingOptions {
crate::protocols::openai::ParsingOptions::new(
self.card.runtime_config.tool_call_parser.clone(),
self.card.runtime_config.reasoning_parser.clone(),
)
}
/// Number of active workers in this set, derived from the Client's discovery watcher.
/// Returns 1 for in-process models (no watcher) since they always have one local worker.
pub fn worker_count(&self) -> usize {
match &self.instance_count_rx {
Some(rx) => rx.borrow().len(),
None => 1,
}
}
/// Store the instance watcher from the Client's discovery system.
/// Must be called before the WorkerSet is wrapped in Arc.
pub fn set_instance_watcher(&mut self, rx: watch::Receiver<Vec<u64>>) {
self.instance_count_rx = Some(rx);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model_card::ModelDeploymentCard;
fn make_worker_set(namespace: &str, mdcsum: &str) -> WorkerSet {
WorkerSet::new(
namespace.to_string(),
mdcsum.to_string(),
ModelDeploymentCard::default(),
)
}
#[test]
fn test_worker_set_basics() {
let ws = make_worker_set("ns1", "abc123");
assert_eq!(ws.namespace(), "ns1");
assert_eq!(ws.mdcsum(), "abc123");
}
#[test]
fn test_no_engines_by_default() {
let ws = make_worker_set("ns1", "abc123");
assert!(!ws.has_chat_engine());
assert!(!ws.has_completions_engine());
assert!(!ws.has_embeddings_engine());
assert!(!ws.has_images_engine());
assert!(!ws.has_tensor_engine());
assert!(!ws.has_decode_engine());
assert!(ws.is_prefill_set());
}
#[test]
fn test_worker_count_without_watcher() {
// In-process models have no discovery watcher; worker_count defaults to 1
let ws = make_worker_set("ns1", "abc");
assert_eq!(ws.worker_count(), 1);
}
#[test]
fn test_worker_count_with_watcher() {
let mut ws = make_worker_set("ns1", "abc");
// Simulate a discovery watcher with 3 workers
let (tx, rx) = watch::channel(vec![1, 2, 3]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 3);
// Workers leave → count drops
tx.send(vec![1]).unwrap();
assert_eq!(ws.worker_count(), 1);
// All workers gone → count is 0
tx.send(vec![]).unwrap();
assert_eq!(ws.worker_count(), 0);
}
#[test]
fn test_worker_count_with_empty_watcher() {
// Discovery watcher starts empty (no workers have joined yet)
let mut ws = make_worker_set("ns1", "abc");
let (_tx, rx) = watch::channel::<Vec<u64>>(vec![]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 0);
}
#[test]
fn test_worker_count_updates_on_join() {
let mut ws = make_worker_set("ns1", "abc");
let (tx, rx) = watch::channel::<Vec<u64>>(vec![]);
ws.set_instance_watcher(rx);
assert_eq!(ws.worker_count(), 0);
// Workers join one by one
tx.send(vec![100]).unwrap();
assert_eq!(ws.worker_count(), 1);
tx.send(vec![100, 200]).unwrap();
assert_eq!(ws.worker_count(), 2);
tx.send(vec![100, 200, 300]).unwrap();
assert_eq!(ws.worker_count(), 3);
}
}
...@@ -12,6 +12,7 @@ use crate::{ ...@@ -12,6 +12,7 @@ use crate::{
kv_router::{DirectRoutingRouter, KvPushRouter, KvRouter, PrefillRouter}, kv_router::{DirectRoutingRouter, KvPushRouter, KvRouter, PrefillRouter},
migration::Migration, migration::Migration,
model_card::ModelDeploymentCard, model_card::ModelDeploymentCard,
namespace::NamespaceFilter,
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate, request_template::RequestTemplate,
...@@ -82,8 +83,14 @@ pub async fn prepare_engine( ...@@ -82,8 +83,14 @@ pub async fn prepare_engine(
) )
.await?; .await?;
let inner_watch_obj = watch_obj.clone(); let inner_watch_obj = watch_obj.clone();
let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
local_model.namespace(),
local_model.namespace_prefix(),
);
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(discovery_stream, None).await; inner_watch_obj
.watch(discovery_stream, namespace_filter)
.await;
}); });
tracing::info!("Waiting for remote model.."); tracing::info!("Waiting for remote model..");
......
...@@ -9,7 +9,7 @@ use crate::{ ...@@ -9,7 +9,7 @@ use crate::{
entrypoint::{EngineConfig, RouterConfig, input::common}, entrypoint::{EngineConfig, RouterConfig, input::common},
grpc::service::kserve, grpc::service::kserve,
http::service::metrics::Metrics, http::service::metrics::Metrics,
namespace::is_global_namespace, namespace::NamespaceFilter,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
...@@ -38,18 +38,16 @@ pub async fn run( ...@@ -38,18 +38,16 @@ pub async fn run(
let router_config = model.router_config(); let router_config = model.router_config();
let migration_limit = model.migration_limit(); let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to gRPC service // Listen for models registering themselves, add them to gRPC service
let namespace = model.namespace().unwrap_or(""); let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
let target_namespace = if is_global_namespace(namespace) { model.namespace(),
None model.namespace_prefix(),
} else { );
Some(namespace.to_string())
};
run_watcher( run_watcher(
distributed_runtime.clone(), distributed_runtime.clone(),
grpc_service.state().manager_clone(), grpc_service.state().manager_clone(),
router_config.clone(), router_config.clone(),
migration_limit, migration_limit,
target_namespace, namespace_filter,
) )
.await?; .await?;
grpc_service grpc_service
...@@ -113,7 +111,7 @@ async fn run_watcher( ...@@ -113,7 +111,7 @@ async fn run_watcher(
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
target_namespace: Option<String>, namespace_filter: NamespaceFilter,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Create metrics for migration tracking (not exposed via /metrics in gRPC mode) // Create metrics for migration tracking (not exposed via /metrics in gRPC mode)
let metrics = Arc::new(Metrics::new()); let metrics = Arc::new(Metrics::new());
...@@ -140,9 +138,7 @@ async fn run_watcher( ...@@ -140,9 +138,7 @@ async fn run_watcher(
// Pass the discovery stream to the watcher // Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
watch_obj watch_obj.watch(discovery_stream, namespace_filter).await;
.watch(discovery_stream, target_namespace.as_deref())
.await;
}); });
Ok(()) Ok(())
......
...@@ -9,7 +9,7 @@ use crate::{ ...@@ -9,7 +9,7 @@ use crate::{
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common}, entrypoint::{ChatEngineFactoryCallback, EngineConfig, RouterConfig, input::common},
http::service::service_v2::{self, HttpService}, http::service::service_v2::{self, HttpService},
namespace::is_global_namespace, namespace::NamespaceFilter,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
...@@ -66,20 +66,17 @@ pub async fn run( ...@@ -66,20 +66,17 @@ pub async fn run(
let router_config = model.router_config(); let router_config = model.router_config();
let migration_limit = model.migration_limit(); let migration_limit = model.migration_limit();
// Listen for models registering themselves, add them to HTTP service // Listen for models registering themselves, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace) // Create namespace filter from model configuration
// Get namespace from the model, fallback to endpoint_id namespace if not set let namespace_filter = NamespaceFilter::from_namespace_and_prefix(
let namespace = model.namespace().unwrap_or(""); model.namespace(),
let target_namespace = if is_global_namespace(namespace) { model.namespace_prefix(),
None );
} else {
Some(namespace.to_string())
};
run_watcher( run_watcher(
distributed_runtime.clone(), distributed_runtime.clone(),
http_service.state().manager_clone(), http_service.state().manager_clone(),
router_config.clone(), router_config.clone(),
migration_limit, migration_limit,
target_namespace, namespace_filter,
Arc::new(http_service.clone()), Arc::new(http_service.clone()),
http_service.state().metrics_clone(), http_service.state().metrics_clone(),
chat_engine_factory.clone(), chat_engine_factory.clone(),
...@@ -157,7 +154,7 @@ async fn run_watcher( ...@@ -157,7 +154,7 @@ async fn run_watcher(
model_manager: Arc<ModelManager>, model_manager: Arc<ModelManager>,
router_config: RouterConfig, router_config: RouterConfig,
migration_limit: u32, migration_limit: u32,
target_namespace: Option<String>, namespace_filter: NamespaceFilter,
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
chat_engine_factory: Option<ChatEngineFactoryCallback>, chat_engine_factory: Option<ChatEngineFactoryCallback>,
...@@ -193,9 +190,7 @@ async fn run_watcher( ...@@ -193,9 +190,7 @@ async fn run_watcher(
// Pass the discovery stream to the watcher // Pass the discovery stream to the watcher
let _watcher_task = tokio::spawn(async move { let _watcher_task = tokio::spawn(async move {
watch_obj watch_obj.watch(discovery_stream, namespace_filter).await;
.watch(discovery_stream, target_namespace.as_deref())
.await;
}); });
Ok(()) Ok(())
......
...@@ -377,10 +377,8 @@ impl GrpcInferenceService for KserveService { ...@@ -377,10 +377,8 @@ impl GrpcInferenceService for KserveService {
} }
} }
let model = completion_request.inner.model.clone(); let (stream, parsing_options) =
let parsing_options = self.state.manager.get_parsing_options(&model); completion_response_stream(self.state_clone(), completion_request).await?;
let stream = completion_response_stream(self.state_clone(), completion_request).await?;
let completion_response = let completion_response =
NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options) NvCreateCompletionResponse::from_annotated_stream(stream, parsing_options)
...@@ -494,12 +492,9 @@ impl GrpcInferenceService for KserveService { ...@@ -494,12 +492,9 @@ impl GrpcInferenceService for KserveService {
} }
} }
let model = completion_request.inner.model.clone();
let parsing_options = state.manager.get_parsing_options(&model);
let streaming = completion_request.inner.stream.unwrap_or(false); let streaming = completion_request.inner.stream.unwrap_or(false);
let stream = completion_response_stream(state.clone(), completion_request).await?; let (stream, parsing_options) = completion_response_stream(state.clone(), completion_request).await?;
if streaming { if streaming {
pin_mut!(stream); pin_mut!(stream);
......
...@@ -9,6 +9,7 @@ use dynamo_runtime::{ ...@@ -9,6 +9,7 @@ use dynamo_runtime::{
use futures::{Stream, StreamExt, stream}; use futures::{Stream, StreamExt, stream};
use std::sync::Arc; use std::sync::Arc;
use crate::protocols::openai::ParsingOptions;
use crate::protocols::openai::completions::{ use crate::protocols::openai::completions::{
NvCreateCompletionRequest, NvCreateCompletionResponse, NvCreateCompletionRequest, NvCreateCompletionResponse,
}; };
...@@ -43,7 +44,13 @@ pub const ANNOTATION_REQUEST_ID: &str = "request_id"; ...@@ -43,7 +44,13 @@ pub const ANNOTATION_REQUEST_ID: &str = "request_id";
pub async fn completion_response_stream( pub async fn completion_response_stream(
state: Arc<kserve::State>, state: Arc<kserve::State>,
request: NvCreateCompletionRequest, request: NvCreateCompletionRequest,
) -> Result<impl Stream<Item = Annotated<NvCreateCompletionResponse>>, Status> { ) -> Result<
(
impl Stream<Item = Annotated<NvCreateCompletionResponse>>,
ParsingOptions,
),
Status,
> {
// create the context for the request // create the context for the request
// [WIP] from request id. // [WIP] from request id.
let request_id = get_or_create_request_id(request.inner.user.as_deref()); let request_id = get_or_create_request_id(request.inner.user.as_deref());
...@@ -66,9 +73,9 @@ pub async fn completion_response_stream( ...@@ -66,9 +73,9 @@ pub async fn completion_response_stream(
let model = &request.inner.model; let model = &request.inner.model;
// todo - error handling should be more robust // todo - error handling should be more robust
let engine = state let (engine, parsing_options) = state
.manager() .manager()
.get_completions_engine(model) .get_completions_engine_with_parsing(model)
.map_err(|_| Status::not_found("model not found"))?; .map_err(|_| Status::not_found("model not found"))?;
let http_queue_guard = state.metrics_clone().create_http_queue_guard(model); let http_queue_guard = state.metrics_clone().create_http_queue_guard(model);
...@@ -130,7 +137,7 @@ pub async fn completion_response_stream( ...@@ -130,7 +137,7 @@ pub async fn completion_response_stream(
// without need to be cancelled. // without need to be cancelled.
connection_handle.disarm(); connection_handle.disarm();
Ok(stream) Ok((stream, parsing_options))
} }
/// This method will consume an AsyncEngineStream and monitor for disconnects or context cancellation. /// This method will consume an AsyncEngineStream and monitor for disconnects or context cancellation.
......
...@@ -194,9 +194,9 @@ async fn anthropic_messages( ...@@ -194,9 +194,9 @@ async fn anthropic_messages(
tracing::trace!("Getting chat completions engine for model: {}", model); tracing::trace!("Getting chat completions engine for model: {}", model);
let engine = state let (engine, parsing_options) = state
.manager() .manager()
.get_chat_completions_engine(&model) .get_chat_completions_engine_with_parsing(&model)
.map_err(|_| { .map_err(|_| {
anthropic_error( anthropic_error(
StatusCode::NOT_FOUND, StatusCode::NOT_FOUND,
...@@ -205,8 +205,6 @@ async fn anthropic_messages( ...@@ -205,8 +205,6 @@ async fn anthropic_messages(
) )
})?; })?;
let parsing_options = state.manager().get_parsing_options(&model);
let mut response_collector = state.metrics_clone().create_response_collector(&model); let mut response_collector = state.metrics_clone().create_response_collector(&model);
tracing::trace!("Issuing generate call for Anthropic messages"); tracing::trace!("Issuing generate call for Anthropic messages");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment