Unverified Commit 81162dfe authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(discovery): Watch/publish ModelDeploymentCard instead of ModelEntry (#3350)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent ddbb4f50
...@@ -69,7 +69,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -69,7 +69,7 @@ async def worker(runtime: DistributedRuntime):
host: str = "localhost" host: str = "localhost"
port: int = 8000 port: int = 8000
service: HttpService = HttpService(port=port) service: HttpService = HttpService(port=port)
service.add_chat_completions_model(served_model_name, engine) service.add_chat_completions_model(served_model_name, "mdcsum", engine)
print("Starting service...") print("Starting service...")
shutdown_signal = service.run(runtime.child_token()) shutdown_signal = service.run(runtime.child_token())
......
...@@ -30,23 +30,29 @@ impl HttpService { ...@@ -30,23 +30,29 @@ impl HttpService {
Ok(Self { inner }) Ok(Self { inner })
} }
pub fn add_completions_model(&self, model: String, engine: HttpAsyncEngine) -> PyResult<()> { pub fn add_completions_model(
&self,
model: String,
checksum: String,
engine: HttpAsyncEngine,
) -> PyResult<()> {
let engine = Arc::new(engine); let engine = Arc::new(engine);
self.inner self.inner
.model_manager() .model_manager()
.add_completions_model(&model, engine) .add_completions_model(&model, &checksum, engine)
.map_err(to_pyerr) .map_err(to_pyerr)
} }
pub fn add_chat_completions_model( pub fn add_chat_completions_model(
&self, &self,
model: String, model: String,
checksum: String,
engine: HttpAsyncEngine, engine: HttpAsyncEngine,
) -> PyResult<()> { ) -> PyResult<()> {
let engine = Arc::new(engine); let engine = Arc::new(engine);
self.inner self.inner
.model_manager() .model_manager()
.add_chat_completions_model(&model, engine) .add_chat_completions_model(&model, &checksum, engine)
.map_err(to_pyerr) .map_err(to_pyerr)
} }
......
...@@ -85,6 +85,7 @@ async def http_server(runtime: DistributedRuntime): ...@@ -85,6 +85,7 @@ async def http_server(runtime: DistributedRuntime):
model_name = "test_model" model_name = "test_model"
start_done = asyncio.Event() start_done = asyncio.Event()
child_token = runtime.child_token() child_token = runtime.child_token()
checksum = "abc123" # Checksum of ModelDeplomentCard for that model
async def worker(): async def worker():
"""The server worker task.""" """The server worker task."""
...@@ -94,7 +95,7 @@ async def http_server(runtime: DistributedRuntime): ...@@ -94,7 +95,7 @@ async def http_server(runtime: DistributedRuntime):
engine = HttpAsyncEngine(python_engine.generate, loop) engine = HttpAsyncEngine(python_engine.generate, loop)
service = HttpService(port=port) service = HttpService(port=port)
service.add_chat_completions_model(model_name, engine) service.add_chat_completions_model(model_name, checksum, engine)
service.enable_endpoint("chat", True) service.enable_endpoint("chat", True)
shutdown_signal = service.run(child_token) shutdown_signal = service.run(child_token)
......
...@@ -33,7 +33,7 @@ pub struct Checksum { ...@@ -33,7 +33,7 @@ pub struct Checksum {
algorithm: CryptographicHashMethods, algorithm: CryptographicHashMethods,
} }
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, Copy, Eq, PartialEq)]
pub enum CryptographicHashMethods { pub enum CryptographicHashMethods {
#[serde(rename = "blake3")] #[serde(rename = "blake3")]
BLAKE3, BLAKE3,
...@@ -259,6 +259,15 @@ impl TryFrom<&str> for Checksum { ...@@ -259,6 +259,15 @@ impl TryFrom<&str> for Checksum {
} }
} }
impl Default for Checksum {
fn default() -> Self {
Self {
hash: "".to_string(),
algorithm: CryptographicHashMethods::BLAKE3,
}
}
}
impl FromStr for CryptographicHashMethods { impl FromStr for CryptographicHashMethods {
type Err = String; type Err = String;
......
...@@ -4,14 +4,8 @@ ...@@ -4,14 +4,8 @@
mod model_manager; mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError}; pub use model_manager::{ModelManager, ModelManagerError};
mod model_entry;
pub use model_entry::ModelEntry;
mod watcher; mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher}; pub use watcher::{ModelUpdate, ModelWatcher};
/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
/// The root etcd path for KV Router registrations /// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "kv_routers"; pub const KV_ROUTERS_ROOT_PATH: &str = "kv_routers";
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::{protocols, slug::Slug};
use serde::{Deserialize, Serialize};
use crate::local_model::runtime_config::ModelRuntimeConfig;
/// [ModelEntry] contains the information to discover models
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
/// Public name of the model
/// Used to identify the model in the HTTP service from the value used in an OpenAI ChatRequest.
pub name: String,
/// How to address this on the network
#[serde(rename = "endpoint")]
pub endpoint_id: protocols::EndpointId,
/// Runtime configuration specific to this model instance
#[serde(default, skip_serializing_if = "Option::is_none")]
pub runtime_config: Option<ModelRuntimeConfig>,
}
impl ModelEntry {
/// Slugified display name for use in network storage, or URL-safe environments
pub fn slug(&self) -> Slug {
Slug::from_string(&self.name)
}
}
...@@ -11,7 +11,6 @@ use parking_lot::{Mutex, RwLock}; ...@@ -11,7 +11,6 @@ use parking_lot::{Mutex, RwLock};
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider; use dynamo_runtime::prelude::DistributedRuntimeProvider;
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard}; use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
use crate::{ use crate::{
kv_router::KvRouter, kv_router::KvRouter,
...@@ -21,6 +20,10 @@ use crate::{ ...@@ -21,6 +20,10 @@ use crate::{
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
}, },
}; };
use crate::{
kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector},
model_type::ModelType,
};
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum ModelManagerError { pub enum ModelManagerError {
...@@ -39,7 +42,7 @@ pub struct ModelManager { ...@@ -39,7 +42,7 @@ pub struct ModelManager {
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>, embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>, tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// These two are Mutex because we read and write rarely and equally // These are Mutex because we read and write rarely and equally
cards: Mutex<HashMap<String, ModelDeploymentCard>>, cards: Mutex<HashMap<String, ModelDeploymentCard>>,
kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>, kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
} }
...@@ -62,6 +65,43 @@ impl ModelManager { ...@@ -62,6 +65,43 @@ impl ModelManager {
} }
} }
pub fn is_valid_checksum(
&self,
model_type: ModelType,
model_name: &str,
candidate_checksum: &str,
) -> Option<bool> {
let mut results = vec![];
for unit in model_type.units() {
let maybe_valid_checksum = match unit {
ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
_ => {
continue;
}
};
if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
tracing::debug!(
model_name,
valid_checksum,
candidate_checksum,
"is_valid_checksum: check case"
);
valid_checksum == candidate_checksum
}) {
results.push(is_valid)
}
}
if results.is_empty() {
None
} else {
// The checksum is valid if it is correct for all the ModelType in the bitflag.
Some(results.into_iter().all(|x| x))
}
}
pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> { pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.cards.lock().values().cloned().collect() self.cards.lock().values().cloned().collect()
} }
...@@ -99,37 +139,41 @@ impl ModelManager { ...@@ -99,37 +139,41 @@ impl ModelManager {
pub fn add_completions_model( pub fn add_completions_model(
&self, &self,
model: &str, model: &str,
card_checksum: &str,
engine: OpenAICompletionsStreamingEngine, engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write(); let mut clients = self.completion_engines.write();
clients.add(model, engine) clients.add(model, card_checksum, engine)
} }
pub fn add_chat_completions_model( pub fn add_chat_completions_model(
&self, &self,
model: &str, model: &str,
card_checksum: &str,
engine: OpenAIChatCompletionsStreamingEngine, engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write(); let mut clients = self.chat_completion_engines.write();
clients.add(model, engine) clients.add(model, card_checksum, engine)
} }
pub fn add_embeddings_model( pub fn add_embeddings_model(
&self, &self,
model: &str, model: &str,
card_checksum: &str,
engine: OpenAIEmbeddingsStreamingEngine, engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write(); let mut clients = self.embeddings_engines.write();
clients.add(model, engine) clients.add(model, card_checksum, engine)
} }
pub fn add_tensor_model( pub fn add_tensor_model(
&self, &self,
model: &str, model: &str,
card_checksum: &str,
engine: TensorStreamingEngine, engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write(); let mut clients = self.tensor_engines.write();
clients.add(model, engine) clients.add(model, card_checksum, engine)
} }
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
...@@ -196,10 +240,11 @@ impl ModelManager { ...@@ -196,10 +240,11 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
/// Save a ModelDeploymentCard from an instance's etcd `models/` key so we can fetch it later when the key is /// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
/// deleted from etcd. /// deleted.
pub fn save_model_card(&self, key: &str, entry: ModelDeploymentCard) { pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.lock().insert(key.to_string(), entry); self.cards.lock().insert(key.to_string(), card);
Ok(())
} }
/// Remove and return model card for this instance's etcd key. We do this when the instance stops. /// Remove and return model card for this instance's etcd key. We do this when the instance stops.
...@@ -291,6 +336,9 @@ pub struct ModelEngines<E> { ...@@ -291,6 +336,9 @@ pub struct ModelEngines<E> {
/// Optional default model name /// Optional default model name
default: Option<String>, default: Option<String>,
engines: HashMap<String, E>, engines: HashMap<String, E>,
/// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
/// same card.
checksums: HashMap<String, String>,
} }
impl<E> Default for ModelEngines<E> { impl<E> Default for ModelEngines<E> {
...@@ -298,6 +346,7 @@ impl<E> Default for ModelEngines<E> { ...@@ -298,6 +346,7 @@ impl<E> Default for ModelEngines<E> {
Self { Self {
default: None, default: None,
engines: HashMap::new(), engines: HashMap::new(),
checksums: HashMap::new(),
} }
} }
} }
...@@ -313,11 +362,13 @@ impl<E> ModelEngines<E> { ...@@ -313,11 +362,13 @@ impl<E> ModelEngines<E> {
self.default = None; self.default = None;
} }
fn add(&mut self, model: &str, engine: E) -> Result<(), ModelManagerError> { fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
if self.engines.contains_key(model) { if self.engines.contains_key(model) {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
} }
self.engines.insert(model.to_string(), engine); self.engines.insert(model.to_string(), engine);
self.checksums
.insert(model.to_string(), checksum.to_string());
Ok(()) Ok(())
} }
...@@ -325,6 +376,7 @@ impl<E> ModelEngines<E> { ...@@ -325,6 +376,7 @@ impl<E> ModelEngines<E> {
if self.engines.remove(model).is_none() { if self.engines.remove(model).is_none() {
return Err(ModelManagerError::ModelNotFound(model.to_string())); return Err(ModelManagerError::ModelNotFound(model.to_string()));
} }
let _ = self.checksums.remove(model);
Ok(()) Ok(())
} }
...@@ -339,4 +391,10 @@ impl<E> ModelEngines<E> { ...@@ -339,4 +391,10 @@ impl<E> ModelEngines<E> {
pub fn list(&self) -> Vec<String> { pub fn list(&self) -> Vec<String> {
self.engines.keys().map(|k| k.to_owned()).collect() self.engines.keys().map(|k| k.to_owned()).collect()
} }
/// Returns a newly allocated String for called convenience. All the places I use
/// this I need a String.
pub fn checksum(&self, model: &str) -> Option<String> {
self.checksums.get(model).map(|s| s.to_string())
}
} }
This diff is collapsed.
...@@ -5,12 +5,12 @@ use std::pin::Pin; ...@@ -5,12 +5,12 @@ use std::pin::Pin;
use crate::{ use crate::{
backend::{Backend, ExecutionContext}, backend::{Backend, ExecutionContext},
discovery::{MODEL_ROOT_PATH, ModelManager, ModelWatcher}, discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig}, entrypoint::{self, EngineConfig},
kv_router::{KvPushRouter, KvRouter}, kv_router::{KvPushRouter, KvRouter},
migration::Migration, migration::Migration,
model_card::ModelDeploymentCard, model_card::{self, ModelDeploymentCard},
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate, request_template::RequestTemplate,
...@@ -73,7 +73,9 @@ pub async fn prepare_engine( ...@@ -73,7 +73,9 @@ pub async fn prepare_engine(
None, None,
None, None,
)); ));
let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?; let models_watcher = etcd_client
.kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let inner_watch_obj = watch_obj.clone(); let inner_watch_obj = watch_obj.clone();
......
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
discovery::{MODEL_ROOT_PATH, ModelManager, ModelWatcher}, discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common}, entrypoint::{self, EngineConfig, input::common},
grpc::service::kserve, grpc::service::kserve,
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace, namespace::is_global_namespace,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
...@@ -46,7 +47,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -46,7 +47,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
distributed_runtime, distributed_runtime,
grpc_service.state().manager_clone(), grpc_service.state().manager_clone(),
etcd_client.clone(), etcd_client.clone(),
MODEL_ROOT_PATH, model_card::ROOT_PATH,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config), Some(router_config.kv_router_config),
router_config.busy_threshold, router_config.busy_threshold,
...@@ -62,6 +63,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -62,6 +63,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} }
EngineConfig::StaticRemote(local_model) => { EngineConfig::StaticRemote(local_model) => {
let card = local_model.card(); let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode; let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static let dst_config = DistributedConfig::from_settings(true); // true means static
...@@ -103,7 +105,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -103,7 +105,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
tokenizer_hf.clone(), tokenizer_hf.clone(),
) )
.await?; .await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?; manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine = let completions_engine =
entrypoint::build_routed_pipeline::< entrypoint::build_routed_pipeline::<
...@@ -111,7 +117,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -111,7 +117,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf) >(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?; .await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?; manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
grpc_service grpc_service
} }
...@@ -119,8 +129,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -119,8 +129,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine)); let engine = Arc::new(StreamingEngineAdapter::new(engine));
let manager = grpc_service.model_manager(); let manager = grpc_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?; let checksum = model.card().mdcsum();
manager.add_chat_completions_model(model.service_name(), engine)?; manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
grpc_service grpc_service
} }
EngineConfig::StaticCore { EngineConfig::StaticCore {
...@@ -130,6 +141,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -130,6 +141,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} => { } => {
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager(); let manager = grpc_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer_hf = model.card().tokenizer_hf()?; let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline = let chat_pipeline =
...@@ -138,14 +150,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -138,14 +150,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone()) >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?; .await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?; manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::< let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer_hf) >(model.card(), inner_engine, tokenizer_hf)
.await?; .await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?; manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
grpc_service grpc_service
} }
}; };
......
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
discovery::{MODEL_ROOT_PATH, ModelManager, ModelUpdate, ModelWatcher}, discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType, endpoint_type::EndpointType,
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common}, entrypoint::{self, EngineConfig, input::common},
http::service::service_v2::{self, HttpService}, http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace, namespace::is_global_namespace,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
...@@ -74,7 +75,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -74,7 +75,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
distributed_runtime, distributed_runtime,
http_service.state().manager_clone(), http_service.state().manager_clone(),
etcd_client.clone(), etcd_client.clone(),
MODEL_ROOT_PATH, model_card::ROOT_PATH,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config), Some(router_config.kv_router_config),
router_config.busy_threshold, router_config.busy_threshold,
...@@ -92,6 +93,8 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -92,6 +93,8 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} }
EngineConfig::StaticRemote(local_model) => { EngineConfig::StaticRemote(local_model) => {
let card = local_model.card(); let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode; let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static let dst_config = DistributedConfig::from_settings(true); // true means static
...@@ -133,7 +136,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -133,7 +136,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
tokenizer_hf.clone(), tokenizer_hf.clone(),
) )
.await?; .await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?; manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine = let completions_engine =
entrypoint::build_routed_pipeline::< entrypoint::build_routed_pipeline::<
...@@ -141,7 +148,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -141,7 +148,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf) >(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?; .await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?; manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true); http_service.enable_model_endpoint(endpoint_type, true);
...@@ -153,8 +164,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -153,8 +164,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine)); let engine = Arc::new(StreamingEngineAdapter::new(engine));
let manager = http_service.model_manager(); let manager = http_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?; let checksum = model.card().mdcsum();
manager.add_chat_completions_model(model.service_name(), engine)?; manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
// Enable all endpoints // Enable all endpoints
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
...@@ -169,6 +181,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -169,6 +181,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} => { } => {
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let manager = http_service.model_manager(); let manager = http_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer_hf = model.card().tokenizer_hf()?; let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline = let chat_pipeline =
...@@ -177,14 +190,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -177,14 +190,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone()) >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?; .await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?; manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::< let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer_hf) >(model.card(), inner_engine, tokenizer_hf)
.await?; .await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?; manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
// Enable all endpoints // Enable all endpoints
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true); http_service.enable_model_endpoint(endpoint_type, true);
......
...@@ -32,7 +32,6 @@ pub mod sequence; ...@@ -32,7 +32,6 @@ pub mod sequence;
pub mod subscriber; pub mod subscriber;
use crate::{ use crate::{
discovery::{MODEL_ROOT_PATH, ModelEntry},
kv_router::{ kv_router::{
approx::ApproxKvIndexer, approx::ApproxKvIndexer,
indexer::{ indexer::{
...@@ -45,6 +44,7 @@ use crate::{ ...@@ -45,6 +44,7 @@ use crate::{
subscriber::start_kv_router_background, subscriber::start_kv_router_background,
}, },
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
model_card::{self, ModelDeploymentCard},
preprocessor::PreprocessedRequest, preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput, protocols::common::llm_backend::LLMEngineOutput,
}; };
...@@ -247,9 +247,9 @@ impl KvRouter { ...@@ -247,9 +247,9 @@ impl KvRouter {
let runtime_configs_watcher = watch_prefix_with_extraction( let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client, etcd_client,
MODEL_ROOT_PATH, model_card::ROOT_PATH,
key_extractors::lease_id, key_extractors::lease_id,
|model_entry: ModelEntry| model_entry.runtime_config, |card: ModelDeploymentCard| Some(card.runtime_config),
cancellation_token.clone(), cancellation_token.clone(),
) )
.await?; .await?;
......
...@@ -15,15 +15,12 @@ use dynamo_runtime::{ ...@@ -15,15 +15,12 @@ use dynamo_runtime::{
storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager}, storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
}; };
use crate::discovery::ModelEntry;
use crate::entrypoint::RouterConfig; use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs; use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::{self, ModelDeploymentCard}; use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::{ModelInput, ModelType}; use crate::model_type::{ModelInput, ModelType};
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
mod network_name;
pub use network_name::ModelNetworkName;
pub mod runtime_config; pub mod runtime_config;
use runtime_config::ModelRuntimeConfig; use runtime_config::ModelRuntimeConfig;
...@@ -421,36 +418,13 @@ impl LocalModel { ...@@ -421,36 +418,13 @@ impl LocalModel {
// Publish the Model Deployment Card to KV store // Publish the Model Deployment Card to KV store
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone())); let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string(); let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
// TODO: Next PR will use this let key = Key::from_raw(endpoint.unique_path(lease_id));
//let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
//let key = Key::from_raw(endpoint.unique_path(lease_id));
card_store
.publish(
model_card::ROOT_PATH,
None,
&Key::from_raw(key),
&mut self.card,
)
.await?;
// Publish our ModelEntry to etcd. This allows ingress to find the model card. let _outcome = card_store
// (Why don't we put the model card directly under this key?) .publish(model_card::ROOT_PATH, None, &key, &mut self.card)
let network_name = ModelNetworkName::new(); .await?;
tracing::debug!("Registering with etcd as {network_name}"); Ok(())
let model_registration = ModelEntry {
name: self.display_name().to_string(),
endpoint_id: endpoint.id(),
runtime_config: Some(self.runtime_config.clone()),
};
etcd_client
.kv_create(
&network_name,
serde_json::to_vec_pretty(&model_registration)?,
None, // use primary lease
)
.await
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::discovery::MODEL_ROOT_PATH;
#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);
impl ModelNetworkName {
pub fn new() -> Self {
ModelNetworkName(format!("{MODEL_ROOT_PATH}/{}", uuid::Uuid::new_v4()))
}
}
impl Default for ModelNetworkName {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl AsRef<str> for ModelNetworkName {
fn as_ref(&self) -> &str {
&self.0
}
}
impl std::ops::Deref for ModelNetworkName {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
use std::fmt; use std::fmt;
use std::fs::File; use std::fs::File;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::{Arc, OnceLock};
use crate::common::checked_file::CheckedFile; use crate::common::checked_file::{CheckedFile, Checksum};
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_type::{ModelInput, ModelType}; use crate::model_type::{ModelInput, ModelType};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
...@@ -43,6 +43,15 @@ pub enum ModelInfoType { ...@@ -43,6 +43,15 @@ pub enum ModelInfoType {
GGUF(PathBuf), GGUF(PathBuf),
} }
impl ModelInfoType {
pub fn checksum(&self) -> String {
match self {
ModelInfoType::HfConfigJson(c) => c.checksum().to_string(),
ModelInfoType::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum TokenizerKind { pub enum TokenizerKind {
...@@ -50,6 +59,15 @@ pub enum TokenizerKind { ...@@ -50,6 +59,15 @@ pub enum TokenizerKind {
GGUF(Box<HfTokenizer>), GGUF(Box<HfTokenizer>),
} }
impl TokenizerKind {
pub fn checksum(&self) -> String {
match self {
TokenizerKind::HfTokenizerJson(c) => c.checksum().to_string(),
TokenizerKind::GGUF(_) => Checksum::default().to_string(),
}
}
}
/// Supported types of prompt formatters. /// Supported types of prompt formatters.
/// ///
/// We need a way to associate the prompt formatter template definition with an associated /// We need a way to associate the prompt formatter template definition with an associated
...@@ -70,6 +88,16 @@ pub enum PromptFormatterArtifact { ...@@ -70,6 +88,16 @@ pub enum PromptFormatterArtifact {
GGUF(PathBuf), GGUF(PathBuf),
} }
impl PromptFormatterArtifact {
pub fn checksum(&self) -> String {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.checksum().to_string(),
PromptFormatterArtifact::HfChatTemplate(c) => c.checksum().to_string(),
PromptFormatterArtifact::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)] #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum PromptContextMixin { pub enum PromptContextMixin {
...@@ -87,6 +115,15 @@ pub enum GenerationConfig { ...@@ -87,6 +115,15 @@ pub enum GenerationConfig {
GGUF(PathBuf), GGUF(PathBuf),
} }
impl GenerationConfig {
pub fn checksum(&self) -> String {
match self {
GenerationConfig::HfGenerationConfigJson(c) => c.checksum().to_string(),
GenerationConfig::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)] #[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
pub struct ModelDeploymentCard { pub struct ModelDeploymentCard {
/// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct" /// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct"
...@@ -145,6 +182,9 @@ pub struct ModelDeploymentCard { ...@@ -145,6 +182,9 @@ pub struct ModelDeploymentCard {
#[serde(skip)] #[serde(skip)]
cache_dir: Option<Arc<tempfile::TempDir>>, cache_dir: Option<Arc<tempfile::TempDir>>,
#[serde(skip, default)]
checksum: OnceLock<String>,
} }
impl ModelDeploymentCard { impl ModelDeploymentCard {
...@@ -189,6 +229,12 @@ impl ModelDeploymentCard { ...@@ -189,6 +229,12 @@ impl ModelDeploymentCard {
Ok(()) Ok(())
} }
#[inline]
pub fn name(&self) -> &str {
&self.display_name
}
#[inline]
pub fn slug(&self) -> &Slug { pub fn slug(&self) -> &Slug {
&self.slug &self.slug
} }
...@@ -198,9 +244,45 @@ impl ModelDeploymentCard { ...@@ -198,9 +244,45 @@ impl ModelDeploymentCard {
Ok(serde_json::to_string(self)?) Ok(serde_json::to_string(self)?)
} }
pub fn mdcsum(&self) -> String { pub fn mdcsum(&self) -> &str {
let json = self.to_json().unwrap(); self.checksum
format!("{}", blake3::hash(json.as_bytes())) .get_or_init(|| {
// Only include the important fields
let mut bytes_to_hash: Vec<u8> = Vec::with_capacity(512);
bytes_to_hash.extend(self.display_name.as_bytes());
// The files can be either a URL or a local path, so we ignore that and hash their
// checksum instead, which won't change wherever they are.
if let Some(model_info) = self.model_info.as_ref() {
bytes_to_hash.extend(model_info.checksum().as_bytes());
}
if let Some(tokenizer) = self.tokenizer.as_ref() {
bytes_to_hash.extend(tokenizer.checksum().as_bytes());
}
if let Some(prompt_formatter) = self.prompt_formatter.as_ref() {
bytes_to_hash.extend(prompt_formatter.checksum().as_bytes());
}
if let Some(chat_template) = self.chat_template_file.as_ref() {
bytes_to_hash.extend(chat_template.checksum().as_bytes());
}
if let Some(gen_config) = self.gen_config.as_ref() {
bytes_to_hash.extend(gen_config.checksum().as_bytes());
}
if let Some(prompt_context_vec) = self.prompt_context.as_ref() {
// Paste it as the bytes of the debug format. It's a Vec of enum, so this should be
// fine. If the debug representation changes that only happens in a new release.
bytes_to_hash.extend(format!("{prompt_context_vec:?}").as_bytes());
}
bytes_to_hash.extend(self.context_length.to_be_bytes());
bytes_to_hash.extend(self.kv_cache_block_size.to_be_bytes());
// TODO: Do we want any of user_data or runtime_config?
blake3::hash(&bytes_to_hash).to_string()
})
.as_ref()
} }
/// Is this a full model card with tokenizer? /// Is this a full model card with tokenizer?
...@@ -291,9 +373,7 @@ impl ModelDeploymentCard { ...@@ -291,9 +373,7 @@ impl ModelDeploymentCard {
/// Move the files this MDC uses from the NATS object store to local disk. /// Move the files this MDC uses from the NATS object store to local disk.
/// Updates the URI's to point to the created files. /// Updates the URI's to point to the created files.
/// pub async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
/// The returned TempDir must be kept alive, it cleans up on drop.
async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<tempfile::TempDir> {
let nats_addr = nats_client.addr(); let nats_addr = nats_client.addr();
let bucket_name = self.slug(); let bucket_name = self.slug();
let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?; let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?;
...@@ -345,7 +425,9 @@ impl ModelDeploymentCard { ...@@ -345,7 +425,9 @@ impl ModelDeploymentCard {
"tokenizer.json" "tokenizer.json"
); );
Ok(target_dir) // This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive.
self.cache_dir = Some(Arc::new(target_dir));
Ok(())
} }
/// Delete this card from the key-value store and it's URLs from the object store /// Delete this card from the key-value store and it's URLs from the object store
...@@ -411,8 +493,7 @@ impl ModelDeploymentCard { ...@@ -411,8 +493,7 @@ impl ModelDeploymentCard {
else { else {
return Ok(None); return Ok(None);
}; };
// This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive. card.move_from_nats(drt.nats_client()).await?;
card.cache_dir = Some(Arc::new(card.move_from_nats(drt.nats_client()).await?));
Ok(Some(card)) Ok(Some(card))
} }
...@@ -487,6 +568,7 @@ impl ModelDeploymentCard { ...@@ -487,6 +568,7 @@ impl ModelDeploymentCard {
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
cache_dir: None, cache_dir: None,
checksum: OnceLock::new(),
}) })
} }
...@@ -551,6 +633,7 @@ impl ModelDeploymentCard { ...@@ -551,6 +633,7 @@ impl ModelDeploymentCard {
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
cache_dir: None, cache_dir: None,
checksum: OnceLock::new(),
}) })
} }
} }
......
...@@ -74,6 +74,25 @@ impl ModelType { ...@@ -74,6 +74,25 @@ impl ModelType {
result result
} }
/// Decompose the bitflag into it's component units:
/// Chat | Completion -> [Chat, Completion]
pub fn units(&self) -> Vec<ModelType> {
let mut result = Vec::new();
if self.supports_chat() {
result.push(ModelType::Chat);
}
if self.supports_completions() {
result.push(ModelType::Completions);
}
if self.supports_embedding() {
result.push(ModelType::Embedding);
}
if self.supports_tensor() {
result.push(ModelType::TensorBased);
}
result
}
/// Returns all endpoint types supported by this model type. /// Returns all endpoint types supported by this model type.
/// This properly handles combinations like Chat | Completions. /// This properly handles combinations like Chat | Completions.
pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> { pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> {
......
...@@ -124,7 +124,7 @@ impl OpenAIPreprocessor { ...@@ -124,7 +124,7 @@ impl OpenAIPreprocessor {
formatter: Arc<dyn OAIPromptFormatter>, formatter: Arc<dyn OAIPromptFormatter>,
hf_tokenizer: tokenizers::Tokenizer, hf_tokenizer: tokenizers::Tokenizer,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum(); let mdcsum = mdc.mdcsum().to_string();
let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer)); let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
let Some(model_info) = mdc.model_info else { let Some(model_info) = mdc.model_info else {
anyhow::bail!( anyhow::bail!(
......
...@@ -4,17 +4,6 @@ ...@@ -4,17 +4,6 @@
use anyhow::Error; use anyhow::Error;
use async_stream::stream; use async_stream::stream;
use dynamo_async_openai::config::OpenAIConfig; use dynamo_async_openai::config::OpenAIConfig;
use dynamo_llm::http::{
client::{
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient, PureOpenAIClient,
},
service::{
Metrics,
error::HttpError,
metrics::{Endpoint, RequestType, Status},
service_v2::HttpService,
},
};
use dynamo_llm::protocols::{ use dynamo_llm::protocols::{
Annotated, Annotated,
codec::SseLineCodec, codec::SseLineCodec,
...@@ -24,6 +13,21 @@ use dynamo_llm::protocols::{ ...@@ -24,6 +13,21 @@ use dynamo_llm::protocols::{
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse}, completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}, },
}; };
use dynamo_llm::{
http::{
client::{
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient,
PureOpenAIClient,
},
service::{
Metrics,
error::HttpError,
metrics::{Endpoint, RequestType, Status},
service_v2::HttpService,
},
},
model_card::ModelDeploymentCard,
};
use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix}; use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix};
use dynamo_runtime::{ use dynamo_runtime::{
CancellationToken, CancellationToken,
...@@ -275,15 +279,18 @@ async fn test_http_service() { ...@@ -275,15 +279,18 @@ async fn test_http_service() {
let registry = Registry::new(); let registry = Registry::new();
// TODO: Shouldn't this test know the card before it registers a model?
let card = ModelDeploymentCard::with_name_only("foo");
let counter = Arc::new(CounterEngine {}); let counter = Arc::new(CounterEngine {});
let result = manager.add_chat_completions_model("foo", counter); let result = manager.add_chat_completions_model("foo", card.mdcsum(), counter);
assert!(result.is_ok()); assert!(result.is_ok());
let failure = Arc::new(AlwaysFailEngine {}); let failure = Arc::new(AlwaysFailEngine {});
let result = manager.add_chat_completions_model("bar", failure.clone()); let card = ModelDeploymentCard::with_name_only("bar");
let result = manager.add_chat_completions_model("bar", card.mdcsum(), failure.clone());
assert!(result.is_ok()); assert!(result.is_ok());
let result = manager.add_completions_model("bar", failure); let result = manager.add_completions_model("bar", card.mdcsum(), failure);
assert!(result.is_ok()); assert!(result.is_ok());
let metrics = state.metrics_clone(); let metrics = state.metrics_clone();
...@@ -578,14 +585,16 @@ async fn service_with_engines() -> (HttpService, Arc<CounterEngine>, Arc<AlwaysF ...@@ -578,14 +585,16 @@ async fn service_with_engines() -> (HttpService, Arc<CounterEngine>, Arc<AlwaysF
let counter = Arc::new(CounterEngine {}); let counter = Arc::new(CounterEngine {});
let failure = Arc::new(AlwaysFailEngine {}); let failure = Arc::new(AlwaysFailEngine {});
let card = ModelDeploymentCard::with_name_only("foo");
manager manager
.add_chat_completions_model("foo", counter.clone()) .add_chat_completions_model("foo", card.mdcsum(), counter.clone())
.unwrap(); .unwrap();
let card = ModelDeploymentCard::with_name_only("bar");
manager manager
.add_chat_completions_model("bar", failure.clone()) .add_chat_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap(); .unwrap();
manager manager
.add_completions_model("bar", failure.clone()) .add_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap(); .unwrap();
(service, counter, failure, port) (service, counter, failure, port)
...@@ -977,9 +986,10 @@ async fn test_client_disconnect_cancellation_unary() { ...@@ -977,9 +986,10 @@ async fn test_client_disconnect_cancellation_unary() {
wait_for_service_ready(port).await; wait_for_service_ready(port).await;
// Create a long-running engine (10 seconds) // Create a long-running engine (10 seconds)
let card = ModelDeploymentCard::with_name_only("slow-model");
let long_running_engine = Arc::new(LongRunningEngine::new(10_000)); let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager manager
.add_chat_completions_model("slow-model", long_running_engine.clone()) .add_chat_completions_model("slow-model", card.mdcsum(), long_running_engine.clone())
.unwrap(); .unwrap();
let client = reqwest::Client::new(); let client = reqwest::Client::new();
...@@ -1068,9 +1078,14 @@ async fn test_client_disconnect_cancellation_streaming() { ...@@ -1068,9 +1078,14 @@ async fn test_client_disconnect_cancellation_streaming() {
wait_for_service_ready(port).await; wait_for_service_ready(port).await;
// Create a long-running engine (10 seconds) // Create a long-running engine (10 seconds)
let card = ModelDeploymentCard::with_name_only("slow-stream-model");
let long_running_engine = Arc::new(LongRunningEngine::new(10_000)); let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager manager
.add_chat_completions_model("slow-stream-model", long_running_engine.clone()) .add_chat_completions_model(
"slow-stream-model",
card.mdcsum(),
long_running_engine.clone(),
)
.unwrap(); .unwrap();
let client = reqwest::Client::new(); let client = reqwest::Client::new();
...@@ -1166,9 +1181,10 @@ async fn test_request_id_annotation() { ...@@ -1166,9 +1181,10 @@ async fn test_request_id_annotation() {
wait_for_service_ready(port).await; wait_for_service_ready(port).await;
// Add a counter engine for this test // Add a counter engine for this test
let card = ModelDeploymentCard::with_name_only("test-model");
let counter_engine = Arc::new(CounterEngine {}); let counter_engine = Arc::new(CounterEngine {});
manager manager
.add_chat_completions_model("test-model", counter_engine) .add_chat_completions_model("test-model", card.mdcsum(), counter_engine)
.unwrap(); .unwrap();
// Create reqwest client directly // Create reqwest client directly
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
use anyhow::Error; use anyhow::Error;
use async_stream::stream; use async_stream::stream;
use dynamo_llm::{ use dynamo_llm::{
http::service::metrics::Endpoint, http::service::{metrics::Endpoint, service_v2::HttpService},
http::service::service_v2::HttpService, model_card::ModelDeploymentCard,
protocols::{ protocols::{
Annotated, Annotated,
openai::chat_completions::{ openai::chat_completions::{
...@@ -206,9 +206,10 @@ async fn test_metrics_with_mock_model() { ...@@ -206,9 +206,10 @@ async fn test_metrics_with_mock_model() {
let task = tokio::spawn(async move { service.run(token.clone()).await }); let task = tokio::spawn(async move { service.run(token.clone()).await });
// Add mock model engine // Add mock model engine
let card = ModelDeploymentCard::with_name_only("mockmodel");
let mock_engine = Arc::new(MockModelEngine {}); let mock_engine = Arc::new(MockModelEngine {});
manager manager
.add_chat_completions_model("mockmodel", mock_engine) .add_chat_completions_model("mockmodel", card.mdcsum(), mock_engine)
.unwrap(); .unwrap();
// Wait for service to be ready // Wait for service to be ready
...@@ -293,10 +294,8 @@ async fn test_metrics_with_mock_model() { ...@@ -293,10 +294,8 @@ async fn test_metrics_with_mock_model() {
mod integration_tests { mod integration_tests {
use super::*; use super::*;
use dynamo_llm::{ use dynamo_llm::{
discovery::{MODEL_ROOT_PATH, ModelEntry, ModelWatcher}, discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig,
engines::make_echo_engine, local_model::LocalModelBuilder, model_card,
entrypoint::EngineConfig,
local_model::LocalModelBuilder,
}; };
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
...@@ -348,7 +347,7 @@ mod integration_tests { ...@@ -348,7 +347,7 @@ mod integration_tests {
// Start watching etcd for model registrations // Start watching etcd for model registrations
if let Some(etcd_client) = distributed_runtime.etcd_client() { if let Some(etcd_client) = distributed_runtime.etcd_client() {
let models_watcher = etcd_client let models_watcher = etcd_client
.kv_get_and_watch_prefix(MODEL_ROOT_PATH) .kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await .await
.unwrap(); .unwrap();
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
...@@ -364,10 +363,11 @@ mod integration_tests { ...@@ -364,10 +363,11 @@ mod integration_tests {
panic!("Expected StaticFull config"); panic!("Expected StaticFull config");
}; };
let card = local_model.card().clone();
let engine = Arc::new(dynamo_llm::engines::StreamingEngineAdapter::new(engine)); let engine = Arc::new(dynamo_llm::engines::StreamingEngineAdapter::new(engine));
let manager = service.model_manager(); let manager = service.model_manager();
manager manager
.add_chat_completions_model(model.service_name(), engine.clone()) .add_chat_completions_model(model.service_name(), card.mdcsum(), engine.clone())
.unwrap(); .unwrap();
// Now do the proper MDC registration via LocalModel::attach() // Now do the proper MDC registration via LocalModel::attach()
...@@ -376,7 +376,7 @@ mod integration_tests { ...@@ -376,7 +376,7 @@ mod integration_tests {
let test_component = namespace.component("test-mdc-component").unwrap(); let test_component = namespace.component("test-mdc-component").unwrap();
let test_endpoint = test_component.endpoint("test-mdc-endpoint"); let test_endpoint = test_component.endpoint("test-mdc-endpoint");
// This will store the MDC in etcd and create the ModelEntry for discovery // This will store the MDC in etcd for discovery
local_model local_model
.attach( .attach(
&test_endpoint, &test_endpoint,
...@@ -388,8 +388,7 @@ mod integration_tests { ...@@ -388,8 +388,7 @@ mod integration_tests {
// Manually save the model card and update metrics // Manually save the model card and update metrics
// This simulates what the ModelWatcher polling task would do in production // This simulates what the ModelWatcher polling task would do in production
let card = local_model.card().clone(); let _ = manager.save_model_card("test-mdc-key", card.clone());
manager.save_model_card("test-mdc-key", card.clone());
if let Err(e) = service if let Err(e) = service
.state() .state()
...@@ -500,8 +499,13 @@ mod integration_tests { ...@@ -500,8 +499,13 @@ mod integration_tests {
assert!(metrics_body.contains("request_type=\"stream\"")); assert!(metrics_body.contains("request_type=\"stream\""));
assert!(metrics_body.contains("status=\"success\"")); assert!(metrics_body.contains("status=\"success\""));
// etcd lease will ensure we everything is deleted from etcd
// Now test the complete lifecycle: remove the model from etcd // Now test the complete lifecycle: remove the model from etcd
// We don't need to cleanup model manager because it's local to this test
/*
// Clean up
// Remove the model using the cleaner ModelWatcher approach // Remove the model using the cleaner ModelWatcher approach
if let Some(etcd_client) = distributed_runtime.etcd_client() { if let Some(etcd_client) = distributed_runtime.etcd_client() {
// Use ModelWatcher to find and remove the model (following ModelWatcher::handle_delete pattern) // Use ModelWatcher to find and remove the model (following ModelWatcher::handle_delete pattern)
...@@ -514,10 +518,7 @@ mod integration_tests { ...@@ -514,10 +518,7 @@ mod integration_tests {
); );
// Get all model entries for our test model // Get all model entries for our test model
let model_entries = watcher let model_entries = watcher.entries_for_model("test-mdc-model").await.unwrap();
.entries_for_model("test-mdc-model", None, true)
.await
.unwrap();
if !model_entries.is_empty() { if !model_entries.is_empty() {
// For each model entry, we need to find its etcd key and remove it // For each model entry, we need to find its etcd key and remove it
...@@ -566,8 +567,8 @@ mod integration_tests { ...@@ -566,8 +567,8 @@ mod integration_tests {
} }
} }
} }
*/
// Clean up
cancel_token.cancel(); cancel_token.cancel();
task.await.unwrap().unwrap(); task.await.unwrap().unwrap();
} }
......
...@@ -280,31 +280,32 @@ pub mod kserve_test { ...@@ -280,31 +280,32 @@ pub mod kserve_test {
let failure = Arc::new(AlwaysFailEngine {}); let failure = Arc::new(AlwaysFailEngine {});
let long_running = Arc::new(LongRunningEngine::new(1_000)); let long_running = Arc::new(LongRunningEngine::new(1_000));
manager
.add_completions_model("split", split.clone())
.unwrap();
let mut card = ModelDeploymentCard::with_name_only("split"); let mut card = ModelDeploymentCard::with_name_only("split");
card.model_type = ModelType::Completions; card.model_type = ModelType::Completions;
card.model_input = ModelInput::Text; card.model_input = ModelInput::Text;
manager.save_model_card("split", card);
manager manager
.add_chat_completions_model("failure", failure.clone()) .add_completions_model("split", card.mdcsum(), split.clone())
.unwrap();
manager
.add_completions_model("failure", failure.clone())
.unwrap(); .unwrap();
let _ = manager.save_model_card("split", card.clone());
let mut card = ModelDeploymentCard::with_name_only("failure"); let mut card = ModelDeploymentCard::with_name_only("failure");
card.model_type = ModelType::Completions | ModelType::Chat; card.model_type = ModelType::Completions | ModelType::Chat;
card.model_input = ModelInput::Text; card.model_input = ModelInput::Text;
manager.save_model_card("failure", card);
manager manager
.add_completions_model("long_running", long_running.clone()) .add_chat_completions_model("failure", card.mdcsum(), failure.clone())
.unwrap(); .unwrap();
manager
.add_completions_model("failure", card.mdcsum(), failure.clone())
.unwrap();
let _ = manager.save_model_card("failure", card);
let mut card = ModelDeploymentCard::with_name_only("long_running"); let mut card = ModelDeploymentCard::with_name_only("long_running");
card.model_type = ModelType::Completions; card.model_type = ModelType::Completions;
card.model_input = ModelInput::Text; card.model_input = ModelInput::Text;
manager.save_model_card("long_running", card); manager
.add_completions_model("long_running", card.mdcsum(), long_running.clone())
.unwrap();
let _ = manager.save_model_card("long_running", card);
(service, split, failure, long_running) (service, split, failure, long_running)
} }
...@@ -1130,11 +1131,16 @@ pub mod kserve_test { ...@@ -1130,11 +1131,16 @@ pub mod kserve_test {
text_input: inference::model_infer_request::InferInputTensor, text_input: inference::model_infer_request::InferInputTensor,
) { ) {
// add tensor model // add tensor model
// Failure, model registered as Tensor but does not provide model config (in runtime config)
let mut card = ModelDeploymentCard::with_name_only("tensor");
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
let tensor = Arc::new(TensorEngine {}); let tensor = Arc::new(TensorEngine {});
service_with_engines service_with_engines
.0 .0
.model_manager() .model_manager()
.add_tensor_model("tensor", tensor.clone()) .add_tensor_model("tensor", card.mdcsum(), tensor.clone())
.unwrap(); .unwrap();
// start server // start server
...@@ -1147,11 +1153,7 @@ pub mod kserve_test { ...@@ -1147,11 +1153,7 @@ pub mod kserve_test {
version: "".into(), version: "".into(),
}); });
// Failure, model registered as Tensor but does not provide model config (in runtime config) let _ = service_with_engines
let mut card = ModelDeploymentCard::with_name_only("tensor");
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
service_with_engines
.0 .0
.model_manager() .model_manager()
.save_model_card("key", card); .save_model_card("key", card);
...@@ -1217,7 +1219,7 @@ pub mod kserve_test { ...@@ -1217,7 +1219,7 @@ pub mod kserve_test {
}), }),
..Default::default() ..Default::default()
}; };
service_with_engines let _ = service_with_engines
.0 .0
.model_manager() .model_manager()
.save_model_card("key", card); .save_model_card("key", card);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment