Unverified Commit 81162dfe authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(discovery): Watch/publish ModelDeploymentCard instead of ModelEntry (#3350)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent ddbb4f50
...@@ -69,7 +69,7 @@ async def worker(runtime: DistributedRuntime): ...@@ -69,7 +69,7 @@ async def worker(runtime: DistributedRuntime):
host: str = "localhost" host: str = "localhost"
port: int = 8000 port: int = 8000
service: HttpService = HttpService(port=port) service: HttpService = HttpService(port=port)
service.add_chat_completions_model(served_model_name, engine) service.add_chat_completions_model(served_model_name, "mdcsum", engine)
print("Starting service...") print("Starting service...")
shutdown_signal = service.run(runtime.child_token()) shutdown_signal = service.run(runtime.child_token())
......
...@@ -30,23 +30,29 @@ impl HttpService { ...@@ -30,23 +30,29 @@ impl HttpService {
Ok(Self { inner }) Ok(Self { inner })
} }
pub fn add_completions_model(&self, model: String, engine: HttpAsyncEngine) -> PyResult<()> { pub fn add_completions_model(
&self,
model: String,
checksum: String,
engine: HttpAsyncEngine,
) -> PyResult<()> {
let engine = Arc::new(engine); let engine = Arc::new(engine);
self.inner self.inner
.model_manager() .model_manager()
.add_completions_model(&model, engine) .add_completions_model(&model, &checksum, engine)
.map_err(to_pyerr) .map_err(to_pyerr)
} }
pub fn add_chat_completions_model( pub fn add_chat_completions_model(
&self, &self,
model: String, model: String,
checksum: String,
engine: HttpAsyncEngine, engine: HttpAsyncEngine,
) -> PyResult<()> { ) -> PyResult<()> {
let engine = Arc::new(engine); let engine = Arc::new(engine);
self.inner self.inner
.model_manager() .model_manager()
.add_chat_completions_model(&model, engine) .add_chat_completions_model(&model, &checksum, engine)
.map_err(to_pyerr) .map_err(to_pyerr)
} }
......
...@@ -85,6 +85,7 @@ async def http_server(runtime: DistributedRuntime): ...@@ -85,6 +85,7 @@ async def http_server(runtime: DistributedRuntime):
model_name = "test_model" model_name = "test_model"
start_done = asyncio.Event() start_done = asyncio.Event()
child_token = runtime.child_token() child_token = runtime.child_token()
checksum = "abc123" # Checksum of ModelDeplomentCard for that model
async def worker(): async def worker():
"""The server worker task.""" """The server worker task."""
...@@ -94,7 +95,7 @@ async def http_server(runtime: DistributedRuntime): ...@@ -94,7 +95,7 @@ async def http_server(runtime: DistributedRuntime):
engine = HttpAsyncEngine(python_engine.generate, loop) engine = HttpAsyncEngine(python_engine.generate, loop)
service = HttpService(port=port) service = HttpService(port=port)
service.add_chat_completions_model(model_name, engine) service.add_chat_completions_model(model_name, checksum, engine)
service.enable_endpoint("chat", True) service.enable_endpoint("chat", True)
shutdown_signal = service.run(child_token) shutdown_signal = service.run(child_token)
......
...@@ -33,7 +33,7 @@ pub struct Checksum { ...@@ -33,7 +33,7 @@ pub struct Checksum {
algorithm: CryptographicHashMethods, algorithm: CryptographicHashMethods,
} }
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)] #[derive(Serialize, Deserialize, Debug, Clone, Copy, Eq, PartialEq)]
pub enum CryptographicHashMethods { pub enum CryptographicHashMethods {
#[serde(rename = "blake3")] #[serde(rename = "blake3")]
BLAKE3, BLAKE3,
...@@ -259,6 +259,15 @@ impl TryFrom<&str> for Checksum { ...@@ -259,6 +259,15 @@ impl TryFrom<&str> for Checksum {
} }
} }
impl Default for Checksum {
fn default() -> Self {
Self {
hash: "".to_string(),
algorithm: CryptographicHashMethods::BLAKE3,
}
}
}
impl FromStr for CryptographicHashMethods { impl FromStr for CryptographicHashMethods {
type Err = String; type Err = String;
......
...@@ -4,14 +4,8 @@ ...@@ -4,14 +4,8 @@
mod model_manager; mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError}; pub use model_manager::{ModelManager, ModelManagerError};
mod model_entry;
pub use model_entry::ModelEntry;
mod watcher; mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher}; pub use watcher::{ModelUpdate, ModelWatcher};
/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
/// The root etcd path for KV Router registrations /// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "kv_routers"; pub const KV_ROUTERS_ROOT_PATH: &str = "kv_routers";
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::{protocols, slug::Slug};
use serde::{Deserialize, Serialize};
use crate::local_model::runtime_config::ModelRuntimeConfig;
/// [ModelEntry] contains the information to discover models
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
/// Public name of the model
/// Used to identify the model in the HTTP service from the value used in an OpenAI ChatRequest.
pub name: String,
/// How to address this on the network
#[serde(rename = "endpoint")]
pub endpoint_id: protocols::EndpointId,
/// Runtime configuration specific to this model instance
#[serde(default, skip_serializing_if = "Option::is_none")]
pub runtime_config: Option<ModelRuntimeConfig>,
}
impl ModelEntry {
/// Slugified display name for use in network storage, or URL-safe environments
pub fn slug(&self) -> Slug {
Slug::from_string(&self.name)
}
}
...@@ -11,7 +11,6 @@ use parking_lot::{Mutex, RwLock}; ...@@ -11,7 +11,6 @@ use parking_lot::{Mutex, RwLock};
use dynamo_runtime::component::Component; use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider; use dynamo_runtime::prelude::DistributedRuntimeProvider;
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard}; use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
use crate::{ use crate::{
kv_router::KvRouter, kv_router::KvRouter,
...@@ -21,6 +20,10 @@ use crate::{ ...@@ -21,6 +20,10 @@ use crate::{
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine, completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
}, },
}; };
use crate::{
kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector},
model_type::ModelType,
};
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum ModelManagerError { pub enum ModelManagerError {
...@@ -39,7 +42,7 @@ pub struct ModelManager { ...@@ -39,7 +42,7 @@ pub struct ModelManager {
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>, embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>, tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// These two are Mutex because we read and write rarely and equally // These are Mutex because we read and write rarely and equally
cards: Mutex<HashMap<String, ModelDeploymentCard>>, cards: Mutex<HashMap<String, ModelDeploymentCard>>,
kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>, kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
} }
...@@ -62,6 +65,43 @@ impl ModelManager { ...@@ -62,6 +65,43 @@ impl ModelManager {
} }
} }
pub fn is_valid_checksum(
&self,
model_type: ModelType,
model_name: &str,
candidate_checksum: &str,
) -> Option<bool> {
let mut results = vec![];
for unit in model_type.units() {
let maybe_valid_checksum = match unit {
ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
_ => {
continue;
}
};
if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
tracing::debug!(
model_name,
valid_checksum,
candidate_checksum,
"is_valid_checksum: check case"
);
valid_checksum == candidate_checksum
}) {
results.push(is_valid)
}
}
if results.is_empty() {
None
} else {
// The checksum is valid if it is correct for all the ModelType in the bitflag.
Some(results.into_iter().all(|x| x))
}
}
pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> { pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.cards.lock().values().cloned().collect() self.cards.lock().values().cloned().collect()
} }
...@@ -99,37 +139,41 @@ impl ModelManager { ...@@ -99,37 +139,41 @@ impl ModelManager {
pub fn add_completions_model( pub fn add_completions_model(
&self, &self,
model: &str, model: &str,
card_checksum: &str,
engine: OpenAICompletionsStreamingEngine, engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write(); let mut clients = self.completion_engines.write();
clients.add(model, engine) clients.add(model, card_checksum, engine)
} }
pub fn add_chat_completions_model( pub fn add_chat_completions_model(
&self, &self,
model: &str, model: &str,
card_checksum: &str,
engine: OpenAIChatCompletionsStreamingEngine, engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write(); let mut clients = self.chat_completion_engines.write();
clients.add(model, engine) clients.add(model, card_checksum, engine)
} }
pub fn add_embeddings_model( pub fn add_embeddings_model(
&self, &self,
model: &str, model: &str,
card_checksum: &str,
engine: OpenAIEmbeddingsStreamingEngine, engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write(); let mut clients = self.embeddings_engines.write();
clients.add(model, engine) clients.add(model, card_checksum, engine)
} }
pub fn add_tensor_model( pub fn add_tensor_model(
&self, &self,
model: &str, model: &str,
card_checksum: &str,
engine: TensorStreamingEngine, engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> { ) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write(); let mut clients = self.tensor_engines.write();
clients.add(model, engine) clients.add(model, card_checksum, engine)
} }
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> { pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
...@@ -196,10 +240,11 @@ impl ModelManager { ...@@ -196,10 +240,11 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string())) .ok_or(ModelManagerError::ModelNotFound(model.to_string()))
} }
/// Save a ModelDeploymentCard from an instance's etcd `models/` key so we can fetch it later when the key is /// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
/// deleted from etcd. /// deleted.
pub fn save_model_card(&self, key: &str, entry: ModelDeploymentCard) { pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.lock().insert(key.to_string(), entry); self.cards.lock().insert(key.to_string(), card);
Ok(())
} }
/// Remove and return model card for this instance's etcd key. We do this when the instance stops. /// Remove and return model card for this instance's etcd key. We do this when the instance stops.
...@@ -291,6 +336,9 @@ pub struct ModelEngines<E> { ...@@ -291,6 +336,9 @@ pub struct ModelEngines<E> {
/// Optional default model name /// Optional default model name
default: Option<String>, default: Option<String>,
engines: HashMap<String, E>, engines: HashMap<String, E>,
/// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
/// same card.
checksums: HashMap<String, String>,
} }
impl<E> Default for ModelEngines<E> { impl<E> Default for ModelEngines<E> {
...@@ -298,6 +346,7 @@ impl<E> Default for ModelEngines<E> { ...@@ -298,6 +346,7 @@ impl<E> Default for ModelEngines<E> {
Self { Self {
default: None, default: None,
engines: HashMap::new(), engines: HashMap::new(),
checksums: HashMap::new(),
} }
} }
} }
...@@ -313,11 +362,13 @@ impl<E> ModelEngines<E> { ...@@ -313,11 +362,13 @@ impl<E> ModelEngines<E> {
self.default = None; self.default = None;
} }
fn add(&mut self, model: &str, engine: E) -> Result<(), ModelManagerError> { fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
if self.engines.contains_key(model) { if self.engines.contains_key(model) {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string())); return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
} }
self.engines.insert(model.to_string(), engine); self.engines.insert(model.to_string(), engine);
self.checksums
.insert(model.to_string(), checksum.to_string());
Ok(()) Ok(())
} }
...@@ -325,6 +376,7 @@ impl<E> ModelEngines<E> { ...@@ -325,6 +376,7 @@ impl<E> ModelEngines<E> {
if self.engines.remove(model).is_none() { if self.engines.remove(model).is_none() {
return Err(ModelManagerError::ModelNotFound(model.to_string())); return Err(ModelManagerError::ModelNotFound(model.to_string()));
} }
let _ = self.checksums.remove(model);
Ok(()) Ok(())
} }
...@@ -339,4 +391,10 @@ impl<E> ModelEngines<E> { ...@@ -339,4 +391,10 @@ impl<E> ModelEngines<E> {
pub fn list(&self) -> Vec<String> { pub fn list(&self) -> Vec<String> {
self.engines.keys().map(|k| k.to_owned()).collect() self.engines.keys().map(|k| k.to_owned()).collect()
} }
/// Returns a newly allocated String for called convenience. All the places I use
/// this I need a String.
pub fn checksum(&self, model: &str) -> Option<String> {
self.checksums.get(model).map(|s| s.to_string())
}
} }
...@@ -13,16 +13,15 @@ use dynamo_runtime::{ ...@@ -13,16 +13,15 @@ use dynamo_runtime::{
ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source, ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
network::egress::push_router::PushRouter, network::egress::push_router::PushRouter,
}, },
protocols::annotated::Annotated, protocols::{EndpointId, annotated::Annotated},
storage::key_value_store::Key, transports::etcd::WatchEvent,
transports::etcd::{KeyValue, WatchEvent},
}; };
use crate::{ use crate::{
backend::Backend, backend::Backend,
entrypoint, entrypoint,
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_card::ModelDeploymentCard, model_card::{self, ModelDeploymentCard},
model_type::{ModelInput, ModelType}, model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
protocols::{ protocols::{
...@@ -38,7 +37,7 @@ use crate::{ ...@@ -38,7 +37,7 @@ use crate::{
}, },
}; };
use super::{MODEL_ROOT_PATH, ModelEntry, ModelManager}; use super::ModelManager;
use crate::namespace::is_global_namespace; use crate::namespace::is_global_namespace;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
...@@ -105,64 +104,97 @@ impl ModelWatcher { ...@@ -105,64 +104,97 @@ impl ModelWatcher {
while let Some(event) = events_rx.recv().await { while let Some(event) = events_rx.recv().await {
match event { match event {
WatchEvent::Put(kv) => { WatchEvent::Put(kv) => {
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) { let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) {
Ok(model_entry) => model_entry, Ok(card) => card,
Err(err) => { Err(err) => {
match kv.value_str() { match kv.value_str() {
Ok(value) => { Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model entry") tracing::error!(%err, value, "Invalid JSON in model card")
} }
Err(value_str_err) => { Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON") tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
} }
} }
continue; continue;
} }
}; };
let key = match kv.key_str() {
Ok(k) => k,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid UTF-8 string in model card key, skipping");
continue;
}
};
let endpoint_id = match etcd_key_extract(key) {
Ok((eid, _)) => eid,
Err(err) => {
tracing::error!(%key, model_name = card.name(), %err, "Failed extracting EndpointId from key. Ignoring instance.");
continue;
}
};
// Filter by namespace if target_namespace is specified // Filter by namespace if target_namespace is specified
if !global_namespace if !global_namespace
&& let Some(target_ns) = target_namespace && let Some(target_ns) = target_namespace
&& model_entry.endpoint_id.namespace != target_ns && endpoint_id.namespace != target_ns
{ {
tracing::debug!( tracing::debug!(
model_namespace = model_entry.endpoint_id.namespace, model_namespace = endpoint_id.namespace,
target_namespace = target_ns, target_namespace = target_ns,
model_name = model_entry.name, model_name = card.name(),
"Skipping model from different namespace" "Skipping model from different namespace"
); );
continue; continue;
} }
let key = match kv.key_str() { // If we already have a worker for this model, and the ModelDeploymentCard
Ok(k) => k, // cards don't match, alert, and don't add the new instance
Err(err) => { let can_add =
tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping"); self.manager
.is_valid_checksum(card.model_type, card.name(), card.mdcsum());
if can_add.is_some_and(|is_valid| !is_valid) {
tracing::error!(
model_name = card.name(),
"Checksum for new model does not match existing model."
);
// TODO: mark that instance down in clients
// Not obvious how to do that given the current design
// Instances come from an `InstanceSource` in a `Client` in a `PushRouter`.
// Calling `report_instance_down` on the Client should do it (although
// needs more testing).
// The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside
// interface `AsyncEngine` which only has a `generate` method.
continue; continue;
} }
};
match self.handle_put(key, &model_entry).await { match self.handle_put(key, &endpoint_id, &mut card).await {
Ok(()) => { Ok(()) => {
tracing::info!( tracing::info!(
model_name = model_entry.name, model_name = card.name(),
namespace = model_entry.endpoint_id.namespace, namespace = endpoint_id.namespace,
"added model" "added model"
); );
self.notify_on_model.notify_waiters(); self.notify_on_model.notify_waiters();
} }
Err(err) => { Err(err) => {
tracing::error!( tracing::error!(
model_name = card.name(),
namespace = endpoint_id.namespace,
error = format!("{err:#}"), error = format!("{err:#}"),
"error adding model {} from namespace {}", "Error adding model from discovery",
model_entry.name,
model_entry.endpoint_id.namespace,
); );
} }
} }
} }
WatchEvent::Delete(kv) => match self WatchEvent::Delete(kv) => {
.handle_delete(&kv, target_namespace, global_namespace) let Ok(deleted_key) = kv.key_str() else {
tracing::warn!("Invalid UTF-8 in etcd delete notification key: {kv:?}");
continue;
};
match self
.handle_delete(deleted_key, target_namespace, global_namespace)
.await .await
{ {
Ok(Some(model_name)) => { Ok(Some(model_name)) => {
...@@ -174,7 +206,8 @@ impl ModelWatcher { ...@@ -174,7 +206,8 @@ impl ModelWatcher {
Err(e) => { Err(e) => {
tracing::error!(error = %e, "error removing model"); tracing::error!(error = %e, "error removing model");
} }
}, }
}
} }
} }
} }
...@@ -183,20 +216,19 @@ impl ModelWatcher { ...@@ -183,20 +216,19 @@ impl ModelWatcher {
/// Returns the name of the model we just deleted, if any. /// Returns the name of the model we just deleted, if any.
async fn handle_delete( async fn handle_delete(
&self, &self,
kv: &KeyValue, key: &str,
target_namespace: Option<&str>, target_namespace: Option<&str>,
is_global_namespace: bool, is_global_namespace: bool,
) -> anyhow::Result<Option<String>> { ) -> anyhow::Result<Option<String>> {
let key = kv.key_str()?;
let card = match self.manager.remove_model_card(key) { let card = match self.manager.remove_model_card(key) {
Some(card) => card, Some(card) => card,
None => { None => {
anyhow::bail!("Missing ModelDeploymentCard for {key}"); anyhow::bail!("Missing ModelDeploymentCard for {key}");
} }
}; };
let model_name = card.display_name.clone(); let model_name = card.name().to_string();
let active_instances = self let active_instances = self
.entries_for_model(&model_name, target_namespace, is_global_namespace) .cards_for_model(&model_name, target_namespace, is_global_namespace)
.await .await
.with_context(|| model_name.clone())?; .with_context(|| model_name.clone())?;
if !active_instances.is_empty() { if !active_instances.is_empty() {
...@@ -265,53 +297,35 @@ impl ModelWatcher { ...@@ -265,53 +297,35 @@ impl ModelWatcher {
// Handles a PUT event from etcd, this usually means adding a new model to the list of served // Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models. // models.
async fn handle_put(&self, key: &str, model_entry: &ModelEntry) -> anyhow::Result<()> { async fn handle_put(
let endpoint_id = &model_entry.endpoint_id; &self,
key: &str,
endpoint_id: &EndpointId,
card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> {
card.move_from_nats(self.drt.nats_client()).await?;
let component = self let component = self
.drt .drt
.namespace(&endpoint_id.namespace)? .namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?; .component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?; let client = component.endpoint(&endpoint_id.name).client().await?;
let model_slug = model_entry.slug(); tracing::debug!(model_name = card.name(), "adding model");
let card = match ModelDeploymentCard::load_from_store( self.manager.save_model_card(key, card.clone())?;
&Key::from_raw(model_slug.to_string()),
&self.drt,
)
.await
{
Ok(Some(mut card)) => {
tracing::debug!(card.display_name, "adding model");
// Ensure runtime_config is populated
if let Some(rc) = model_entry.runtime_config.clone() {
card.runtime_config = rc;
}
card
}
Ok(None) => {
anyhow::bail!("Missing ModelDeploymentCard in storage under key {model_slug}");
}
Err(err) => {
anyhow::bail!(
"Error fetching ModelDeploymentCard from storage under key {model_slug}. {err}"
);
}
};
self.manager.save_model_card(key, card.clone());
if self.manager.has_model_any(&model_entry.name) { if self.manager.has_model_any(card.name()) {
tracing::trace!( tracing::debug!(
name = model_entry.name, model_name = card.name(),
namespace = model_entry.endpoint_id.namespace, namespace = endpoint_id.namespace,
"New endpoint for existing model" "New endpoint for existing model"
); );
self.notify_on_model.notify_waiters(); //self.notify_on_model.notify_waiters();
return Ok(()); return Ok(());
} }
if let Some(tx) = &self.model_update_tx { if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Added(card.clone())).await.ok(); tx.send(ModelUpdate::Added(card.clone())).await.ok();
} }
let checksum = card.mdcsum();
if card.model_input == ModelInput::Tokens if card.model_input == ModelInput::Tokens
&& (card.model_type.supports_chat() || card.model_type.supports_completions()) && (card.model_type.supports_chat() || card.model_type.supports_completions())
...@@ -324,7 +338,7 @@ impl ModelWatcher { ...@@ -324,7 +338,7 @@ impl ModelWatcher {
Some( Some(
self.manager self.manager
.kv_chooser_for( .kv_chooser_for(
&model_entry.name, card.name(),
&component, &component,
card.kv_cache_block_size, card.kv_cache_block_size,
self.kv_router_config, self.kv_router_config,
...@@ -344,7 +358,7 @@ impl ModelWatcher { ...@@ -344,7 +358,7 @@ impl ModelWatcher {
NvCreateChatCompletionRequest, NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>( >(
&card, card,
&client, &client,
self.router_mode, self.router_mode,
self.busy_threshold, self.busy_threshold,
...@@ -354,7 +368,7 @@ impl ModelWatcher { ...@@ -354,7 +368,7 @@ impl ModelWatcher {
.await .await
.context("build_routed_pipeline")?; .context("build_routed_pipeline")?;
self.manager self.manager
.add_chat_completions_model(&model_entry.name, chat_engine) .add_chat_completions_model(card.name(), checksum, chat_engine)
.context("add_chat_completions_model")?; .context("add_chat_completions_model")?;
tracing::info!("Chat completions is ready"); tracing::info!("Chat completions is ready");
} }
...@@ -373,7 +387,7 @@ impl ModelWatcher { ...@@ -373,7 +387,7 @@ impl ModelWatcher {
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>( >(
&card, card,
&client, &client,
self.router_mode, self.router_mode,
self.busy_threshold, self.busy_threshold,
...@@ -384,7 +398,7 @@ impl ModelWatcher { ...@@ -384,7 +398,7 @@ impl ModelWatcher {
.await .await
.context("build_routed_pipeline_with_preprocessor")?; .context("build_routed_pipeline_with_preprocessor")?;
self.manager self.manager
.add_completions_model(&model_entry.name, completions_engine) .add_completions_model(card.name(), checksum, completions_engine)
.context("add_completions_model")?; .context("add_completions_model")?;
tracing::info!("Completions is ready"); tracing::info!("Completions is ready");
} }
...@@ -411,7 +425,7 @@ impl ModelWatcher { ...@@ -411,7 +425,7 @@ impl ModelWatcher {
.await?; .await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_chat_completions_model(&model_entry.name, engine)?; .add_chat_completions_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_completions() { } else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
// Case 2: Text + Completions // Case 2: Text + Completions
let push_router = PushRouter::< let push_router = PushRouter::<
...@@ -423,7 +437,7 @@ impl ModelWatcher { ...@@ -423,7 +437,7 @@ impl ModelWatcher {
.await?; .await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager self.manager
.add_completions_model(&model_entry.name, engine)?; .add_completions_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() { } else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
// Case 4: Tokens + Embeddings // Case 4: Tokens + Embeddings
...@@ -434,7 +448,7 @@ impl ModelWatcher { ...@@ -434,7 +448,7 @@ impl ModelWatcher {
>::new(); >::new();
let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator(); let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
let backend = Backend::from_mdc(&card).into_operator(); let backend = Backend::from_mdc(card).into_operator();
let router = PushRouter::< let router = PushRouter::<
PreprocessedEmbeddingRequest, PreprocessedEmbeddingRequest,
...@@ -457,7 +471,7 @@ impl ModelWatcher { ...@@ -457,7 +471,7 @@ impl ModelWatcher {
.link(frontend)?; .link(frontend)?;
self.manager self.manager
.add_embeddings_model(&model_entry.name, embedding_engine)?; .add_embeddings_model(card.name(), checksum, embedding_engine)?;
} else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() { } else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
// Case 5: Tensor + Tensor (non-LLM) // Case 5: Tensor + Tensor (non-LLM)
let push_router = PushRouter::< let push_router = PushRouter::<
...@@ -468,7 +482,8 @@ impl ModelWatcher { ...@@ -468,7 +482,8 @@ impl ModelWatcher {
) )
.await?; .await?;
let engine = Arc::new(push_router); let engine = Arc::new(push_router);
self.manager.add_tensor_model(&model_entry.name, engine)?; self.manager
.add_tensor_model(card.name(), checksum, engine)?;
} else { } else {
// Reject unsupported combinations // Reject unsupported combinations
anyhow::bail!( anyhow::bail!(
...@@ -482,49 +497,116 @@ impl ModelWatcher { ...@@ -482,49 +497,116 @@ impl ModelWatcher {
Ok(()) Ok(())
} }
/// All the registered ModelEntry, one per instance /// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> { pub async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
let Some(etcd_client) = self.drt.etcd_client() else { let Some(etcd_client) = self.drt.etcd_client() else {
anyhow::bail!("all_entries: Missing etcd client"); anyhow::bail!("all_cards: Missing etcd client");
}; };
let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?; let kvs = etcd_client.kv_get_prefix(model_card::ROOT_PATH).await?;
let mut entries = Vec::with_capacity(kvs.len()); let mut results = Vec::with_capacity(kvs.len());
for kv in kvs { for kv in kvs {
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) { let maybe_convert = serde_json::from_slice::<ModelDeploymentCard>(kv.value());
Ok(model_entry) => model_entry, let r = match maybe_convert {
Ok(card) => {
let maybe_endpoint_id = kv.key_str().map_err(|err| err.into()).and_then(|k| {
etcd_key_extract(k).map(|(endpoint_id, _instance_id)| endpoint_id)
});
let endpoint_id = match maybe_endpoint_id {
Ok(eid) => eid,
Err(err) => {
tracing::error!(%err, "Skipping invalid etcd key, not string or not EndpointId");
continue;
}
};
(endpoint_id, card)
}
Err(err) => { Err(err) => {
match kv.value_str() { match kv.value_str() {
Ok(value) => { Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model entry") tracing::error!(%err, value, "Invalid JSON in model card");
} }
Err(value_str_err) => { Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON") tracing::error!(original_error=%err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON");
} }
} }
continue; continue;
} }
}; };
entries.push(model_entry); results.push(r);
} }
Ok(entries) Ok(results)
} }
pub async fn entries_for_model( pub async fn cards_for_model(
&self, &self,
model_name: &str, model_name: &str,
target_namespace: Option<&str>, target_namespace: Option<&str>,
is_global_namespace: bool, is_global_namespace: bool,
) -> anyhow::Result<Vec<ModelEntry>> { ) -> anyhow::Result<Vec<ModelDeploymentCard>> {
let mut all = self.all_entries().await?; let mut all = self.all_cards().await?;
all.retain(|entry| { all.retain(|(endpoint_id, card)| {
let matches_name = entry.name == model_name; let matches_name = card.name() == model_name;
let matches_namespace = match (is_global_namespace, target_namespace) { let matches_namespace = match (is_global_namespace, target_namespace) {
(true, _) => true, (true, _) => true,
(false, None) => true, (false, None) => true,
(false, Some(target_ns)) => entry.endpoint_id.namespace == target_ns, (false, Some(target_ns)) => endpoint_id.namespace == target_ns,
}; };
matches_name && matches_namespace matches_name && matches_namespace
}); });
Ok(all) Ok(all.into_iter().map(|(_eid, card)| card).collect())
}
}
/// The ModelDeploymentCard is published in etcd with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad".
/// Extract the EndpointId and instance_id from that.
fn etcd_key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> {
let parts: Vec<&str> = s.split('/').collect();
let start_idx = if !parts.is_empty() && parts[0] == "v1" {
1
} else {
0
};
// Need at least prefix model_card::ROOT_PATH + 3 parts: namespace, component, name
if parts.len() <= start_idx + 3 {
anyhow::bail!("Invalid format: not enough path segments in {s}");
}
if parts.get(start_idx) != Some(&model_card::ROOT_PATH) {
anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}");
}
let endpoint_id = EndpointId {
namespace: parts[start_idx + 1].to_string(),
component: parts[start_idx + 2].to_string(),
name: parts[start_idx + 3].to_string(),
};
Ok((endpoint_id, parts[parts.len() - 1].to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_etcd_key_extract() {
let input = format!(
"v1/{}/dynamo/backend/generate/694d9981145a61ad",
model_card::ROOT_PATH
);
let (endpoint_id, instance_id) = etcd_key_extract(&input).unwrap();
assert_eq!(endpoint_id.namespace, "dynamo");
assert_eq!(endpoint_id.component, "backend");
assert_eq!(endpoint_id.name, "generate");
assert_eq!(instance_id, "694d9981145a61ad");
let input = format!(
"{}/dynamo/backend/generate/694d9981145a61ad",
model_card::ROOT_PATH
);
let (endpoint_id, _) = etcd_key_extract(&input).unwrap();
assert_eq!(endpoint_id.namespace, "dynamo");
assert_eq!(endpoint_id.component, "backend");
assert_eq!(endpoint_id.name, "generate");
} }
} }
...@@ -5,12 +5,12 @@ use std::pin::Pin; ...@@ -5,12 +5,12 @@ use std::pin::Pin;
use crate::{ use crate::{
backend::{Backend, ExecutionContext}, backend::{Backend, ExecutionContext},
discovery::{MODEL_ROOT_PATH, ModelManager, ModelWatcher}, discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig}, entrypoint::{self, EngineConfig},
kv_router::{KvPushRouter, KvRouter}, kv_router::{KvPushRouter, KvRouter},
migration::Migration, migration::Migration,
model_card::ModelDeploymentCard, model_card::{self, ModelDeploymentCard},
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest}, protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate, request_template::RequestTemplate,
...@@ -73,7 +73,9 @@ pub async fn prepare_engine( ...@@ -73,7 +73,9 @@ pub async fn prepare_engine(
None, None,
None, None,
)); ));
let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?; let models_watcher = etcd_client
.kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let inner_watch_obj = watch_obj.clone(); let inner_watch_obj = watch_obj.clone();
......
...@@ -4,11 +4,12 @@ ...@@ -4,11 +4,12 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
discovery::{MODEL_ROOT_PATH, ModelManager, ModelWatcher}, discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common}, entrypoint::{self, EngineConfig, input::common},
grpc::service::kserve, grpc::service::kserve,
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace, namespace::is_global_namespace,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
...@@ -46,7 +47,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -46,7 +47,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
distributed_runtime, distributed_runtime,
grpc_service.state().manager_clone(), grpc_service.state().manager_clone(),
etcd_client.clone(), etcd_client.clone(),
MODEL_ROOT_PATH, model_card::ROOT_PATH,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config), Some(router_config.kv_router_config),
router_config.busy_threshold, router_config.busy_threshold,
...@@ -62,6 +63,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -62,6 +63,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} }
EngineConfig::StaticRemote(local_model) => { EngineConfig::StaticRemote(local_model) => {
let card = local_model.card(); let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode; let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static let dst_config = DistributedConfig::from_settings(true); // true means static
...@@ -103,7 +105,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -103,7 +105,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
tokenizer_hf.clone(), tokenizer_hf.clone(),
) )
.await?; .await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?; manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine = let completions_engine =
entrypoint::build_routed_pipeline::< entrypoint::build_routed_pipeline::<
...@@ -111,7 +117,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -111,7 +117,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf) >(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?; .await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?; manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
grpc_service grpc_service
} }
...@@ -119,8 +129,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -119,8 +129,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine)); let engine = Arc::new(StreamingEngineAdapter::new(engine));
let manager = grpc_service.model_manager(); let manager = grpc_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?; let checksum = model.card().mdcsum();
manager.add_chat_completions_model(model.service_name(), engine)?; manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
grpc_service grpc_service
} }
EngineConfig::StaticCore { EngineConfig::StaticCore {
...@@ -130,6 +141,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -130,6 +141,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} => { } => {
let grpc_service = grpc_service_builder.build()?; let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager(); let manager = grpc_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer_hf = model.card().tokenizer_hf()?; let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline = let chat_pipeline =
...@@ -138,14 +150,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -138,14 +150,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone()) >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?; .await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?; manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::< let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer_hf) >(model.card(), inner_engine, tokenizer_hf)
.await?; .await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?; manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
grpc_service grpc_service
} }
}; };
......
...@@ -4,12 +4,13 @@ ...@@ -4,12 +4,13 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
discovery::{MODEL_ROOT_PATH, ModelManager, ModelUpdate, ModelWatcher}, discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType, endpoint_type::EndpointType,
engines::StreamingEngineAdapter, engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common}, entrypoint::{self, EngineConfig, input::common},
http::service::service_v2::{self, HttpService}, http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace, namespace::is_global_namespace,
types::openai::{ types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse}, chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
...@@ -74,7 +75,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -74,7 +75,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
distributed_runtime, distributed_runtime,
http_service.state().manager_clone(), http_service.state().manager_clone(),
etcd_client.clone(), etcd_client.clone(),
MODEL_ROOT_PATH, model_card::ROOT_PATH,
router_config.router_mode, router_config.router_mode,
Some(router_config.kv_router_config), Some(router_config.kv_router_config),
router_config.busy_threshold, router_config.busy_threshold,
...@@ -92,6 +93,8 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -92,6 +93,8 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} }
EngineConfig::StaticRemote(local_model) => { EngineConfig::StaticRemote(local_model) => {
let card = local_model.card(); let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode; let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static let dst_config = DistributedConfig::from_settings(true); // true means static
...@@ -133,7 +136,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -133,7 +136,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
tokenizer_hf.clone(), tokenizer_hf.clone(),
) )
.await?; .await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?; manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine = let completions_engine =
entrypoint::build_routed_pipeline::< entrypoint::build_routed_pipeline::<
...@@ -141,7 +148,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -141,7 +148,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf) >(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?; .await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?; manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true); http_service.enable_model_endpoint(endpoint_type, true);
...@@ -153,8 +164,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -153,8 +164,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine)); let engine = Arc::new(StreamingEngineAdapter::new(engine));
let manager = http_service.model_manager(); let manager = http_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?; let checksum = model.card().mdcsum();
manager.add_chat_completions_model(model.service_name(), engine)?; manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
// Enable all endpoints // Enable all endpoints
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
...@@ -169,6 +181,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -169,6 +181,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} => { } => {
let http_service = http_service_builder.build()?; let http_service = http_service_builder.build()?;
let manager = http_service.model_manager(); let manager = http_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer_hf = model.card().tokenizer_hf()?; let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline = let chat_pipeline =
...@@ -177,14 +190,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul ...@@ -177,14 +190,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateChatCompletionStreamResponse, NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone()) >(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?; .await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?; manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::< let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer_hf) >(model.card(), inner_engine, tokenizer_hf)
.await?; .await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?; manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
// Enable all endpoints // Enable all endpoints
for endpoint_type in EndpointType::all() { for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true); http_service.enable_model_endpoint(endpoint_type, true);
......
...@@ -32,7 +32,6 @@ pub mod sequence; ...@@ -32,7 +32,6 @@ pub mod sequence;
pub mod subscriber; pub mod subscriber;
use crate::{ use crate::{
discovery::{MODEL_ROOT_PATH, ModelEntry},
kv_router::{ kv_router::{
approx::ApproxKvIndexer, approx::ApproxKvIndexer,
indexer::{ indexer::{
...@@ -45,6 +44,7 @@ use crate::{ ...@@ -45,6 +44,7 @@ use crate::{
subscriber::start_kv_router_background, subscriber::start_kv_router_background,
}, },
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
model_card::{self, ModelDeploymentCard},
preprocessor::PreprocessedRequest, preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput, protocols::common::llm_backend::LLMEngineOutput,
}; };
...@@ -247,9 +247,9 @@ impl KvRouter { ...@@ -247,9 +247,9 @@ impl KvRouter {
let runtime_configs_watcher = watch_prefix_with_extraction( let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client, etcd_client,
MODEL_ROOT_PATH, model_card::ROOT_PATH,
key_extractors::lease_id, key_extractors::lease_id,
|model_entry: ModelEntry| model_entry.runtime_config, |card: ModelDeploymentCard| Some(card.runtime_config),
cancellation_token.clone(), cancellation_token.clone(),
) )
.await?; .await?;
......
...@@ -15,15 +15,12 @@ use dynamo_runtime::{ ...@@ -15,15 +15,12 @@ use dynamo_runtime::{
storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager}, storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
}; };
use crate::discovery::ModelEntry;
use crate::entrypoint::RouterConfig; use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs; use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::{self, ModelDeploymentCard}; use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::{ModelInput, ModelType}; use crate::model_type::{ModelInput, ModelType};
use crate::request_template::RequestTemplate; use crate::request_template::RequestTemplate;
mod network_name;
pub use network_name::ModelNetworkName;
pub mod runtime_config; pub mod runtime_config;
use runtime_config::ModelRuntimeConfig; use runtime_config::ModelRuntimeConfig;
...@@ -421,36 +418,13 @@ impl LocalModel { ...@@ -421,36 +418,13 @@ impl LocalModel {
// Publish the Model Deployment Card to KV store // Publish the Model Deployment Card to KV store
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone())); let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string(); let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
// TODO: Next PR will use this let key = Key::from_raw(endpoint.unique_path(lease_id));
//let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
//let key = Key::from_raw(endpoint.unique_path(lease_id));
card_store
.publish(
model_card::ROOT_PATH,
None,
&Key::from_raw(key),
&mut self.card,
)
.await?;
// Publish our ModelEntry to etcd. This allows ingress to find the model card. let _outcome = card_store
// (Why don't we put the model card directly under this key?) .publish(model_card::ROOT_PATH, None, &key, &mut self.card)
let network_name = ModelNetworkName::new(); .await?;
tracing::debug!("Registering with etcd as {network_name}"); Ok(())
let model_registration = ModelEntry {
name: self.display_name().to_string(),
endpoint_id: endpoint.id(),
runtime_config: Some(self.runtime_config.clone()),
};
etcd_client
.kv_create(
&network_name,
serde_json::to_vec_pretty(&model_registration)?,
None, // use primary lease
)
.await
} }
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::discovery::MODEL_ROOT_PATH;
#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);
impl ModelNetworkName {
pub fn new() -> Self {
ModelNetworkName(format!("{MODEL_ROOT_PATH}/{}", uuid::Uuid::new_v4()))
}
}
impl Default for ModelNetworkName {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl AsRef<str> for ModelNetworkName {
fn as_ref(&self) -> &str {
&self.0
}
}
impl std::ops::Deref for ModelNetworkName {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
...@@ -15,9 +15,9 @@ ...@@ -15,9 +15,9 @@
use std::fmt; use std::fmt;
use std::fs::File; use std::fs::File;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use std::sync::Arc; use std::sync::{Arc, OnceLock};
use crate::common::checked_file::CheckedFile; use crate::common::checked_file::{CheckedFile, Checksum};
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_type::{ModelInput, ModelType}; use crate::model_type::{ModelInput, ModelType};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
...@@ -43,6 +43,15 @@ pub enum ModelInfoType { ...@@ -43,6 +43,15 @@ pub enum ModelInfoType {
GGUF(PathBuf), GGUF(PathBuf),
} }
impl ModelInfoType {
pub fn checksum(&self) -> String {
match self {
ModelInfoType::HfConfigJson(c) => c.checksum().to_string(),
ModelInfoType::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum TokenizerKind { pub enum TokenizerKind {
...@@ -50,6 +59,15 @@ pub enum TokenizerKind { ...@@ -50,6 +59,15 @@ pub enum TokenizerKind {
GGUF(Box<HfTokenizer>), GGUF(Box<HfTokenizer>),
} }
impl TokenizerKind {
pub fn checksum(&self) -> String {
match self {
TokenizerKind::HfTokenizerJson(c) => c.checksum().to_string(),
TokenizerKind::GGUF(_) => Checksum::default().to_string(),
}
}
}
/// Supported types of prompt formatters. /// Supported types of prompt formatters.
/// ///
/// We need a way to associate the prompt formatter template definition with an associated /// We need a way to associate the prompt formatter template definition with an associated
...@@ -70,6 +88,16 @@ pub enum PromptFormatterArtifact { ...@@ -70,6 +88,16 @@ pub enum PromptFormatterArtifact {
GGUF(PathBuf), GGUF(PathBuf),
} }
impl PromptFormatterArtifact {
pub fn checksum(&self) -> String {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.checksum().to_string(),
PromptFormatterArtifact::HfChatTemplate(c) => c.checksum().to_string(),
PromptFormatterArtifact::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)] #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum PromptContextMixin { pub enum PromptContextMixin {
...@@ -87,6 +115,15 @@ pub enum GenerationConfig { ...@@ -87,6 +115,15 @@ pub enum GenerationConfig {
GGUF(PathBuf), GGUF(PathBuf),
} }
impl GenerationConfig {
pub fn checksum(&self) -> String {
match self {
GenerationConfig::HfGenerationConfigJson(c) => c.checksum().to_string(),
GenerationConfig::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)] #[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
pub struct ModelDeploymentCard { pub struct ModelDeploymentCard {
/// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct" /// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct"
...@@ -145,6 +182,9 @@ pub struct ModelDeploymentCard { ...@@ -145,6 +182,9 @@ pub struct ModelDeploymentCard {
#[serde(skip)] #[serde(skip)]
cache_dir: Option<Arc<tempfile::TempDir>>, cache_dir: Option<Arc<tempfile::TempDir>>,
#[serde(skip, default)]
checksum: OnceLock<String>,
} }
impl ModelDeploymentCard { impl ModelDeploymentCard {
...@@ -189,6 +229,12 @@ impl ModelDeploymentCard { ...@@ -189,6 +229,12 @@ impl ModelDeploymentCard {
Ok(()) Ok(())
} }
#[inline]
pub fn name(&self) -> &str {
&self.display_name
}
#[inline]
pub fn slug(&self) -> &Slug { pub fn slug(&self) -> &Slug {
&self.slug &self.slug
} }
...@@ -198,9 +244,45 @@ impl ModelDeploymentCard { ...@@ -198,9 +244,45 @@ impl ModelDeploymentCard {
Ok(serde_json::to_string(self)?) Ok(serde_json::to_string(self)?)
} }
pub fn mdcsum(&self) -> String { pub fn mdcsum(&self) -> &str {
let json = self.to_json().unwrap(); self.checksum
format!("{}", blake3::hash(json.as_bytes())) .get_or_init(|| {
// Only include the important fields
let mut bytes_to_hash: Vec<u8> = Vec::with_capacity(512);
bytes_to_hash.extend(self.display_name.as_bytes());
// The files can be either a URL or a local path, so we ignore that and hash their
// checksum instead, which won't change wherever they are.
if let Some(model_info) = self.model_info.as_ref() {
bytes_to_hash.extend(model_info.checksum().as_bytes());
}
if let Some(tokenizer) = self.tokenizer.as_ref() {
bytes_to_hash.extend(tokenizer.checksum().as_bytes());
}
if let Some(prompt_formatter) = self.prompt_formatter.as_ref() {
bytes_to_hash.extend(prompt_formatter.checksum().as_bytes());
}
if let Some(chat_template) = self.chat_template_file.as_ref() {
bytes_to_hash.extend(chat_template.checksum().as_bytes());
}
if let Some(gen_config) = self.gen_config.as_ref() {
bytes_to_hash.extend(gen_config.checksum().as_bytes());
}
if let Some(prompt_context_vec) = self.prompt_context.as_ref() {
// Paste it as the bytes of the debug format. It's a Vec of enum, so this should be
// fine. If the debug representation changes that only happens in a new release.
bytes_to_hash.extend(format!("{prompt_context_vec:?}").as_bytes());
}
bytes_to_hash.extend(self.context_length.to_be_bytes());
bytes_to_hash.extend(self.kv_cache_block_size.to_be_bytes());
// TODO: Do we want any of user_data or runtime_config?
blake3::hash(&bytes_to_hash).to_string()
})
.as_ref()
} }
/// Is this a full model card with tokenizer? /// Is this a full model card with tokenizer?
...@@ -291,9 +373,7 @@ impl ModelDeploymentCard { ...@@ -291,9 +373,7 @@ impl ModelDeploymentCard {
/// Move the files this MDC uses from the NATS object store to local disk. /// Move the files this MDC uses from the NATS object store to local disk.
/// Updates the URI's to point to the created files. /// Updates the URI's to point to the created files.
/// pub async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
/// The returned TempDir must be kept alive, it cleans up on drop.
async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<tempfile::TempDir> {
let nats_addr = nats_client.addr(); let nats_addr = nats_client.addr();
let bucket_name = self.slug(); let bucket_name = self.slug();
let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?; let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?;
...@@ -345,7 +425,9 @@ impl ModelDeploymentCard { ...@@ -345,7 +425,9 @@ impl ModelDeploymentCard {
"tokenizer.json" "tokenizer.json"
); );
Ok(target_dir) // This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive.
self.cache_dir = Some(Arc::new(target_dir));
Ok(())
} }
/// Delete this card from the key-value store and it's URLs from the object store /// Delete this card from the key-value store and it's URLs from the object store
...@@ -411,8 +493,7 @@ impl ModelDeploymentCard { ...@@ -411,8 +493,7 @@ impl ModelDeploymentCard {
else { else {
return Ok(None); return Ok(None);
}; };
// This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive. card.move_from_nats(drt.nats_client()).await?;
card.cache_dir = Some(Arc::new(card.move_from_nats(drt.nats_client()).await?));
Ok(Some(card)) Ok(Some(card))
} }
...@@ -487,6 +568,7 @@ impl ModelDeploymentCard { ...@@ -487,6 +568,7 @@ impl ModelDeploymentCard {
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
cache_dir: None, cache_dir: None,
checksum: OnceLock::new(),
}) })
} }
...@@ -551,6 +633,7 @@ impl ModelDeploymentCard { ...@@ -551,6 +633,7 @@ impl ModelDeploymentCard {
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
cache_dir: None, cache_dir: None,
checksum: OnceLock::new(),
}) })
} }
} }
......
...@@ -74,6 +74,25 @@ impl ModelType { ...@@ -74,6 +74,25 @@ impl ModelType {
result result
} }
/// Decompose the bitflag into it's component units:
/// Chat | Completion -> [Chat, Completion]
pub fn units(&self) -> Vec<ModelType> {
let mut result = Vec::new();
if self.supports_chat() {
result.push(ModelType::Chat);
}
if self.supports_completions() {
result.push(ModelType::Completions);
}
if self.supports_embedding() {
result.push(ModelType::Embedding);
}
if self.supports_tensor() {
result.push(ModelType::TensorBased);
}
result
}
/// Returns all endpoint types supported by this model type. /// Returns all endpoint types supported by this model type.
/// This properly handles combinations like Chat | Completions. /// This properly handles combinations like Chat | Completions.
pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> { pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> {
......
...@@ -124,7 +124,7 @@ impl OpenAIPreprocessor { ...@@ -124,7 +124,7 @@ impl OpenAIPreprocessor {
formatter: Arc<dyn OAIPromptFormatter>, formatter: Arc<dyn OAIPromptFormatter>,
hf_tokenizer: tokenizers::Tokenizer, hf_tokenizer: tokenizers::Tokenizer,
) -> Result<Arc<Self>> { ) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum(); let mdcsum = mdc.mdcsum().to_string();
let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer)); let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
let Some(model_info) = mdc.model_info else { let Some(model_info) = mdc.model_info else {
anyhow::bail!( anyhow::bail!(
......
...@@ -4,9 +4,20 @@ ...@@ -4,9 +4,20 @@
use anyhow::Error; use anyhow::Error;
use async_stream::stream; use async_stream::stream;
use dynamo_async_openai::config::OpenAIConfig; use dynamo_async_openai::config::OpenAIConfig;
use dynamo_llm::http::{ use dynamo_llm::protocols::{
Annotated,
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_llm::{
http::{
client::{ client::{
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient, PureOpenAIClient, GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient,
PureOpenAIClient,
}, },
service::{ service::{
Metrics, Metrics,
...@@ -14,15 +25,8 @@ use dynamo_llm::http::{ ...@@ -14,15 +25,8 @@ use dynamo_llm::http::{
metrics::{Endpoint, RequestType, Status}, metrics::{Endpoint, RequestType, Status},
service_v2::HttpService, service_v2::HttpService,
}, },
};
use dynamo_llm::protocols::{
Annotated,
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
}, },
model_card::ModelDeploymentCard,
}; };
use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix}; use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix};
use dynamo_runtime::{ use dynamo_runtime::{
...@@ -275,15 +279,18 @@ async fn test_http_service() { ...@@ -275,15 +279,18 @@ async fn test_http_service() {
let registry = Registry::new(); let registry = Registry::new();
// TODO: Shouldn't this test know the card before it registers a model?
let card = ModelDeploymentCard::with_name_only("foo");
let counter = Arc::new(CounterEngine {}); let counter = Arc::new(CounterEngine {});
let result = manager.add_chat_completions_model("foo", counter); let result = manager.add_chat_completions_model("foo", card.mdcsum(), counter);
assert!(result.is_ok()); assert!(result.is_ok());
let failure = Arc::new(AlwaysFailEngine {}); let failure = Arc::new(AlwaysFailEngine {});
let result = manager.add_chat_completions_model("bar", failure.clone()); let card = ModelDeploymentCard::with_name_only("bar");
let result = manager.add_chat_completions_model("bar", card.mdcsum(), failure.clone());
assert!(result.is_ok()); assert!(result.is_ok());
let result = manager.add_completions_model("bar", failure); let result = manager.add_completions_model("bar", card.mdcsum(), failure);
assert!(result.is_ok()); assert!(result.is_ok());
let metrics = state.metrics_clone(); let metrics = state.metrics_clone();
...@@ -578,14 +585,16 @@ async fn service_with_engines() -> (HttpService, Arc<CounterEngine>, Arc<AlwaysF ...@@ -578,14 +585,16 @@ async fn service_with_engines() -> (HttpService, Arc<CounterEngine>, Arc<AlwaysF
let counter = Arc::new(CounterEngine {}); let counter = Arc::new(CounterEngine {});
let failure = Arc::new(AlwaysFailEngine {}); let failure = Arc::new(AlwaysFailEngine {});
let card = ModelDeploymentCard::with_name_only("foo");
manager manager
.add_chat_completions_model("foo", counter.clone()) .add_chat_completions_model("foo", card.mdcsum(), counter.clone())
.unwrap(); .unwrap();
let card = ModelDeploymentCard::with_name_only("bar");
manager manager
.add_chat_completions_model("bar", failure.clone()) .add_chat_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap(); .unwrap();
manager manager
.add_completions_model("bar", failure.clone()) .add_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap(); .unwrap();
(service, counter, failure, port) (service, counter, failure, port)
...@@ -977,9 +986,10 @@ async fn test_client_disconnect_cancellation_unary() { ...@@ -977,9 +986,10 @@ async fn test_client_disconnect_cancellation_unary() {
wait_for_service_ready(port).await; wait_for_service_ready(port).await;
// Create a long-running engine (10 seconds) // Create a long-running engine (10 seconds)
let card = ModelDeploymentCard::with_name_only("slow-model");
let long_running_engine = Arc::new(LongRunningEngine::new(10_000)); let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager manager
.add_chat_completions_model("slow-model", long_running_engine.clone()) .add_chat_completions_model("slow-model", card.mdcsum(), long_running_engine.clone())
.unwrap(); .unwrap();
let client = reqwest::Client::new(); let client = reqwest::Client::new();
...@@ -1068,9 +1078,14 @@ async fn test_client_disconnect_cancellation_streaming() { ...@@ -1068,9 +1078,14 @@ async fn test_client_disconnect_cancellation_streaming() {
wait_for_service_ready(port).await; wait_for_service_ready(port).await;
// Create a long-running engine (10 seconds) // Create a long-running engine (10 seconds)
let card = ModelDeploymentCard::with_name_only("slow-stream-model");
let long_running_engine = Arc::new(LongRunningEngine::new(10_000)); let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager manager
.add_chat_completions_model("slow-stream-model", long_running_engine.clone()) .add_chat_completions_model(
"slow-stream-model",
card.mdcsum(),
long_running_engine.clone(),
)
.unwrap(); .unwrap();
let client = reqwest::Client::new(); let client = reqwest::Client::new();
...@@ -1166,9 +1181,10 @@ async fn test_request_id_annotation() { ...@@ -1166,9 +1181,10 @@ async fn test_request_id_annotation() {
wait_for_service_ready(port).await; wait_for_service_ready(port).await;
// Add a counter engine for this test // Add a counter engine for this test
let card = ModelDeploymentCard::with_name_only("test-model");
let counter_engine = Arc::new(CounterEngine {}); let counter_engine = Arc::new(CounterEngine {});
manager manager
.add_chat_completions_model("test-model", counter_engine) .add_chat_completions_model("test-model", card.mdcsum(), counter_engine)
.unwrap(); .unwrap();
// Create reqwest client directly // Create reqwest client directly
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
use anyhow::Error; use anyhow::Error;
use async_stream::stream; use async_stream::stream;
use dynamo_llm::{ use dynamo_llm::{
http::service::metrics::Endpoint, http::service::{metrics::Endpoint, service_v2::HttpService},
http::service::service_v2::HttpService, model_card::ModelDeploymentCard,
protocols::{ protocols::{
Annotated, Annotated,
openai::chat_completions::{ openai::chat_completions::{
...@@ -206,9 +206,10 @@ async fn test_metrics_with_mock_model() { ...@@ -206,9 +206,10 @@ async fn test_metrics_with_mock_model() {
let task = tokio::spawn(async move { service.run(token.clone()).await }); let task = tokio::spawn(async move { service.run(token.clone()).await });
// Add mock model engine // Add mock model engine
let card = ModelDeploymentCard::with_name_only("mockmodel");
let mock_engine = Arc::new(MockModelEngine {}); let mock_engine = Arc::new(MockModelEngine {});
manager manager
.add_chat_completions_model("mockmodel", mock_engine) .add_chat_completions_model("mockmodel", card.mdcsum(), mock_engine)
.unwrap(); .unwrap();
// Wait for service to be ready // Wait for service to be ready
...@@ -293,10 +294,8 @@ async fn test_metrics_with_mock_model() { ...@@ -293,10 +294,8 @@ async fn test_metrics_with_mock_model() {
mod integration_tests { mod integration_tests {
use super::*; use super::*;
use dynamo_llm::{ use dynamo_llm::{
discovery::{MODEL_ROOT_PATH, ModelEntry, ModelWatcher}, discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig,
engines::make_echo_engine, local_model::LocalModelBuilder, model_card,
entrypoint::EngineConfig,
local_model::LocalModelBuilder,
}; };
use dynamo_runtime::DistributedRuntime; use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode; use dynamo_runtime::pipeline::RouterMode;
...@@ -348,7 +347,7 @@ mod integration_tests { ...@@ -348,7 +347,7 @@ mod integration_tests {
// Start watching etcd for model registrations // Start watching etcd for model registrations
if let Some(etcd_client) = distributed_runtime.etcd_client() { if let Some(etcd_client) = distributed_runtime.etcd_client() {
let models_watcher = etcd_client let models_watcher = etcd_client
.kv_get_and_watch_prefix(MODEL_ROOT_PATH) .kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await .await
.unwrap(); .unwrap();
let (_prefix, _watcher, receiver) = models_watcher.dissolve(); let (_prefix, _watcher, receiver) = models_watcher.dissolve();
...@@ -364,10 +363,11 @@ mod integration_tests { ...@@ -364,10 +363,11 @@ mod integration_tests {
panic!("Expected StaticFull config"); panic!("Expected StaticFull config");
}; };
let card = local_model.card().clone();
let engine = Arc::new(dynamo_llm::engines::StreamingEngineAdapter::new(engine)); let engine = Arc::new(dynamo_llm::engines::StreamingEngineAdapter::new(engine));
let manager = service.model_manager(); let manager = service.model_manager();
manager manager
.add_chat_completions_model(model.service_name(), engine.clone()) .add_chat_completions_model(model.service_name(), card.mdcsum(), engine.clone())
.unwrap(); .unwrap();
// Now do the proper MDC registration via LocalModel::attach() // Now do the proper MDC registration via LocalModel::attach()
...@@ -376,7 +376,7 @@ mod integration_tests { ...@@ -376,7 +376,7 @@ mod integration_tests {
let test_component = namespace.component("test-mdc-component").unwrap(); let test_component = namespace.component("test-mdc-component").unwrap();
let test_endpoint = test_component.endpoint("test-mdc-endpoint"); let test_endpoint = test_component.endpoint("test-mdc-endpoint");
// This will store the MDC in etcd and create the ModelEntry for discovery // This will store the MDC in etcd for discovery
local_model local_model
.attach( .attach(
&test_endpoint, &test_endpoint,
...@@ -388,8 +388,7 @@ mod integration_tests { ...@@ -388,8 +388,7 @@ mod integration_tests {
// Manually save the model card and update metrics // Manually save the model card and update metrics
// This simulates what the ModelWatcher polling task would do in production // This simulates what the ModelWatcher polling task would do in production
let card = local_model.card().clone(); let _ = manager.save_model_card("test-mdc-key", card.clone());
manager.save_model_card("test-mdc-key", card.clone());
if let Err(e) = service if let Err(e) = service
.state() .state()
...@@ -500,8 +499,13 @@ mod integration_tests { ...@@ -500,8 +499,13 @@ mod integration_tests {
assert!(metrics_body.contains("request_type=\"stream\"")); assert!(metrics_body.contains("request_type=\"stream\""));
assert!(metrics_body.contains("status=\"success\"")); assert!(metrics_body.contains("status=\"success\""));
// etcd lease will ensure we everything is deleted from etcd
// Now test the complete lifecycle: remove the model from etcd // Now test the complete lifecycle: remove the model from etcd
// We don't need to cleanup model manager because it's local to this test
/*
// Clean up
// Remove the model using the cleaner ModelWatcher approach // Remove the model using the cleaner ModelWatcher approach
if let Some(etcd_client) = distributed_runtime.etcd_client() { if let Some(etcd_client) = distributed_runtime.etcd_client() {
// Use ModelWatcher to find and remove the model (following ModelWatcher::handle_delete pattern) // Use ModelWatcher to find and remove the model (following ModelWatcher::handle_delete pattern)
...@@ -514,10 +518,7 @@ mod integration_tests { ...@@ -514,10 +518,7 @@ mod integration_tests {
); );
// Get all model entries for our test model // Get all model entries for our test model
let model_entries = watcher let model_entries = watcher.entries_for_model("test-mdc-model").await.unwrap();
.entries_for_model("test-mdc-model", None, true)
.await
.unwrap();
if !model_entries.is_empty() { if !model_entries.is_empty() {
// For each model entry, we need to find its etcd key and remove it // For each model entry, we need to find its etcd key and remove it
...@@ -566,8 +567,8 @@ mod integration_tests { ...@@ -566,8 +567,8 @@ mod integration_tests {
} }
} }
} }
*/
// Clean up
cancel_token.cancel(); cancel_token.cancel();
task.await.unwrap().unwrap(); task.await.unwrap().unwrap();
} }
......
...@@ -280,31 +280,32 @@ pub mod kserve_test { ...@@ -280,31 +280,32 @@ pub mod kserve_test {
let failure = Arc::new(AlwaysFailEngine {}); let failure = Arc::new(AlwaysFailEngine {});
let long_running = Arc::new(LongRunningEngine::new(1_000)); let long_running = Arc::new(LongRunningEngine::new(1_000));
manager
.add_completions_model("split", split.clone())
.unwrap();
let mut card = ModelDeploymentCard::with_name_only("split"); let mut card = ModelDeploymentCard::with_name_only("split");
card.model_type = ModelType::Completions; card.model_type = ModelType::Completions;
card.model_input = ModelInput::Text; card.model_input = ModelInput::Text;
manager.save_model_card("split", card);
manager manager
.add_chat_completions_model("failure", failure.clone()) .add_completions_model("split", card.mdcsum(), split.clone())
.unwrap();
manager
.add_completions_model("failure", failure.clone())
.unwrap(); .unwrap();
let _ = manager.save_model_card("split", card.clone());
let mut card = ModelDeploymentCard::with_name_only("failure"); let mut card = ModelDeploymentCard::with_name_only("failure");
card.model_type = ModelType::Completions | ModelType::Chat; card.model_type = ModelType::Completions | ModelType::Chat;
card.model_input = ModelInput::Text; card.model_input = ModelInput::Text;
manager.save_model_card("failure", card);
manager manager
.add_completions_model("long_running", long_running.clone()) .add_chat_completions_model("failure", card.mdcsum(), failure.clone())
.unwrap(); .unwrap();
manager
.add_completions_model("failure", card.mdcsum(), failure.clone())
.unwrap();
let _ = manager.save_model_card("failure", card);
let mut card = ModelDeploymentCard::with_name_only("long_running"); let mut card = ModelDeploymentCard::with_name_only("long_running");
card.model_type = ModelType::Completions; card.model_type = ModelType::Completions;
card.model_input = ModelInput::Text; card.model_input = ModelInput::Text;
manager.save_model_card("long_running", card); manager
.add_completions_model("long_running", card.mdcsum(), long_running.clone())
.unwrap();
let _ = manager.save_model_card("long_running", card);
(service, split, failure, long_running) (service, split, failure, long_running)
} }
...@@ -1130,11 +1131,16 @@ pub mod kserve_test { ...@@ -1130,11 +1131,16 @@ pub mod kserve_test {
text_input: inference::model_infer_request::InferInputTensor, text_input: inference::model_infer_request::InferInputTensor,
) { ) {
// add tensor model // add tensor model
// Failure, model registered as Tensor but does not provide model config (in runtime config)
let mut card = ModelDeploymentCard::with_name_only("tensor");
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
let tensor = Arc::new(TensorEngine {}); let tensor = Arc::new(TensorEngine {});
service_with_engines service_with_engines
.0 .0
.model_manager() .model_manager()
.add_tensor_model("tensor", tensor.clone()) .add_tensor_model("tensor", card.mdcsum(), tensor.clone())
.unwrap(); .unwrap();
// start server // start server
...@@ -1147,11 +1153,7 @@ pub mod kserve_test { ...@@ -1147,11 +1153,7 @@ pub mod kserve_test {
version: "".into(), version: "".into(),
}); });
// Failure, model registered as Tensor but does not provide model config (in runtime config) let _ = service_with_engines
let mut card = ModelDeploymentCard::with_name_only("tensor");
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
service_with_engines
.0 .0
.model_manager() .model_manager()
.save_model_card("key", card); .save_model_card("key", card);
...@@ -1217,7 +1219,7 @@ pub mod kserve_test { ...@@ -1217,7 +1219,7 @@ pub mod kserve_test {
}), }),
..Default::default() ..Default::default()
}; };
service_with_engines let _ = service_with_engines
.0 .0
.model_manager() .model_manager()
.save_model_card("key", card); .save_model_card("key", card);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment