Unverified Commit 6ffd20a8 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore: Move model_input, model_type from ModelEntry to ModelDeploymentCard (#3292)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 50dfd3af
...@@ -4,10 +4,7 @@ ...@@ -4,10 +4,7 @@
use dynamo_runtime::{protocols, slug::Slug}; use dynamo_runtime::{protocols, slug::Slug};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::local_model::runtime_config::ModelRuntimeConfig;
local_model::runtime_config::ModelRuntimeConfig,
model_type::{ModelInput, ModelType},
};
/// [ModelEntry] contains the information to discover models /// [ModelEntry] contains the information to discover models
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
...@@ -20,17 +17,9 @@ pub struct ModelEntry { ...@@ -20,17 +17,9 @@ pub struct ModelEntry {
#[serde(rename = "endpoint")] #[serde(rename = "endpoint")]
pub endpoint_id: protocols::EndpointId, pub endpoint_id: protocols::EndpointId,
/// Specifies whether the model is a chat, completions, etc model.
pub model_type: ModelType,
/// Runtime configuration specific to this model instance /// Runtime configuration specific to this model instance
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub runtime_config: Option<ModelRuntimeConfig>, pub runtime_config: Option<ModelRuntimeConfig>,
/// Specifies the model input type.
/// `Tokens` for engines that expect pre-processed input.
/// `Text` for engines that take care of pre-processing themselves.
pub model_input: ModelInput,
} }
impl ModelEntry { impl ModelEntry {
...@@ -38,8 +27,4 @@ impl ModelEntry { ...@@ -38,8 +27,4 @@ impl ModelEntry {
pub fn slug(&self) -> Slug { pub fn slug(&self) -> Slug {
Slug::from_string(&self.name) Slug::from_string(&self.name)
} }
pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_input, ModelInput::Tokens)
}
} }
...@@ -11,8 +11,8 @@ use parking_lot::{Mutex, RwLock}; ...@@ -11,8 +11,8 @@ use parking_lot::{Mutex, RwLock};
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider; use dynamo_runtime::prelude::DistributedRuntimeProvider;
use crate::discovery::{KV_ROUTERS_ROOT_PATH, ModelEntry};
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector}; use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
use crate::{ use crate::{
kv_router::KvRouter, kv_router::KvRouter,
types::generic::tensor::TensorStreamingEngine, types::generic::tensor::TensorStreamingEngine,
...@@ -40,7 +40,7 @@ pub struct ModelManager { ...@@ -40,7 +40,7 @@ pub struct ModelManager {
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>, tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// These two are Mutex because we read and write rarely and equally // These two are Mutex because we read and write rarely and equally
entries: Mutex<HashMap<String, ModelEntry>>, cards: Mutex<HashMap<String, ModelDeploymentCard>>,
kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>, kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
} }
...@@ -57,13 +57,13 @@ impl ModelManager { ...@@ -57,13 +57,13 @@ impl ModelManager {
chat_completion_engines: RwLock::new(ModelEngines::default()), chat_completion_engines: RwLock::new(ModelEngines::default()),
embeddings_engines: RwLock::new(ModelEngines::default()), embeddings_engines: RwLock::new(ModelEngines::default()),
tensor_engines: RwLock::new(ModelEngines::default()), tensor_engines: RwLock::new(ModelEngines::default()),
entries: Mutex::new(HashMap::new()), cards: Mutex::new(HashMap::new()),
kv_choosers: Mutex::new(HashMap::new()), kv_choosers: Mutex::new(HashMap::new()),
} }
} }
pub fn get_model_entries(&self) -> Vec<ModelEntry> { pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.entries.lock().values().cloned().collect() self.cards.lock().values().cloned().collect()
} }
pub fn has_model_any(&self, model: &str) -> bool { pub fn has_model_any(&self, model: &str) -> bool {
...@@ -196,15 +196,15 @@ impl ModelManager { ...@@ -196,15 +196,15 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
/// Save a ModelEntry under an instance's etcd `models/` key so we can fetch it later when the key is /// Save a ModelDeploymentCard from an instance's etcd `models/` key so we can fetch it later when the key is
/// deleted from etcd. /// deleted from etcd.
pub fn save_model_entry(&self, key: &str, entry: ModelEntry) { pub fn save_model_card(&self, key: &str, entry: ModelDeploymentCard) {
self.entries.lock().insert(key.to_string(), entry); self.cards.lock().insert(key.to_string(), entry);
} }
/// Remove and return model entry for this instance's etcd key. We do this when the instance stops. /// Remove and return model card for this instance's etcd key. We do this when the instance stops.
pub fn remove_model_entry(&self, key: &str) -> Option<ModelEntry> { pub fn remove_model_card(&self, key: &str) -> Option<ModelDeploymentCard> {
self.entries.lock().remove(key) self.cards.lock().remove(key)
} }
pub async fn kv_chooser_for( pub async fn kv_chooser_for(
...@@ -279,12 +279,11 @@ impl ModelManager { ...@@ -279,12 +279,11 @@ impl ModelManager {
} }
pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> { pub fn get_model_tool_call_parser(&self, model: &str) -> Option<String> {
self.entries self.cards
.lock() .lock()
.values() .values()
.find(|entry| entry.name == model) .find(|c| c.display_name == model)
.and_then(|entry| entry.runtime_config.as_ref()) .and_then(|c| c.runtime_config.tool_call_parser.as_ref())
.and_then(|config| config.tool_call_parser.clone())
.map(|parser| parser.to_string()) .map(|parser| parser.to_string())
} }
......
...@@ -40,10 +40,10 @@ use crate::{ ...@@ -40,10 +40,10 @@ use crate::{
use super::{MODEL_ROOT_PATH, ModelEntry, ModelManager}; use super::{MODEL_ROOT_PATH, ModelEntry, ModelManager};
use crate::namespace::is_global_namespace; use crate::namespace::is_global_namespace;
#[derive(Debug, Clone, Copy, PartialEq)] #[derive(Debug, Clone)]
pub enum ModelUpdate { pub enum ModelUpdate {
Added(ModelType), Added(ModelDeploymentCard),
Removed(ModelType), Removed(ModelDeploymentCard),
} }
pub struct ModelWatcher { pub struct ModelWatcher {
...@@ -140,25 +140,8 @@ impl ModelWatcher { ...@@ -140,25 +140,8 @@ impl ModelWatcher {
continue; continue;
} }
}; };
self.manager.save_model_entry(key, model_entry.clone());
if let Some(tx) = &self.model_update_tx { match self.handle_put(key, &model_entry).await {
tx.send(ModelUpdate::Added(model_entry.model_type))
.await
.ok();
}
if self.manager.has_model_any(&model_entry.name) {
tracing::trace!(
name = model_entry.name,
namespace = model_entry.endpoint_id.namespace,
"New endpoint for existing model"
);
self.notify_on_model.notify_waiters();
continue;
}
match self.handle_put(&model_entry).await {
Ok(()) => { Ok(()) => {
tracing::info!( tracing::info!(
model_name = model_entry.name, model_name = model_entry.name,
...@@ -196,13 +179,13 @@ impl ModelWatcher { ...@@ -196,13 +179,13 @@ impl ModelWatcher {
/// Returns the name of the model we just deleted, if any. /// Returns the name of the model we just deleted, if any.
async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> { async fn handle_delete(&self, kv: &KeyValue) -> anyhow::Result<Option<String>> {
let key = kv.key_str()?; let key = kv.key_str()?;
let model_entry = match self.manager.remove_model_entry(key) { let card = match self.manager.remove_model_card(key) {
Some(entry) => entry, Some(card) => card,
None => { None => {
anyhow::bail!("Missing ModelEntry for {key}"); anyhow::bail!("Missing ModelDeploymentCard for {key}");
} }
}; };
let model_name = model_entry.name; let model_name = card.display_name.clone();
let active_instances = self let active_instances = self
.entries_for_model(&model_name) .entries_for_model(&model_name)
.await .await
...@@ -257,7 +240,7 @@ impl ModelWatcher { ...@@ -257,7 +240,7 @@ impl ModelWatcher {
|| (tensor_model_removed && *model_type == ModelType::TensorBased)) || (tensor_model_removed && *model_type == ModelType::TensorBased))
&& let Some(tx) = &self.model_update_tx && let Some(tx) = &self.model_update_tx
{ {
tx.send(ModelUpdate::Removed(*model_type)).await.ok(); tx.send(ModelUpdate::Removed(card.clone())).await.ok();
} }
} }
} }
...@@ -267,7 +250,7 @@ impl ModelWatcher { ...@@ -267,7 +250,7 @@ impl ModelWatcher {
// Handles a PUT event from etcd, this usually means adding a new model to the list of served // Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models. // models.
async fn handle_put(&self, model_entry: &ModelEntry) -> anyhow::Result<()> { async fn handle_put(&self, key: &str, model_entry: &ModelEntry) -> anyhow::Result<()> {
let endpoint_id = &model_entry.endpoint_id; let endpoint_id = &model_entry.endpoint_id;
let component = self let component = self
.drt .drt
...@@ -294,9 +277,24 @@ impl ModelWatcher { ...@@ -294,9 +277,24 @@ impl ModelWatcher {
} }
}; };
if model_entry.model_input == ModelInput::Tokens self.manager.save_model_card(key, card.clone());
&& (model_entry.model_type.supports_chat()
|| model_entry.model_type.supports_completions()) if self.manager.has_model_any(&model_entry.name) {
tracing::trace!(
name = model_entry.name,
namespace = model_entry.endpoint_id.namespace,
"New endpoint for existing model"
);
self.notify_on_model.notify_waiters();
return Ok(());
}
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Added(card.clone())).await.ok();
}
if card.model_input == ModelInput::Tokens
&& (card.model_type.supports_chat() || card.model_type.supports_completions())
{ {
// Case 1: Tokens + (Chat OR Completions OR Both) // Case 1: Tokens + (Chat OR Completions OR Both)
// A model that expects pre-processed requests meaning it's up to us whether we // A model that expects pre-processed requests meaning it's up to us whether we
...@@ -321,7 +319,7 @@ impl ModelWatcher { ...@@ -321,7 +319,7 @@ impl ModelWatcher {
let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?; let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
// Add chat engine only if the model supports chat // Add chat engine only if the model supports chat
if model_entry.model_type.supports_chat() { if card.model_type.supports_chat() {
let chat_engine = entrypoint::build_routed_pipeline::< let chat_engine = entrypoint::build_routed_pipeline::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
...@@ -342,7 +340,7 @@ impl ModelWatcher { ...@@ -342,7 +340,7 @@ impl ModelWatcher {
} }
// Add completions engine only if the model supports completions // Add completions engine only if the model supports completions
if model_entry.model_type.supports_completions() { if card.model_type.supports_completions() {
let formatter = PromptFormatter::no_op(); let formatter = PromptFormatter::no_op();
let PromptFormatter::OAI(formatter) = formatter; let PromptFormatter::OAI(formatter) = formatter;
let preprocessor = OpenAIPreprocessor::new_with_parts( let preprocessor = OpenAIPreprocessor::new_with_parts(
...@@ -370,9 +368,7 @@ impl ModelWatcher { ...@@ -370,9 +368,7 @@ impl ModelWatcher {
.context("add_completions_model")?; .context("add_completions_model")?;
tracing::info!("Completions is ready"); tracing::info!("Completions is ready");
} }
} else if model_entry.model_input == ModelInput::Text } else if card.model_input == ModelInput::Text && card.model_type.supports_chat() {
&& model_entry.model_type.supports_chat()
{
// Case 3: Text + Chat // Case 3: Text + Chat
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
...@@ -384,9 +380,7 @@ impl ModelWatcher { ...@@ -384,9 +380,7 @@ impl ModelWatcher {
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_chat_completions_model(&model_entry.name, engine)?; .add_chat_completions_model(&model_entry.name, engine)?;
} else if model_entry.model_input == ModelInput::Text } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
&& model_entry.model_type.supports_completions()
{
// Case 2: Text + Completions // Case 2: Text + Completions
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
...@@ -398,9 +392,7 @@ impl ModelWatcher { ...@@ -398,9 +392,7 @@ impl ModelWatcher {
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_completions_model(&model_entry.name, engine)?; .add_completions_model(&model_entry.name, engine)?;
} else if model_entry.model_input == ModelInput::Tokens } else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
&& model_entry.model_type.supports_embedding()
{
// Case 4: Tokens + Embeddings // Case 4: Tokens + Embeddings
// Create preprocessing pipeline similar to Backend // Create preprocessing pipeline similar to Backend
...@@ -434,9 +426,7 @@ impl ModelWatcher { ...@@ -434,9 +426,7 @@ impl ModelWatcher {
self.manager self.manager
.add_embeddings_model(&model_entry.name, embedding_engine)?; .add_embeddings_model(&model_entry.name, embedding_engine)?;
} else if model_entry.model_input == ModelInput::Tensor } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
&& model_entry.model_type.supports_tensor()
{
// Case 5: Tensor + Tensor (non-LLM) // Case 5: Tensor + Tensor (non-LLM)
let push_router = PushRouter::< let push_router = PushRouter::<
NvCreateTensorRequest, NvCreateTensorRequest,
...@@ -452,8 +442,8 @@ impl ModelWatcher { ...@@ -452,8 +442,8 @@ impl ModelWatcher {
anyhow::bail!( anyhow::bail!(
"Unsupported model configuration: {} with {} input. Supported combinations: \ "Unsupported model configuration: {} with {} input. Supported combinations: \
Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased", Tokens+(Chat|Completions), Text+Chat, Text+Completions, Tokens+Embeddings, Tensor+TensorBased",
model_entry.model_type, card.model_type,
model_entry.model_input.as_str() card.model_input.as_str()
); );
} }
......
...@@ -220,9 +220,6 @@ async fn run_watcher( ...@@ -220,9 +220,6 @@ async fn run_watcher(
http_service: Arc<HttpService>, http_service: Arc<HttpService>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Clone model_manager before it's moved into ModelWatcher
let model_manager_clone = model_manager.clone();
let mut watch_obj = ModelWatcher::new( let mut watch_obj = ModelWatcher::new(
runtime, runtime,
model_manager, model_manager,
...@@ -241,20 +238,12 @@ async fn run_watcher( ...@@ -241,20 +238,12 @@ async fn run_watcher(
// Spawn a task to watch for model type changes and update HTTP service endpoints and metrics // Spawn a task to watch for model type changes and update HTTP service endpoints and metrics
let _endpoint_enabler_task = tokio::spawn(async move { let _endpoint_enabler_task = tokio::spawn(async move {
while let Some(model_type) = rx.recv().await { while let Some(model_update) = rx.recv().await {
tracing::debug!("Received model type update: {:?}", model_type);
// Update HTTP endpoints (existing functionality) // Update HTTP endpoints (existing functionality)
update_http_endpoints(http_service.clone(), model_type); update_http_endpoints(http_service.clone(), model_update.clone());
// Update metrics (only for added models) // Update metrics (only for added models)
update_model_metrics( update_model_metrics(model_update, metrics.clone());
model_type,
model_manager_clone.clone(),
metrics.clone(),
Some(etcd_client.clone()),
)
.await;
} }
}); });
...@@ -273,15 +262,15 @@ fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) { ...@@ -273,15 +262,15 @@ fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
model_type model_type
); );
match model_type { match model_type {
ModelUpdate::Added(model_type) => { ModelUpdate::Added(card) => {
// Handle all supported endpoint types, not just the first one // Handle all supported endpoint types, not just the first one
for endpoint_type in model_type.as_endpoint_types() { for endpoint_type in card.model_type.as_endpoint_types() {
service.enable_model_endpoint(endpoint_type, true); service.enable_model_endpoint(endpoint_type, true);
} }
} }
ModelUpdate::Removed(model_type) => { ModelUpdate::Removed(card) => {
// Handle all supported endpoint types, not just the first one // Handle all supported endpoint types, not just the first one
for endpoint_type in model_type.as_endpoint_types() { for endpoint_type in card.model_type.as_endpoint_types() {
service.enable_model_endpoint(endpoint_type, false); service.enable_model_endpoint(endpoint_type, false);
} }
} }
...@@ -289,42 +278,19 @@ fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) { ...@@ -289,42 +278,19 @@ fn update_http_endpoints(service: Arc<HttpService>, model_type: ModelUpdate) {
} }
/// Updates metrics for model type changes /// Updates metrics for model type changes
async fn update_model_metrics( fn update_model_metrics(
model_type: ModelUpdate, model_type: ModelUpdate,
model_manager: Arc<ModelManager>,
metrics: Arc<crate::http::service::metrics::Metrics>, metrics: Arc<crate::http::service::metrics::Metrics>,
etcd_client: Option<etcd::Client>,
) { ) {
match model_type { match model_type {
ModelUpdate::Added(model_type) => { ModelUpdate::Added(card) => {
tracing::debug!("Updating metrics for added model type: {:?}", model_type); tracing::debug!("Updating metrics for added model: {}", card.display_name);
if let Err(err) = metrics.update_metrics_from_mdc(&card) {
// Get all model entries and update metrics for matching types tracing::warn!(%err, model_name=card.display_name, "update_metrics_from_mdc failed");
let model_entries = model_manager.get_model_entries();
for entry in model_entries {
if entry.model_type == model_type {
// Update runtime config metrics if available
if let Some(runtime_config) = &entry.runtime_config {
metrics.update_runtime_config_metrics(&entry.name, runtime_config);
}
// Update MDC metrics if etcd is available
if let Some(ref etcd) = etcd_client
&& let Err(e) = metrics
.update_metrics_from_model_entry_with_mdc(&entry, etcd)
.await
{
tracing::debug!(
model = %entry.name,
error = %e,
"Failed to update MDC metrics for newly added model"
);
}
}
} }
} }
ModelUpdate::Removed(model_type) => { ModelUpdate::Removed(card) => {
tracing::debug!("Model type removed: {:?}", model_type); tracing::debug!(model_name = card.display_name, "Model removed");
// Note: Metrics are typically not removed to preserve historical data // Note: Metrics are typically not removed to preserve historical data
// This matches the behavior in the polling task // This matches the behavior in the polling task
} }
......
...@@ -397,15 +397,14 @@ impl GrpcInferenceService for KserveService { ...@@ -397,15 +397,14 @@ impl GrpcInferenceService for KserveService {
&self, &self,
request: Request<ModelMetadataRequest>, request: Request<ModelMetadataRequest>,
) -> Result<Response<ModelMetadataResponse>, Status> { ) -> Result<Response<ModelMetadataResponse>, Status> {
let entries = self.state.manager().get_model_entries(); let cards = self.state.manager().get_model_cards();
let request_model_name = &request.into_inner().name; let request_model_name = &request.into_inner().name;
if let Some(entry) = entries if let Some(card) = cards
.into_iter() .into_iter()
.find(|entry| request_model_name == &entry.name) .find(|card| request_model_name == &card.display_name)
{ {
if entry.model_type.supports_tensor() { if card.model_type.supports_tensor() {
if let Some(config) = entry.runtime_config.as_ref() if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref()
&& let Some(tensor_model_config) = config.tensor_model_config.as_ref()
{ {
return Ok(Response::new(ModelMetadataResponse { return Ok(Response::new(ModelMetadataResponse {
name: tensor_model_config.name.clone(), name: tensor_model_config.name.clone(),
...@@ -437,9 +436,9 @@ impl GrpcInferenceService for KserveService { ...@@ -437,9 +436,9 @@ impl GrpcInferenceService for KserveService {
"Model '{}' has type Tensor but no model config is provided", "Model '{}' has type Tensor but no model config is provided",
request_model_name request_model_name
)))? )))?
} else if entry.model_type.supports_completions() { } else if card.model_type.supports_completions() {
return Ok(Response::new(ModelMetadataResponse { return Ok(Response::new(ModelMetadataResponse {
name: entry.name, name: card.display_name,
versions: vec!["1".to_string()], versions: vec!["1".to_string()],
platform: "dynamo".to_string(), platform: "dynamo".to_string(),
inputs: vec![ inputs: vec![
...@@ -479,15 +478,14 @@ impl GrpcInferenceService for KserveService { ...@@ -479,15 +478,14 @@ impl GrpcInferenceService for KserveService {
&self, &self,
request: Request<ModelConfigRequest>, request: Request<ModelConfigRequest>,
) -> Result<Response<ModelConfigResponse>, Status> { ) -> Result<Response<ModelConfigResponse>, Status> {
let entries = self.state.manager().get_model_entries(); let cards = self.state.manager().get_model_cards();
let request_model_name = &request.into_inner().name; let request_model_name = &request.into_inner().name;
if let Some(entry) = entries if let Some(card) = cards
.into_iter() .into_iter()
.find(|entry| request_model_name == &entry.name) .find(|card| request_model_name == &card.display_name)
{ {
if entry.model_type.supports_tensor() { if card.model_type.supports_tensor() {
if let Some(config) = entry.runtime_config.as_ref() if let Some(tensor_model_config) = card.runtime_config.tensor_model_config.as_ref()
&& let Some(tensor_model_config) = config.tensor_model_config.as_ref()
{ {
let model_config = ModelConfig { let model_config = ModelConfig {
name: tensor_model_config.name.clone(), name: tensor_model_config.name.clone(),
...@@ -523,9 +521,9 @@ impl GrpcInferenceService for KserveService { ...@@ -523,9 +521,9 @@ impl GrpcInferenceService for KserveService {
"Model '{}' has type Tensor but no model config is provided", "Model '{}' has type Tensor but no model config is provided",
request_model_name request_model_name
)))? )))?
} else if entry.model_type.supports_completions() { } else if card.model_type.supports_completions() {
let config = ModelConfig { let config = ModelConfig {
name: entry.name, name: card.display_name,
platform: "dynamo".to_string(), platform: "dynamo".to_string(),
backend: "dynamo".to_string(), backend: "dynamo".to_string(),
input: vec![ input: vec![
......
...@@ -52,7 +52,6 @@ async fn live_handler( ...@@ -52,7 +52,6 @@ async fn live_handler(
async fn health_handler( async fn health_handler(
axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>, axum::extract::State(state): axum::extract::State<Arc<service_v2::State>>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let model_entries = state.manager().get_model_entries();
let instances = if let Some(etcd_client) = state.etcd_client() { let instances = if let Some(etcd_client) = state.etcd_client() {
match list_all_instances(etcd_client).await { match list_all_instances(etcd_client).await {
Ok(instances) => instances, Ok(instances) => instances,
...@@ -65,10 +64,12 @@ async fn health_handler( ...@@ -65,10 +64,12 @@ async fn health_handler(
vec![] vec![]
}; };
let endpoints: Vec<String> = model_entries let mut endpoints: Vec<String> = instances
.iter() .iter()
.map(|entry| entry.endpoint_id.as_url()) .map(|instance| instance.endpoint_id().as_url())
.collect(); .collect();
endpoints.sort();
endpoints.dedup();
( (
StatusCode::OK, StatusCode::OK,
Json(json!({ Json(json!({
......
...@@ -18,12 +18,9 @@ use std::{ ...@@ -18,12 +18,9 @@ use std::{
time::{Duration, Instant}, time::{Duration, Instant},
}; };
use crate::discovery::ModelEntry;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_card::{ModelDeploymentCard, ROOT_PATH as MDC_ROOT_PATH}; use crate::model_card::ModelDeploymentCard;
use dynamo_runtime::metrics::prometheus_names::clamp_u64_to_i64; use dynamo_runtime::metrics::prometheus_names::clamp_u64_to_i64;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager};
pub use prometheus::Registry; pub use prometheus::Registry;
...@@ -472,60 +469,27 @@ impl Metrics { ...@@ -472,60 +469,27 @@ impl Metrics {
} }
} }
/// Update metrics from a ModelEntry and its ModelDeploymentCard /// Update metrics from a ModelDeploymentCard
/// This updates both runtime config metrics and MDC-specific metrics /// This updates both runtime config metrics and MDC-specific metrics
pub async fn update_metrics_from_model_entry_with_mdc( pub fn update_metrics_from_mdc(&self, card: &ModelDeploymentCard) -> anyhow::Result<()> {
&self, self.update_runtime_config_metrics(&card.display_name, &card.runtime_config);
model_entry: &ModelEntry,
etcd_client: &dynamo_runtime::transports::etcd::Client,
) -> anyhow::Result<()> {
// Update runtime config metrics
if let Some(runtime_config) = &model_entry.runtime_config {
self.update_runtime_config_metrics(&model_entry.name, runtime_config);
}
// Load and update MDC metrics self.model_context_length
let model_slug = Slug::from_string(&model_entry.name); .with_label_values(&[&card.display_name])
let store: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone())); .set(card.context_length as i64);
let card_store = Arc::new(KeyValueStoreManager::new(store));
match card_store self.model_kv_cache_block_size
.load::<ModelDeploymentCard>(MDC_ROOT_PATH, &model_slug) .with_label_values(&[&card.display_name])
.await .set(card.kv_cache_block_size as i64);
{
Ok(Some(mdc)) => { self.model_migration_limit
// Inline MDC metrics update .with_label_values(&[&card.display_name])
self.model_context_length .set(card.migration_limit as i64);
.with_label_values(&[&model_entry.name])
.set(mdc.context_length as i64); tracing::debug!(
model = %card.display_name,
self.model_kv_cache_block_size "Successfully updated MDC metrics"
.with_label_values(&[&model_entry.name]) );
.set(mdc.kv_cache_block_size as i64);
self.model_migration_limit
.with_label_values(&[&model_entry.name])
.set(mdc.migration_limit as i64);
tracing::debug!(
model = %model_entry.name,
"Successfully updated MDC metrics"
);
}
Ok(None) => {
tracing::debug!(
model = %model_entry.name,
"No MDC found in storage, skipping MDC metrics"
);
}
Err(e) => {
tracing::debug!(
model = %model_entry.name,
error = %e,
"Failed to load MDC for metrics update"
);
}
}
Ok(()) Ok(())
} }
......
...@@ -410,6 +410,8 @@ impl LocalModel { ...@@ -410,6 +410,8 @@ impl LocalModel {
let Some(etcd_client) = endpoint.drt().etcd_client() else { let Some(etcd_client) = endpoint.drt().etcd_client() else {
anyhow::bail!("Cannot attach to static endpoint"); anyhow::bail!("Cannot attach to static endpoint");
}; };
self.card.model_type = model_type;
self.card.model_input = model_input;
// Store model config files in NATS object store // Store model config files in NATS object store
let nats_client = endpoint.drt().nats_client(); let nats_client = endpoint.drt().nats_client();
...@@ -431,9 +433,7 @@ impl LocalModel { ...@@ -431,9 +433,7 @@ impl LocalModel {
let model_registration = ModelEntry { let model_registration = ModelEntry {
name: self.display_name().to_string(), name: self.display_name().to_string(),
endpoint_id: endpoint.id(), endpoint_id: endpoint.id(),
model_type,
runtime_config: Some(self.runtime_config.clone()), runtime_config: Some(self.runtime_config.clone()),
model_input,
}; };
etcd_client etcd_client
.kv_create( .kv_create(
......
...@@ -16,10 +16,10 @@ use std::fmt; ...@@ -16,10 +16,10 @@ use std::fmt;
use std::fs::File; use std::fs::File;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration;
use crate::common::checked_file::CheckedFile; use crate::common::checked_file::CheckedFile;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_type::{ModelInput, ModelType};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
...@@ -34,9 +34,6 @@ use crate::protocols::TokenIdType; ...@@ -34,9 +34,6 @@ use crate::protocols::TokenIdType;
/// Identify model deployment cards in the key-value store /// Identify model deployment cards in the key-value store
pub const ROOT_PATH: &str = "mdc"; pub const ROOT_PATH: &str = "mdc";
/// If a model deployment card hasn't been refreshed in this much time the worker is likely gone
const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5);
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum ModelInfoType { pub enum ModelInfoType {
...@@ -118,9 +115,6 @@ pub struct ModelDeploymentCard { ...@@ -118,9 +115,6 @@ pub struct ModelDeploymentCard {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_context: Option<Vec<PromptContextMixin>>, pub prompt_context: Option<Vec<PromptContextMixin>>,
/// When this card was last advertised by a worker. None if not yet published.
pub last_published: Option<chrono::DateTime<chrono::Utc>>,
/// Max context (in number of tokens) this model can handle /// Max context (in number of tokens) this model can handle
pub context_length: u32, pub context_length: u32,
...@@ -132,6 +126,14 @@ pub struct ModelDeploymentCard { ...@@ -132,6 +126,14 @@ pub struct ModelDeploymentCard {
/// connection to the current worker. /// connection to the current worker.
pub migration_limit: u32, pub migration_limit: u32,
/// Specifies whether the model is a chat, completions, etc model.
pub model_type: ModelType,
/// Specifies the model input type.
/// `Tokens` for engines that expect pre-processed input.
/// `Text` for engines that take care of pre-processing themselves.
pub model_input: ModelInput,
/// User-defined metadata for custom worker behavior /// User-defined metadata for custom worker behavior
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub user_data: Option<serde_json::Value>, pub user_data: Option<serde_json::Value>,
...@@ -161,17 +163,6 @@ impl ModelDeploymentCard { ...@@ -161,17 +163,6 @@ impl ModelDeploymentCard {
} }
} }
/// How often we should check if a model deployment card expired because it's workers are gone
pub fn expiry_check_period() -> Duration {
match CARD_MAX_AGE.to_std() {
Ok(duration) => duration / 3,
Err(_) => {
// Only happens if CARD_MAX_AGE is negative, which it isn't
unreachable!("Cannot run card expiry watcher, invalid CARD_MAX_AGE");
}
}
}
/// Load a model deployment card from a JSON file /// Load a model deployment card from a JSON file
pub fn load_from_json_file<P: AsRef<Path>>(file: P) -> std::io::Result<Self> { pub fn load_from_json_file<P: AsRef<Path>>(file: P) -> std::io::Result<Self> {
let contents = std::fs::read_to_string(&file)?; let contents = std::fs::read_to_string(&file)?;
...@@ -210,15 +201,6 @@ impl ModelDeploymentCard { ...@@ -210,15 +201,6 @@ impl ModelDeploymentCard {
format!("{}", blake3::hash(json.as_bytes())) format!("{}", blake3::hash(json.as_bytes()))
} }
/// Was this card last published a long time ago, suggesting the worker is gone?
pub fn is_expired(&self) -> bool {
if let Some(last_published) = self.last_published.as_ref() {
chrono::Utc::now() - last_published > CARD_MAX_AGE
} else {
false
}
}
/// Is this a full model card with tokenizer? /// Is this a full model card with tokenizer?
/// There are cases where we have a placeholder card (see `with_name_only`). /// There are cases where we have a placeholder card (see `with_name_only`).
pub fn has_tokenizer(&self) -> bool { pub fn has_tokenizer(&self) -> bool {
...@@ -405,6 +387,10 @@ impl ModelDeploymentCard { ...@@ -405,6 +387,10 @@ impl ModelDeploymentCard {
} }
} }
pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_input, ModelInput::Tokens)
}
/// Load a ModelDeploymentCard from storage the DistributedRuntime is configured to use. /// Load a ModelDeploymentCard from storage the DistributedRuntime is configured to use.
/// Card should be fully local and ready to use when the call returns. /// Card should be fully local and ready to use when the call returns.
pub async fn load_from_store( pub async fn load_from_store(
...@@ -491,10 +477,11 @@ impl ModelDeploymentCard { ...@@ -491,10 +477,11 @@ impl ModelDeploymentCard {
prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())), prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())),
chat_template_file: None, chat_template_file: None,
prompt_context: None, // TODO - auto-detect prompt context prompt_context: None, // TODO - auto-detect prompt context
last_published: None,
context_length, context_length,
kv_cache_block_size: 0, kv_cache_block_size: 0,
migration_limit: 0, migration_limit: 0,
model_type: Default::default(), // set later
model_input: Default::default(), // set later
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
cache_dir: None, cache_dir: None,
...@@ -554,10 +541,11 @@ impl ModelDeploymentCard { ...@@ -554,10 +541,11 @@ impl ModelDeploymentCard {
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id)?, prompt_formatter: PromptFormatterArtifact::from_repo(repo_id)?,
chat_template_file, chat_template_file,
prompt_context: None, // TODO - auto-detect prompt context prompt_context: None, // TODO - auto-detect prompt context
last_published: None,
context_length, context_length,
kv_cache_block_size: 0, // set later kv_cache_block_size: 0, // set later
migration_limit: 0, migration_limit: 0,
model_type: Default::default(), // set later
model_input: Default::default(), // set later
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
cache_dir: None, cache_dir: None,
...@@ -565,15 +553,19 @@ impl ModelDeploymentCard { ...@@ -565,15 +553,19 @@ impl ModelDeploymentCard {
} }
} }
impl PartialEq for ModelDeploymentCard {
fn eq(&self, other: &ModelDeploymentCard) -> bool {
self.mdcsum() == other.mdcsum()
}
}
/// A ModelDeploymentCard is published a single time per instance and never updated. /// A ModelDeploymentCard is published a single time per instance and never updated.
impl Versioned for ModelDeploymentCard { impl Versioned for ModelDeploymentCard {
fn revision(&self) -> u64 { fn revision(&self) -> u64 {
0 0
} }
fn set_revision(&mut self, _revision: u64) { fn set_revision(&mut self, _revision: u64) {}
self.last_published = Some(chrono::Utc::now());
}
} }
impl fmt::Display for ModelDeploymentCard { impl fmt::Display for ModelDeploymentCard {
......
...@@ -30,7 +30,7 @@ bitflags! { ...@@ -30,7 +30,7 @@ bitflags! {
/// Using bitflags avoids deep branching on a single enum variant, /// Using bitflags avoids deep branching on a single enum variant,
/// simplifies checks like `supports_chat()`, and enables efficient, /// simplifies checks like `supports_chat()`, and enables efficient,
/// type-safe combinations of multiple endpoint types within a single byte. /// type-safe combinations of multiple endpoint types within a single byte.
#[derive(Copy, Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Copy, Debug, Default, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelType: u8 { pub struct ModelType: u8 {
const Chat = 1 << 0; const Chat = 1 << 0;
const Completions = 1 << 1; const Completions = 1 << 1;
...@@ -100,9 +100,10 @@ impl fmt::Display for ModelType { ...@@ -100,9 +100,10 @@ impl fmt::Display for ModelType {
} }
} }
#[derive(Copy, Debug, Clone, Display, Serialize, Deserialize, Eq, PartialEq)] #[derive(Copy, Debug, Default, Clone, Display, Serialize, Deserialize, Eq, PartialEq)]
pub enum ModelInput { pub enum ModelInput {
/// Raw text input /// Raw text input
#[default]
Text, Text,
/// Pre-processed input /// Pre-processed input
Tokens, Tokens,
......
...@@ -293,12 +293,13 @@ async fn test_metrics_with_mock_model() { ...@@ -293,12 +293,13 @@ async fn test_metrics_with_mock_model() {
mod integration_tests { mod integration_tests {
use super::*; use super::*;
use dynamo_llm::{ use dynamo_llm::{
discovery::{ModelEntry, ModelWatcher}, discovery::{MODEL_ROOT_PATH, ModelEntry, ModelWatcher},
engines::make_echo_engine, engines::make_echo_engine,
entrypoint::EngineConfig, entrypoint::EngineConfig,
local_model::LocalModelBuilder, local_model::LocalModelBuilder,
}; };
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
use std::sync::Arc; use std::sync::Arc;
#[tokio::test] #[tokio::test]
...@@ -335,8 +336,6 @@ mod integration_tests { ...@@ -335,8 +336,6 @@ mod integration_tests {
// Set up model watcher to discover models from etcd (like production) // Set up model watcher to discover models from etcd (like production)
// This is crucial for the polling task to find model entries // This is crucial for the polling task to find model entries
use dynamo_llm::discovery::{MODEL_ROOT_PATH, ModelWatcher};
use dynamo_runtime::pipeline::RouterMode;
let model_watcher = ModelWatcher::new( let model_watcher = ModelWatcher::new(
distributed_runtime.clone(), distributed_runtime.clone(),
...@@ -531,7 +530,7 @@ mod integration_tests { ...@@ -531,7 +530,7 @@ mod integration_tests {
if let Some(key) = key { if let Some(key) = key {
// Remove from ModelManager first (this returns the ModelEntry) // Remove from ModelManager first (this returns the ModelEntry)
if let Some(_removed_entry) = manager.remove_model_entry(&key) { if let Some(_removed_card) = manager.remove_model_card(&key) {
// Remove engines (following ModelWatcher::handle_delete pattern) // Remove engines (following ModelWatcher::handle_delete pattern)
manager manager
.remove_chat_completions_model(&model_entry.name) .remove_chat_completions_model(&model_entry.name)
......
...@@ -4,109 +4,18 @@ ...@@ -4,109 +4,18 @@
//! Integration tests for HTTP service namespace discovery functionality. //! Integration tests for HTTP service namespace discovery functionality.
//! These tests verify that the HTTP service correctly filters models based on namespace configuration. //! These tests verify that the HTTP service correctly filters models based on namespace configuration.
use dynamo_llm::{ use dynamo_llm::namespace::{GLOBAL_NAMESPACE, is_global_namespace};
discovery::ModelEntry,
model_type::{ModelInput, ModelType},
namespace::{GLOBAL_NAMESPACE, is_global_namespace},
};
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
// Helper function to create a test ModelEntry // Helper function to create a test ModelDeploymentCard
fn create_test_model_entry( fn create_test_endpoint(namespace: &str, component: &str, endpoint_name: &str) -> EndpointId {
name: &str, EndpointId {
namespace: &str, namespace: namespace.to_string(),
component: &str, component: component.to_string(),
endpoint_name: &str, name: endpoint_name.to_string(),
model_type: ModelType,
model_input: ModelInput,
) -> ModelEntry {
ModelEntry {
name: name.to_string(),
endpoint_id: EndpointId {
namespace: namespace.to_string(),
component: component.to_string(),
name: endpoint_name.to_string(),
},
model_type,
model_input,
runtime_config: None,
} }
} }
#[test]
fn test_namespace_filtering_behavior() {
// Test the core namespace filtering logic used in HTTP service
let test_models = vec![
create_test_model_entry(
"model-1",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-2",
"sglang-prod",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-3",
"dynamo",
"backend",
"generate",
ModelType::Completions,
ModelInput::Tokens,
),
create_test_model_entry(
"model-4",
"tensorrt-llm",
"backend",
"generate",
ModelType::Embedding,
ModelInput::Tokens,
),
];
// Test filtering for specific namespace "vllm-agg"
let target_namespace = "vllm-agg";
let is_global = is_global_namespace(target_namespace);
let filtered_models: Vec<&ModelEntry> = test_models
.iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_models.len(), 1);
assert_eq!(filtered_models[0].name, "model-1");
assert_eq!(filtered_models[0].endpoint_id.namespace, "vllm-agg");
// Test filtering for global namespace (should include all models)
let target_namespace = GLOBAL_NAMESPACE;
let is_global = is_global_namespace(target_namespace);
let filtered_models_global: Vec<&ModelEntry> = test_models
.iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_models_global.len(), 4); // All models should be included
// Test filtering for empty namespace (treated as global)
let target_namespace = "";
let is_global = is_global_namespace(target_namespace);
let filtered_models_empty: Vec<&ModelEntry> = test_models
.iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_models_empty.len(), 4); // All models should be included
}
#[test] #[test]
fn test_endpoint_id_namespace_extraction() { fn test_endpoint_id_namespace_extraction() {
// Test endpoint ID parsing for different namespace formats // Test endpoint ID parsing for different namespace formats
...@@ -165,62 +74,30 @@ fn test_model_discovery_scoping_scenarios() { ...@@ -165,62 +74,30 @@ fn test_model_discovery_scoping_scenarios() {
// Scenario 1: Frontend configured for specific namespace should only see models from that namespace // Scenario 1: Frontend configured for specific namespace should only see models from that namespace
let frontend_namespace = "vllm-agg"; let frontend_namespace = "vllm-agg";
let available_models = vec![ let available_models = vec![
create_test_model_entry( create_test_endpoint("vllm-agg", "backend", "generate"),
"llama-7b", create_test_endpoint("vllm-agg", "backend", "generate"),
"vllm-agg", create_test_endpoint("sglang-prod", "backend", "generate"),
"backend", create_test_endpoint("dynamo", "backend", "generate"),
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"mistral-7b",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"gpt-3.5",
"sglang-prod",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"claude-3",
"dynamo",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
]; ];
let visible_models: Vec<&ModelEntry> = available_models let visible_models: Vec<&EndpointId> = available_models
.iter() .iter()
.filter(|model| { .filter(|endpoint| {
let is_global = is_global_namespace(frontend_namespace); let is_global = is_global_namespace(frontend_namespace);
is_global || model.endpoint_id.namespace == frontend_namespace is_global || endpoint.namespace == frontend_namespace
}) })
.collect(); .collect();
assert_eq!(visible_models.len(), 2); assert_eq!(visible_models.len(), 2);
assert!( assert!(visible_models.iter().all(|m| m.namespace == "vllm-agg"));
visible_models
.iter()
.all(|m| m.endpoint_id.namespace == "vllm-agg")
);
// Scenario 2: Frontend configured for global namespace should see all models // Scenario 2: Frontend configured for global namespace should see all models
let frontend_namespace = GLOBAL_NAMESPACE; let frontend_namespace = GLOBAL_NAMESPACE;
let visible_models_global: Vec<&ModelEntry> = available_models let visible_models_global: Vec<&EndpointId> = available_models
.iter() .iter()
.filter(|model| { .filter(|endpoint| {
let is_global = is_global_namespace(frontend_namespace); let is_global = is_global_namespace(frontend_namespace);
is_global || model.endpoint_id.namespace == frontend_namespace is_global || endpoint.namespace == frontend_namespace
}) })
.collect(); .collect();
...@@ -228,11 +105,11 @@ fn test_model_discovery_scoping_scenarios() { ...@@ -228,11 +105,11 @@ fn test_model_discovery_scoping_scenarios() {
// Scenario 3: Frontend configured for non-existent namespace should see no models // Scenario 3: Frontend configured for non-existent namespace should see no models
let frontend_namespace = "non-existent-namespace"; let frontend_namespace = "non-existent-namespace";
let visible_models_none: Vec<&ModelEntry> = available_models let visible_models_none: Vec<&EndpointId> = available_models
.iter() .iter()
.filter(|model| { .filter(|endpoint| {
let is_global = is_global_namespace(frontend_namespace); let is_global = is_global_namespace(frontend_namespace);
is_global || model.endpoint_id.namespace == frontend_namespace is_global || endpoint.namespace == frontend_namespace
}) })
.collect(); .collect();
...@@ -244,30 +121,9 @@ fn test_namespace_boundary_conditions() { ...@@ -244,30 +121,9 @@ fn test_namespace_boundary_conditions() {
// Test edge cases and boundary conditions for namespace handling // Test edge cases and boundary conditions for namespace handling
let test_models = vec![ let test_models = vec![
create_test_model_entry( create_test_endpoint("", "backend", "generate"), // Empty namespace
"model-1", create_test_endpoint("dynamo", "backend", "generate"), // Global namespace
"", create_test_endpoint("ns-with-special-chars_123", "backend", "generate"),
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
), // Empty namespace
create_test_model_entry(
"model-2",
"dynamo",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
), // Global namespace
create_test_model_entry(
"model-3",
"ns-with-special-chars_123",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
]; ];
// Test filtering with empty target namespace (should be treated as global) // Test filtering with empty target namespace (should be treated as global)
...@@ -275,9 +131,9 @@ fn test_namespace_boundary_conditions() { ...@@ -275,9 +131,9 @@ fn test_namespace_boundary_conditions() {
let is_global = is_global_namespace(target_namespace); let is_global = is_global_namespace(target_namespace);
assert!(is_global); // Empty namespace should be treated as global assert!(is_global); // Empty namespace should be treated as global
let filtered_empty: Vec<&ModelEntry> = test_models let filtered_empty: Vec<&EndpointId> = test_models
.iter() .iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace) .filter(|model| is_global || model.namespace == target_namespace)
.collect(); .collect();
assert_eq!(filtered_empty.len(), 3); // All models should be visible assert_eq!(filtered_empty.len(), 3); // All models should be visible
...@@ -287,9 +143,9 @@ fn test_namespace_boundary_conditions() { ...@@ -287,9 +143,9 @@ fn test_namespace_boundary_conditions() {
let is_global = is_global_namespace(target_namespace); let is_global = is_global_namespace(target_namespace);
assert!(is_global); assert!(is_global);
let filtered_global: Vec<&ModelEntry> = test_models let filtered_global: Vec<&EndpointId> = test_models
.iter() .iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace) .filter(|model| is_global || model.namespace == target_namespace)
.collect(); .collect();
assert_eq!(filtered_global.len(), 3); // All models should be visible assert_eq!(filtered_global.len(), 3); // All models should be visible
...@@ -299,9 +155,9 @@ fn test_namespace_boundary_conditions() { ...@@ -299,9 +155,9 @@ fn test_namespace_boundary_conditions() {
let is_global = is_global_namespace(target_namespace); let is_global = is_global_namespace(target_namespace);
assert!(!is_global); // Should be case-sensitive assert!(!is_global); // Should be case-sensitive
let filtered_uppercase: Vec<&ModelEntry> = test_models let filtered_uppercase: Vec<&EndpointId> = test_models
.iter() .iter()
.filter(|model| is_global || model.endpoint_id.namespace == target_namespace) .filter(|model| is_global || model.namespace == target_namespace)
.collect(); .collect();
assert_eq!(filtered_uppercase.len(), 0); // No models should be visible assert_eq!(filtered_uppercase.len(), 0); // No models should be visible
......
...@@ -6,11 +6,10 @@ pub mod kserve_test { ...@@ -6,11 +6,10 @@ pub mod kserve_test {
pub mod inference { pub mod inference {
tonic::include_proto!("inference"); tonic::include_proto!("inference");
} }
use dynamo_llm::discovery::ModelEntry;
use dynamo_llm::local_model::runtime_config::ModelRuntimeConfig; use dynamo_llm::local_model::runtime_config::ModelRuntimeConfig;
use dynamo_llm::model_card::ModelDeploymentCard;
use dynamo_llm::model_type::{ModelInput, ModelType}; use dynamo_llm::model_type::{ModelInput, ModelType};
use dynamo_llm::protocols::tensor; use dynamo_llm::protocols::tensor;
use dynamo_runtime::protocols::EndpointId;
use inference::grpc_inference_service_client::GrpcInferenceServiceClient; use inference::grpc_inference_service_client::GrpcInferenceServiceClient;
use inference::{ use inference::{
DataType, ModelConfigRequest, ModelInferRequest, ModelInferResponse, ModelMetadataRequest, DataType, ModelConfigRequest, ModelInferRequest, ModelInferResponse, ModelMetadataRequest,
...@@ -284,20 +283,10 @@ pub mod kserve_test { ...@@ -284,20 +283,10 @@ pub mod kserve_test {
manager manager
.add_completions_model("split", split.clone()) .add_completions_model("split", split.clone())
.unwrap(); .unwrap();
manager.save_model_entry( let mut card = ModelDeploymentCard::with_name_only("split");
"split", card.model_type = ModelType::Completions;
ModelEntry { card.model_input = ModelInput::Text;
name: "split".to_string(), manager.save_model_card("split", card);
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "split".to_string(),
},
model_type: ModelType::Completions,
model_input: ModelInput::Text,
runtime_config: None,
},
);
manager manager
.add_chat_completions_model("failure", failure.clone()) .add_chat_completions_model("failure", failure.clone())
...@@ -305,37 +294,17 @@ pub mod kserve_test { ...@@ -305,37 +294,17 @@ pub mod kserve_test {
manager manager
.add_completions_model("failure", failure.clone()) .add_completions_model("failure", failure.clone())
.unwrap(); .unwrap();
manager.save_model_entry( let mut card = ModelDeploymentCard::with_name_only("failure");
"failure", card.model_type = ModelType::Completions | ModelType::Chat;
ModelEntry { card.model_input = ModelInput::Text;
name: "failure".to_string(), manager.save_model_card("failure", card);
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "failure".to_string(),
},
model_type: ModelType::Completions | ModelType::Chat,
model_input: ModelInput::Text,
runtime_config: None,
},
);
manager manager
.add_completions_model("long_running", long_running.clone()) .add_completions_model("long_running", long_running.clone())
.unwrap(); .unwrap();
manager.save_model_entry( let mut card = ModelDeploymentCard::with_name_only("long_running");
"long_running", card.model_type = ModelType::Completions;
ModelEntry { card.model_input = ModelInput::Text;
name: "long_running".to_string(), manager.save_model_card("long_running", card);
endpoint_id: EndpointId {
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "long_running".to_string(),
},
model_type: ModelType::Completions,
model_input: ModelInput::Text,
runtime_config: None,
},
);
(service, split, failure, long_running) (service, split, failure, long_running)
} }
...@@ -1179,21 +1148,13 @@ pub mod kserve_test { ...@@ -1179,21 +1148,13 @@ pub mod kserve_test {
}); });
// Failure, model registered as Tensor but does not provide model config (in runtime config) // Failure, model registered as Tensor but does not provide model config (in runtime config)
let entry = ModelEntry { let mut card = ModelDeploymentCard::with_name_only("tensor");
name: "tensor".to_string(), card.model_type = ModelType::TensorBased;
endpoint_id: EndpointId { card.model_input = ModelInput::Tensor;
namespace: "namespace".to_string(),
component: "component".to_string(),
name: "endpoint".to_string(),
},
model_type: ModelType::TensorBased,
model_input: ModelInput::Tensor,
runtime_config: None,
};
service_with_engines service_with_engines
.0 .0
.model_manager() .model_manager()
.save_model_entry("key", entry); .save_model_card("key", card);
let response = client.model_metadata(request).await; let response = client.model_metadata(request).await;
assert!(response.is_err()); assert!(response.is_err());
...@@ -1236,37 +1197,30 @@ pub mod kserve_test { ...@@ -1236,37 +1197,30 @@ pub mod kserve_test {
service_with_engines service_with_engines
.0 .0
.model_manager() .model_manager()
.remove_model_entry("key"); .remove_model_card("key");
let entry = ModelEntry { let mut card = ModelDeploymentCard::with_name_only("tensor");
name: "tensor".to_string(), card.model_type = ModelType::TensorBased;
endpoint_id: EndpointId { card.model_input = ModelInput::Tensor;
namespace: "namespace".to_string(), card.runtime_config = ModelRuntimeConfig {
component: "component".to_string(), tensor_model_config: Some(tensor::TensorModelConfig {
name: "endpoint".to_string(), name: "tensor".to_string(),
}, inputs: vec![tensor::TensorMetadata {
model_type: ModelType::TensorBased, name: "input".to_string(),
model_input: ModelInput::Tensor, data_type: tensor::DataType::Bytes,
runtime_config: Some(ModelRuntimeConfig { shape: vec![1],
tensor_model_config: Some(tensor::TensorModelConfig { }],
name: "tensor".to_string(), outputs: vec![tensor::TensorMetadata {
inputs: vec![tensor::TensorMetadata { name: "output".to_string(),
name: "input".to_string(), data_type: tensor::DataType::Bool,
data_type: tensor::DataType::Bytes, shape: vec![-1],
shape: vec![1], }],
}],
outputs: vec![tensor::TensorMetadata {
name: "output".to_string(),
data_type: tensor::DataType::Bool,
shape: vec![-1],
}],
}),
..Default::default()
}), }),
..Default::default()
}; };
service_with_engines service_with_engines
.0 .0
.model_manager() .model_manager()
.save_model_entry("key", entry); .save_model_card("key", card);
// Success // Success
let request = tonic::Request::new(ModelMetadataRequest { let request = tonic::Request::new(ModelMetadataRequest {
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use dynamo_llm::{
discovery::ModelEntry,
model_type::{ModelInput, ModelType},
namespace::{GLOBAL_NAMESPACE, is_global_namespace},
};
use dynamo_runtime::protocols::EndpointId; use dynamo_runtime::protocols::EndpointId;
#[test]
fn test_is_global_namespace_with_global_string() {
assert!(is_global_namespace(GLOBAL_NAMESPACE));
assert!(is_global_namespace("dynamo"));
}
#[test]
fn test_is_global_namespace_with_empty_string() {
assert!(is_global_namespace(""));
}
#[test]
fn test_is_global_namespace_with_specific_namespace() {
assert!(!is_global_namespace("test-namespace"));
assert!(!is_global_namespace("my-custom-namespace"));
}
#[test]
fn test_is_global_namespace_with_whitespace() {
// Whitespace should not be considered global
assert!(!is_global_namespace(" "));
assert!(!is_global_namespace(" "));
assert!(!is_global_namespace("\t"));
assert!(!is_global_namespace("\n"));
}
#[test]
fn test_is_global_namespace_case_sensitivity() {
// Should be case sensitive
assert!(!is_global_namespace("Dynamo"));
assert!(!is_global_namespace("DYNAMO"));
}
#[test]
fn test_global_namespace_constant() {
assert_eq!(GLOBAL_NAMESPACE, "dynamo");
}
// Helper function to create a test ModelEntry
fn create_test_model_entry(
name: &str,
namespace: &str,
component: &str,
endpoint_name: &str,
model_type: ModelType,
model_input: ModelInput,
) -> ModelEntry {
ModelEntry {
name: name.to_string(),
endpoint_id: EndpointId {
namespace: namespace.to_string(),
component: component.to_string(),
name: endpoint_name.to_string(),
},
model_type,
model_input,
runtime_config: None,
}
}
#[test]
fn test_model_entry_creation_with_different_namespaces() {
// Test creating ModelEntry with specific namespace
let model_vllm = create_test_model_entry(
"test-model-1",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
);
assert_eq!(model_vllm.name, "test-model-1");
assert_eq!(model_vllm.endpoint_id.namespace, "vllm-agg");
assert_eq!(model_vllm.endpoint_id.component, "backend");
assert_eq!(model_vllm.endpoint_id.name, "generate");
assert_eq!(model_vllm.model_type, ModelType::Chat);
assert_eq!(model_vllm.model_input, ModelInput::Tokens);
// Test creating ModelEntry with global namespace
let model_global = create_test_model_entry(
"test-model-2",
"dynamo",
"frontend",
"http",
ModelType::Completions,
ModelInput::Text,
);
assert_eq!(model_global.name, "test-model-2");
assert_eq!(model_global.endpoint_id.namespace, "dynamo");
assert_eq!(model_global.endpoint_id.component, "frontend");
assert_eq!(model_global.endpoint_id.name, "http");
assert_eq!(model_global.model_type, ModelType::Completions);
assert_eq!(model_global.model_input, ModelInput::Text);
}
#[test]
fn test_namespace_filtering_logic() {
// Test the core logic that would be used in namespace filtering
let models = vec![
create_test_model_entry(
"model-1",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-2",
"sglang-prod",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-3",
"dynamo",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
create_test_model_entry(
"model-4",
"",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
),
];
// Test filtering for specific namespace "vllm-agg"
let target_namespace = "vllm-agg";
let global_namespace = is_global_namespace(target_namespace);
let filtered_vllm: Vec<&ModelEntry> = models
.iter()
.filter(|model| global_namespace || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_vllm.len(), 1);
assert_eq!(filtered_vllm[0].name, "model-1");
assert_eq!(filtered_vllm[0].endpoint_id.namespace, "vllm-agg");
// Test filtering for global namespace (should include all)
let target_namespace = "dynamo";
let global_namespace = is_global_namespace(target_namespace);
let filtered_global: Vec<&ModelEntry> = models
.iter()
.filter(|model| global_namespace || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_global.len(), 4); // All models should be included
// Test filtering for empty namespace (should include all, treated as global)
let target_namespace = "";
let global_namespace = is_global_namespace(target_namespace);
let filtered_empty: Vec<&ModelEntry> = models
.iter()
.filter(|model| global_namespace || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_empty.len(), 4); // All models should be included
// Test filtering for non-existent namespace
let target_namespace = "non-existent";
let global_namespace = is_global_namespace(target_namespace);
let filtered_none: Vec<&ModelEntry> = models
.iter()
.filter(|model| global_namespace || model.endpoint_id.namespace == target_namespace)
.collect();
assert_eq!(filtered_none.len(), 0); // No models should match
}
#[test]
fn test_model_entry_serialization() {
// Test that ModelEntry can be serialized and deserialized (important for etcd storage)
let model = create_test_model_entry(
"test-model",
"vllm-agg",
"backend",
"generate",
ModelType::Chat,
ModelInput::Tokens,
);
// Serialize to JSON
let json = serde_json::to_string(&model).expect("Failed to serialize ModelEntry");
assert!(json.contains("test-model"));
assert!(json.contains("vllm-agg"));
assert!(json.contains("backend"));
assert!(json.contains("generate"));
// Deserialize from JSON
let deserialized: ModelEntry =
serde_json::from_str(&json).expect("Failed to deserialize ModelEntry");
assert_eq!(deserialized.name, model.name);
assert_eq!(
deserialized.endpoint_id.namespace,
model.endpoint_id.namespace
);
assert_eq!(
deserialized.endpoint_id.component,
model.endpoint_id.component
);
assert_eq!(deserialized.endpoint_id.name, model.endpoint_id.name);
assert_eq!(deserialized.model_type, model.model_type);
assert_eq!(deserialized.model_input, model.model_input);
}
#[test] #[test]
fn test_endpoint_namespace_parsing() { fn test_endpoint_namespace_parsing() {
// Test Endpoint creation from string with namespace // Test Endpoint creation from string with namespace
......
...@@ -687,12 +687,14 @@ dependencies = [ ...@@ -687,12 +687,14 @@ dependencies = [
"once_cell", "once_cell",
"prometheus", "prometheus",
"rand 0.9.1", "rand 0.9.1",
"rayon",
"regex", "regex",
"serde", "serde",
"serde_json", "serde_json",
"socket2", "socket2",
"thiserror 2.0.12", "thiserror 2.0.12",
"tokio", "tokio",
"tokio-rayon",
"tokio-stream", "tokio-stream",
"tokio-util", "tokio-util",
"tower-http", "tower-http",
...@@ -2848,6 +2850,16 @@ dependencies = [ ...@@ -2848,6 +2850,16 @@ dependencies = [
"syn 2.0.100", "syn 2.0.100",
] ]
[[package]]
name = "tokio-rayon"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7cf33a76e0b1dd03b778f83244137bd59887abf25c0e87bc3e7071105f457693"
dependencies = [
"rayon",
"tokio",
]
[[package]] [[package]]
name = "tokio-rustls" name = "tokio-rustls"
version = "0.26.2" version = "0.26.2"
......
...@@ -107,6 +107,13 @@ impl Instance { ...@@ -107,6 +107,13 @@ impl Instance {
pub fn id(&self) -> i64 { pub fn id(&self) -> i64 {
self.instance_id self.instance_id
} }
pub fn endpoint_id(&self) -> EndpointId {
EndpointId {
namespace: self.namespace.clone(),
component: self.component.clone(),
name: self.endpoint.clone(),
}
}
} }
/// A [Component] a discoverable entity in the distributed runtime. /// A [Component] a discoverable entity in the distributed runtime.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment