Unverified Commit c6becbc8 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: dynamo namespace isolation (#2394)


Signed-off-by: default avatarBiswa Panda <biswa.panda@gmail.com>
parent 6e073516
......@@ -4,6 +4,7 @@
# Usage: `python -m dynamo.mocker --model-path /data/models/Qwen3-0.6B-Q8_0.gguf --extra-engine-args args.json`
import argparse
import os
from pathlib import Path
import uvloop
......@@ -14,7 +15,8 @@ from dynamo.runtime.logging import configure_dynamo_logging
from . import __version__
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DYN_NAMESPACE = os.environ.get("DYN_NAMESPACE", "dynamo")
DEFAULT_ENDPOINT = f"dyn://{DYN_NAMESPACE}.backend.generate"
configure_dynamo_logging()
......
......@@ -115,7 +115,7 @@ def parse_args(args: list[str]) -> Config:
# Dynamo argument processing
# If an endpoint is provided, validate and use it
# otherwise fall back to default endpoints
namespace = os.environ.get("DYNAMO_NAMESPACE", "dynamo")
namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
endpoint = parsed_args.endpoint
if endpoint is None:
......
......@@ -26,4 +26,7 @@ if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--namespace", type=str, required=True)
args = parser.parse_args()
assert (
args.namespace
), "Missing namespace, either pass --namespace or set DYN_NAMESPACE"
asyncio.run(clear_namespace(args.namespace))
......@@ -159,7 +159,6 @@ For complete Kubernetes deployment instructions, configurations, and troubleshoo
vLLM workers are configured through command-line arguments. Key parameters include:
- `--endpoint`: Dynamo endpoint in format `dyn://namespace.component.endpoint`
- `--model`: Model to serve (e.g., `Qwen/Qwen3-0.6B`)
- `--is-prefill-worker`: Enable prefill-only mode for disaggregated serving
- `--metrics-endpoint-port`: Port for publishing KV metrics to Dynamo
......
......@@ -186,7 +186,6 @@ spec:
vLLM workers are configured through command-line arguments. Key parameters include:
- `--endpoint`: Dynamo endpoint in format `dyn://namespace.component.endpoint`
- `--model`: Model to serve (e.g., `Qwen/Qwen3-0.6B`)
- `--is-prefill-worker`: Enable prefill-only mode for disaggregated serving
- `--metrics-endpoint-port`: Port for publishing KV metrics to Dynamo
......
......@@ -4,7 +4,6 @@
import logging
import os
import sys
from typing import Optional
from vllm.config import KVTransferConfig
......@@ -29,7 +28,6 @@ from .ports import (
logger = logging.getLogger(__name__)
DEFAULT_ENDPOINT = "dyn://dynamo.backend.generate"
DEFAULT_MODEL = "Qwen/Qwen3-0.6B"
VALID_CONNECTORS = {"nixl", "lmcache", "kvbm", "null", "none"}
......@@ -72,12 +70,6 @@ def parse_args() -> Config:
parser.add_argument(
"--version", action="version", version=f"Dynamo Backend VLLM {__version__}"
)
parser.add_argument(
"--endpoint",
type=str,
default=DEFAULT_ENDPOINT,
help=f"Dynamo endpoint string in 'dyn://namespace.component.endpoint' format. Default: {DEFAULT_ENDPOINT}",
)
parser.add_argument(
"--is-prefill-worker",
action="store_true",
......@@ -145,27 +137,9 @@ def parse_args() -> Config:
# This becomes an `Option` on the Rust side
config.served_model_name = None
namespace = os.environ.get("DYNAMO_NAMESPACE", "dynamo")
if args.is_prefill_worker:
args.endpoint = f"dyn://{namespace}.prefill.generate"
else:
# For decode workers, also use the provided namespace instead of hardcoded "dynamo"
args.endpoint = f"dyn://{namespace}.backend.generate"
endpoint_str = args.endpoint.replace("dyn://", "", 1)
endpoint_parts = endpoint_str.split(".")
if len(endpoint_parts) != 3:
logger.error(
f"Invalid endpoint format: '{args.endpoint}'. Expected 'dyn://namespace.component.endpoint' or 'namespace.component.endpoint'."
)
sys.exit(1)
parsed_namespace, parsed_component_name, parsed_endpoint_name = endpoint_parts
config.namespace = parsed_namespace
config.component = parsed_component_name
config.endpoint = parsed_endpoint_name
config.namespace = os.environ.get("DYN_NAMESPACE", "dynamo")
config.component = "prefill" if args.is_prefill_worker else "backend"
config.endpoint = "generate"
config.engine_args = engine_args
config.is_prefill_worker = args.is_prefill_worker
config.migration_limit = args.migration_limit
......
......@@ -23,6 +23,7 @@
import argparse
import asyncio
import logging
import os
import pathlib
import re
......@@ -42,6 +43,10 @@ from dynamo.runtime import DistributedRuntime
from . import __version__
DYNAMO_NAMESPACE_ENV_VAR = "DYN_NAMESPACE"
logger = logging.getLogger(__name__)
def validate_static_endpoint(value):
"""Validate that static-endpoint is three words separated by dots."""
......@@ -137,6 +142,12 @@ def parse_args():
default=True,
help="KV Router: Disable KV events. When set, uses ApproxKvRouter for predicting block creation/deletion based only on incoming requests at a timer. By default, KV events are enabled.",
)
parser.add_argument(
"--namespace",
type=str,
default=os.environ.get(DYNAMO_NAMESPACE_ENV_VAR),
help="Dynamo namespace for model discovery scoping. If specified, models will only be discovered from this namespace. If not specified, discovers models from all namespaces (global discovery).",
)
parser.add_argument(
"--router-replica-sync",
action="store_true",
......@@ -240,6 +251,7 @@ async def async_main():
if flags.static_endpoint:
kwargs["endpoint_id"] = flags.static_endpoint
if flags.model_name:
kwargs["model_name"] = flags.model_name
if flags.model_path:
......@@ -248,6 +260,8 @@ async def async_main():
kwargs["tls_cert_path"] = flags.tls_cert_path
if flags.tls_key_path:
kwargs["tls_key_path"] = flags.tls_key_path
if flags.namespace:
kwargs["namespace"] = flags.namespace
if is_static:
# out=dyn://<static_endpoint>
......
......@@ -117,13 +117,14 @@ pub(crate) struct EntrypointArgs {
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
namespace: Option<String>,
}
#[pymethods]
impl EntrypointArgs {
#[allow(clippy::too_many_arguments)]
#[new]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None))]
#[pyo3(signature = (engine_type, model_path=None, model_name=None, model_config=None, endpoint_id=None, context_length=None, template_file=None, router_config=None, kv_cache_block_size=None, http_host=None, http_port=None, tls_cert_path=None, tls_key_path=None, extra_engine_args=None, namespace=None))]
pub fn new(
engine_type: EngineType,
model_path: Option<PathBuf>,
......@@ -139,6 +140,7 @@ impl EntrypointArgs {
tls_cert_path: Option<PathBuf>,
tls_key_path: Option<PathBuf>,
extra_engine_args: Option<PathBuf>,
namespace: Option<String>,
) -> PyResult<Self> {
let endpoint_id_obj: Option<EndpointId> = endpoint_id.as_deref().map(EndpointId::from);
if (tls_cert_path.is_some() && tls_key_path.is_none())
......@@ -163,6 +165,7 @@ impl EntrypointArgs {
tls_cert_path,
tls_key_path,
extra_engine_args,
namespace,
})
}
}
......@@ -195,7 +198,8 @@ pub fn make_engine<'p>(
.tls_cert_path(args.tls_cert_path.clone())
.tls_key_path(args.tls_key_path.clone())
.is_mocker(matches!(args.engine_type, EngineType::Mocker))
.extra_engine_args(args.extra_engine_args.clone());
.extra_engine_args(args.extra_engine_args.clone())
.namespace(args.namespace.clone());
pyo3_async_runtimes::tokio::future_into_py(py, async move {
let local_model = builder.build().await.map_err(to_pyerr)?;
let inner = select_engine(distributed_runtime, args, local_model)
......
......@@ -36,6 +36,7 @@ use crate::{
};
use super::{MODEL_ROOT_PATH, ModelEntry, ModelManager};
use crate::namespace::is_global_namespace;
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum ModelUpdate {
......@@ -90,8 +91,9 @@ impl ModelWatcher {
}
}
pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>) {
tracing::debug!("model watcher started");
/// Common watch logic with optional namespace filtering
pub async fn watch(&self, mut events_rx: Receiver<WatchEvent>, target_namespace: Option<&str>) {
let global_namespace = target_namespace.is_none_or(is_global_namespace);
while let Some(event) = events_rx.recv().await {
match event {
......@@ -110,6 +112,21 @@ impl ModelWatcher {
continue;
}
};
// Filter by namespace if target_namespace is specified
if !global_namespace
&& let Some(target_ns) = target_namespace
&& model_entry.endpoint_id.namespace != target_ns
{
tracing::debug!(
model_namespace = model_entry.endpoint_id.namespace,
target_namespace = target_ns,
model_name = model_entry.name,
"Skipping model from different namespace"
);
continue;
}
let key = match kv.key_str() {
Ok(k) => k,
Err(err) => {
......@@ -126,21 +143,30 @@ impl ModelWatcher {
}
if self.manager.has_model_any(&model_entry.name) {
tracing::trace!(name = model_entry.name, "New endpoint for existing model");
tracing::trace!(
name = model_entry.name,
namespace = model_entry.endpoint_id.namespace,
"New endpoint for existing model"
);
self.notify_on_model.notify_waiters();
continue;
}
match self.handle_put(&model_entry).await {
Ok(()) => {
tracing::info!(model_name = model_entry.name, "added model");
tracing::info!(
model_name = model_entry.name,
namespace = model_entry.endpoint_id.namespace,
"added model"
);
self.notify_on_model.notify_waiters();
}
Err(err) => {
tracing::error!(
error = format!("{err:#}"),
"error adding model {}",
model_entry.name
"error adding model {} from namespace {}",
model_entry.name,
model_entry.endpoint_id.namespace,
);
}
}
......
......@@ -78,7 +78,7 @@ pub async fn prepare_engine(
let inner_watch_obj = watch_obj.clone();
let _watcher_task = tokio::spawn(async move {
inner_watch_obj.watch(receiver).await;
inner_watch_obj.watch(receiver, None).await;
});
tracing::info!("Waiting for remote model..");
......
......@@ -9,6 +9,7 @@ use crate::{
entrypoint::{self, EngineConfig, input::common},
grpc::service::kserve,
kv_router::KvRouterConfig,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
......@@ -35,6 +36,12 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Some(ref etcd_client) => {
let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves in etcd, add them to gRPC service
let namespace = engine_config.local_model().namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) {
None
} else {
Some(namespace.to_string())
};
run_watcher(
distributed_runtime,
grpc_service.state().manager_clone(),
......@@ -43,6 +50,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
target_namespace,
)
.await?;
}
......@@ -137,6 +145,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
/// Spawns a task that watches for new models in etcd at network_prefix,
/// and registers them with the ModelManager so that the HTTP service can use them.
#[allow(clippy::too_many_arguments)]
async fn run_watcher(
runtime: DistributedRuntime,
model_manager: Arc<ModelManager>,
......@@ -145,6 +154,7 @@ async fn run_watcher(
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
target_namespace: Option<String>,
) -> anyhow::Result<()> {
let watch_obj = ModelWatcher::new(
runtime,
......@@ -163,7 +173,7 @@ async fn run_watcher(
// Pass the sender to the watcher
let _watcher_task = tokio::spawn(async move {
watch_obj.watch(receiver).await;
watch_obj.watch(receiver, target_namespace.as_deref()).await;
});
Ok(())
......
......@@ -11,6 +11,7 @@ use crate::{
http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig,
model_type::ModelType,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
......@@ -62,6 +63,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
Some(ref etcd_client) => {
let router_config = engine_config.local_model().router_config();
// Listen for models registering themselves in etcd, add them to HTTP service
// Check if we should filter by namespace (based on the local model's namespace)
// Get namespace from the model, fallback to endpoint_id namespace if not set
let namespace = engine_config.local_model().namespace().unwrap_or("");
let target_namespace = if is_global_namespace(namespace) {
None
} else {
Some(namespace.to_string())
};
run_watcher(
distributed_runtime,
http_service.state().manager_clone(),
......@@ -70,6 +79,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
target_namespace,
Arc::new(http_service.clone()),
)
.await?;
......@@ -195,6 +205,7 @@ async fn run_watcher(
router_mode: RouterMode,
kv_router_config: Option<KvRouterConfig>,
busy_threshold: Option<f64>,
target_namespace: Option<String>,
http_service: Arc<HttpService>,
) -> anyhow::Result<()> {
let mut watch_obj = ModelWatcher::new(
......@@ -223,7 +234,7 @@ async fn run_watcher(
// Pass the sender to the watcher
let _watcher_task = tokio::spawn(async move {
watch_obj.watch(receiver).await;
watch_obj.watch(receiver, target_namespace.as_deref()).await;
});
Ok(())
......
......@@ -28,6 +28,7 @@ pub mod migration;
pub mod mocker;
pub mod model_card;
pub mod model_type;
pub mod namespace;
pub mod perf;
pub mod preprocessor;
pub mod protocols;
......
......@@ -59,6 +59,7 @@ pub struct LocalModelBuilder {
extra_engine_args: Option<PathBuf>,
runtime_config: ModelRuntimeConfig,
user_data: Option<serde_json::Value>,
namespace: Option<String>,
}
impl Default for LocalModelBuilder {
......@@ -81,6 +82,7 @@ impl Default for LocalModelBuilder {
extra_engine_args: Default::default(),
runtime_config: Default::default(),
user_data: Default::default(),
namespace: Default::default(),
}
}
}
......@@ -142,6 +144,11 @@ impl LocalModelBuilder {
self
}
pub fn namespace(&mut self, namespace: Option<String>) -> &mut Self {
self.namespace = namespace;
self
}
pub fn request_template(&mut self, template_file: Option<PathBuf>) -> &mut Self {
self.template_file = template_file;
self
......@@ -189,6 +196,7 @@ impl LocalModelBuilder {
.endpoint_id
.take()
.unwrap_or_else(|| internal_endpoint("local_model"));
let template = self
.template_file
.as_deref()
......@@ -215,6 +223,7 @@ impl LocalModelBuilder {
tls_key_path: self.tls_key_path.take(),
router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(),
namespace: self.namespace.clone(),
});
}
......@@ -290,6 +299,7 @@ impl LocalModelBuilder {
tls_key_path: self.tls_key_path.take(),
router_config: self.router_config.take().unwrap_or_default(),
runtime_config: self.runtime_config.clone(),
namespace: self.namespace.clone(),
})
}
}
......@@ -306,6 +316,7 @@ pub struct LocalModel {
tls_key_path: Option<PathBuf>,
router_config: RouterConfig,
runtime_config: ModelRuntimeConfig,
namespace: Option<String>,
}
impl LocalModel {
......@@ -356,6 +367,10 @@ impl LocalModel {
&self.runtime_config
}
pub fn namespace(&self) -> Option<&str> {
self.namespace.as_deref()
}
pub fn is_gguf(&self) -> bool {
// GGUF is the only file (not-folder) we accept, so we don't need to check the extension
// We will error when we come to parse it
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
/// The global namespace for all models
pub const GLOBAL_NAMESPACE: &str = "dynamo";
pub fn is_global_namespace(namespace: &str) -> bool {
namespace == GLOBAL_NAMESPACE || namespace.is_empty()
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//! Integration tests for HTTP service namespace discovery functionality.
//! These tests verify that the HTTP service correctly filters models based on namespace configuration.
use dynamo_llm::{
discovery::ModelEntry,
model_type::ModelType,
namespace::{GLOBAL_NAMESPACE, is_global_namespace},
};
use dynamo_runtime::protocols::EndpointId;
// Helper function to create a test ModelEntry
fn create_test_model_entry(
name: &str,
namespace: &str,
component: &str,
endpoint_name: &str,
model_type: ModelType,
) -> ModelEntry {
ModelEntry {
name: name.to_string(),
endpoint_id: EndpointId {
namespace: namespace.to_string(),
component: component.to_string(),
name: endpoint_name.to_string(),
},
model_type,
runtime_config: None,
}
}
#[test]
fn test_namespace_filtering_behavior() {
// Test the core namespace filtering logic used in HTTP service
let test_models = vec![
create_test_model_entry(
"model-1",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
),
create_test_model_entry(
"model-2",
"sglang-prod",
"backend",
"generate",
ModelType::Chat,
),
create_test_model_entry(
"model-3",
"dynamo",
"backend",
"generate",
ModelType::Completion,
),
create_test_model_entry(
"model-4",
"tensorrt-llm",
"backend",
"generate",
ModelType::Embedding,
),
];
// Test filtering for specific namespace "vllm-agg"
let target_namespace = "vllm-agg";
let is_global = is_global_namespace(target_namespace);
let filtered_models: Vec<&ModelEntry> = test_models
.iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_models.len(), 1);
assert_eq!(filtered_models[0].name, "model-1");
assert_eq!(filtered_models[0].endpoint_id.namespace, "vllm-agg");
// Test filtering for global namespace (should include all models)
let target_namespace = GLOBAL_NAMESPACE;
let is_global = is_global_namespace(target_namespace);
let filtered_models_global: Vec<&ModelEntry> = test_models
.iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_models_global.len(), 4); // All models should be included
// Test filtering for empty namespace (treated as global)
let target_namespace = "";
let is_global = is_global_namespace(target_namespace);
let filtered_models_empty: Vec<&ModelEntry> = test_models
.iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_models_empty.len(), 4); // All models should be included
}
#[test]
fn test_endpoint_id_namespace_extraction() {
// Test endpoint ID parsing for different namespace formats
let test_cases = vec![
("vllm-agg.frontend.http", "vllm-agg", "frontend", "http"),
(
"sglang-prod.backend.generate",
"sglang-prod",
"backend",
"generate",
),
("dynamo.frontend.http", "dynamo", "frontend", "http"),
(
"tensorrt-llm.backend.inference",
"tensorrt-llm",
"backend",
"inference",
),
(
"test-namespace.component.endpoint",
"test-namespace",
"component",
"endpoint",
),
];
for (endpoint_str, expected_namespace, expected_component, expected_name) in test_cases {
let endpoint: EndpointId = endpoint_str.parse().expect("Failed to parse endpoint");
assert_eq!(endpoint.namespace, expected_namespace);
assert_eq!(endpoint.component, expected_component);
assert_eq!(endpoint.name, expected_name);
// Test namespace classification
let is_global = is_global_namespace(&endpoint.namespace);
if expected_namespace == GLOBAL_NAMESPACE {
assert!(
is_global,
"Namespace '{}' should be classified as global",
expected_namespace
);
} else {
assert!(
!is_global,
"Namespace '{}' should not be classified as global",
expected_namespace
);
}
}
}
#[test]
fn test_model_discovery_scoping_scenarios() {
// Test various scenarios for model discovery scoping
// Scenario 1: Frontend configured for specific namespace should only see models from that namespace
let frontend_namespace = "vllm-agg";
let available_models = vec![
create_test_model_entry(
"llama-7b",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
),
create_test_model_entry(
"mistral-7b",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
),
create_test_model_entry(
"gpt-3.5",
"sglang-prod",
"backend",
"generate",
ModelType::Chat,
),
create_test_model_entry("claude-3", "dynamo", "backend", "generate", ModelType::Chat),
];
let visible_models: Vec<&ModelEntry> = available_models
.iter()
.filter(|model| {
let is_global = is_global_namespace(frontend_namespace);
is_global || model.endpoint_id.namespace == frontend_namespace
})
.collect();
assert_eq!(visible_models.len(), 2);
assert!(
visible_models
.iter()
.all(|m| m.endpoint_id.namespace == "vllm-agg")
);
// Scenario 2: Frontend configured for global namespace should see all models
let frontend_namespace = GLOBAL_NAMESPACE;
let visible_models_global: Vec<&ModelEntry> = available_models
.iter()
.filter(|model| {
let is_global = is_global_namespace(frontend_namespace);
is_global || model.endpoint_id.namespace == frontend_namespace
})
.collect();
assert_eq!(visible_models_global.len(), 4); // Should see all models
// Scenario 3: Frontend configured for non-existent namespace should see no models
let frontend_namespace = "non-existent-namespace";
let visible_models_none: Vec<&ModelEntry> = available_models
.iter()
.filter(|model| {
let is_global = is_global_namespace(frontend_namespace);
is_global || model.endpoint_id.namespace == frontend_namespace
})
.collect();
assert_eq!(visible_models_none.len(), 0); // Should see no models
}
#[test]
fn test_namespace_boundary_conditions() {
// Test edge cases and boundary conditions for namespace handling
let test_models = vec![
create_test_model_entry("model-1", "", "backend", "generate", ModelType::Chat), // Empty namespace
create_test_model_entry("model-2", "dynamo", "backend", "generate", ModelType::Chat), // Global namespace
create_test_model_entry(
"model-3",
"ns-with-special-chars_123",
"backend",
"generate",
ModelType::Chat,
),
];
// Test filtering with empty target namespace (should be treated as global)
let target_namespace = "";
let is_global = is_global_namespace(target_namespace);
assert!(is_global); // Empty namespace should be treated as global
let filtered_empty: Vec<&ModelEntry> = test_models
.iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_empty.len(), 3); // All models should be visible
// Test filtering with exact "dynamo" namespace
let target_namespace = "dynamo";
let is_global = is_global_namespace(target_namespace);
assert!(is_global);
let filtered_global: Vec<&ModelEntry> = test_models
.iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_global.len(), 3); // All models should be visible
// Test case sensitivity - "GLOBAL" should not be treated as global
let target_namespace = "DYNAMO";
let is_global = is_global_namespace(target_namespace);
assert!(!is_global); // Should be case-sensitive
let filtered_uppercase: Vec<&ModelEntry> = test_models
.iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_uppercase.len(), 0); // No models should be visible
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_llm::{
discovery::ModelEntry,
model_type::ModelType,
namespace::{GLOBAL_NAMESPACE, is_global_namespace},
};
use dynamo_runtime::protocols::EndpointId;
#[test]
fn test_is_global_namespace_with_global_string() {
assert!(is_global_namespace(GLOBAL_NAMESPACE));
assert!(is_global_namespace("dynamo"));
}
#[test]
fn test_is_global_namespace_with_empty_string() {
assert!(is_global_namespace(""));
}
#[test]
fn test_is_global_namespace_with_specific_namespace() {
assert!(!is_global_namespace("test-namespace"));
assert!(!is_global_namespace("my-custom-namespace"));
}
#[test]
fn test_is_global_namespace_with_whitespace() {
// Whitespace should not be considered global
assert!(!is_global_namespace(" "));
assert!(!is_global_namespace(" "));
assert!(!is_global_namespace("\t"));
assert!(!is_global_namespace("\n"));
}
#[test]
fn test_is_global_namespace_case_sensitivity() {
// Should be case sensitive
assert!(!is_global_namespace("Dynamo"));
assert!(!is_global_namespace("DYNAMO"));
}
#[test]
fn test_global_namespace_constant() {
assert_eq!(GLOBAL_NAMESPACE, "dynamo");
}
// Helper function to create a test ModelEntry
fn create_test_model_entry(
name: &str,
namespace: &str,
component: &str,
endpoint_name: &str,
model_type: ModelType,
) -> ModelEntry {
ModelEntry {
name: name.to_string(),
endpoint_id: EndpointId {
namespace: namespace.to_string(),
component: component.to_string(),
name: endpoint_name.to_string(),
},
model_type,
runtime_config: None,
}
}
#[test]
fn test_model_entry_creation_with_different_namespaces() {
// Test creating ModelEntry with specific namespace
let model_vllm = create_test_model_entry(
"test-model-1",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
);
assert_eq!(model_vllm.name, "test-model-1");
assert_eq!(model_vllm.endpoint_id.namespace, "vllm-agg");
assert_eq!(model_vllm.endpoint_id.component, "backend");
assert_eq!(model_vllm.endpoint_id.name, "generate");
assert_eq!(model_vllm.model_type, ModelType::Chat);
// Test creating ModelEntry with global namespace
let model_global = create_test_model_entry(
"test-model-2",
"dynamo",
"frontend",
"http",
ModelType::Completion,
);
assert_eq!(model_global.name, "test-model-2");
assert_eq!(model_global.endpoint_id.namespace, "dynamo");
assert_eq!(model_global.endpoint_id.component, "frontend");
assert_eq!(model_global.endpoint_id.name, "http");
assert_eq!(model_global.model_type, ModelType::Completion);
}
#[test]
fn test_namespace_filtering_logic() {
// Test the core logic that would be used in namespace filtering
let models = vec![
create_test_model_entry(
"model-1",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
),
create_test_model_entry(
"model-2",
"sglang-prod",
"backend",
"generate",
ModelType::Chat,
),
create_test_model_entry("model-3", "dynamo", "backend", "generate", ModelType::Chat),
create_test_model_entry("model-4", "", "backend", "generate", ModelType::Chat),
];
// Test filtering for specific namespace "vllm-agg"
let target_namespace = "vllm-agg";
let global_namespace = is_global_namespace(target_namespace);
let filtered_vllm: Vec<&ModelEntry> = models
.iter()
.filter(|model| global_namespace || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_vllm.len(), 1);
assert_eq!(filtered_vllm[0].name, "model-1");
assert_eq!(filtered_vllm[0].endpoint_id.namespace, "vllm-agg");
// Test filtering for global namespace (should include all)
let target_namespace = "dynamo";
let global_namespace = is_global_namespace(target_namespace);
let filtered_global: Vec<&ModelEntry> = models
.iter()
.filter(|model| global_namespace || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_global.len(), 4); // All models should be included
// Test filtering for empty namespace (should include all, treated as global)
let target_namespace = "";
let global_namespace = is_global_namespace(target_namespace);
let filtered_empty: Vec<&ModelEntry> = models
.iter()
.filter(|model| global_namespace || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_empty.len(), 4); // All models should be included
// Test filtering for non-existent namespace
let target_namespace = "non-existent";
let global_namespace = is_global_namespace(target_namespace);
let filtered_none: Vec<&ModelEntry> = models
.iter()
.filter(|model| global_namespace || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_none.len(), 0); // No models should match
}
#[test]
fn test_model_entry_serialization() {
// Test that ModelEntry can be serialized and deserialized (important for etcd storage)
let model = create_test_model_entry(
"test-model",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
);
// Serialize to JSON
let json = serde_json::to_string(&model).expect("Failed to serialize ModelEntry");
assert!(json.contains("test-model"));
assert!(json.contains("vllm-agg"));
assert!(json.contains("backend"));
assert!(json.contains("generate"));
// Deserialize from JSON
let deserialized: ModelEntry =
serde_json::from_str(&json).expect("Failed to deserialize ModelEntry");
assert_eq!(deserialized.name, model.name);
assert_eq!(
deserialized.endpoint_id.namespace,
model.endpoint_id.namespace
);
assert_eq!(
deserialized.endpoint_id.component,
model.endpoint_id.component
);
assert_eq!(deserialized.endpoint_id.name, model.endpoint_id.name);
assert_eq!(deserialized.model_type, model.model_type);
}
#[test]
fn test_endpoint_namespace_parsing() {
// Test Endpoint creation from string with namespace
let endpoint1 = EndpointId::from("vllm-agg.backend.generate");
assert_eq!(endpoint1.namespace, "vllm-agg");
assert_eq!(endpoint1.component, "backend");
assert_eq!(endpoint1.name, "generate");
let endpoint2 = EndpointId::from("global.frontend.http");
assert_eq!(endpoint2.namespace, "global");
assert_eq!(endpoint2.component, "frontend");
assert_eq!(endpoint2.name, "http");
// Test with forward slash separator
let endpoint3 = EndpointId::from("sglang-prod/backend/generate");
assert_eq!(endpoint3.namespace, "sglang-prod");
assert_eq!(endpoint3.component, "backend");
assert_eq!(endpoint3.name, "generate");
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment