Unverified Commit 81162dfe authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(discovery): Watch/publish ModelDeploymentCard instead of ModelEntry (#3350)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent ddbb4f50
......@@ -69,7 +69,7 @@ async def worker(runtime: DistributedRuntime):
host: str = "localhost"
port: int = 8000
service: HttpService = HttpService(port=port)
service.add_chat_completions_model(served_model_name, engine)
service.add_chat_completions_model(served_model_name, "mdcsum", engine)
print("Starting service...")
shutdown_signal = service.run(runtime.child_token())
......
......@@ -30,23 +30,29 @@ impl HttpService {
Ok(Self { inner })
}
pub fn add_completions_model(&self, model: String, engine: HttpAsyncEngine) -> PyResult<()> {
pub fn add_completions_model(
&self,
model: String,
checksum: String,
engine: HttpAsyncEngine,
) -> PyResult<()> {
let engine = Arc::new(engine);
self.inner
.model_manager()
.add_completions_model(&model, engine)
.add_completions_model(&model, &checksum, engine)
.map_err(to_pyerr)
}
pub fn add_chat_completions_model(
&self,
model: String,
checksum: String,
engine: HttpAsyncEngine,
) -> PyResult<()> {
let engine = Arc::new(engine);
self.inner
.model_manager()
.add_chat_completions_model(&model, engine)
.add_chat_completions_model(&model, &checksum, engine)
.map_err(to_pyerr)
}
......
......@@ -85,6 +85,7 @@ async def http_server(runtime: DistributedRuntime):
model_name = "test_model"
start_done = asyncio.Event()
child_token = runtime.child_token()
checksum = "abc123" # Checksum of ModelDeplomentCard for that model
async def worker():
"""The server worker task."""
......@@ -94,7 +95,7 @@ async def http_server(runtime: DistributedRuntime):
engine = HttpAsyncEngine(python_engine.generate, loop)
service = HttpService(port=port)
service.add_chat_completions_model(model_name, engine)
service.add_chat_completions_model(model_name, checksum, engine)
service.enable_endpoint("chat", True)
shutdown_signal = service.run(child_token)
......
......@@ -33,7 +33,7 @@ pub struct Checksum {
algorithm: CryptographicHashMethods,
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Eq, PartialEq)]
pub enum CryptographicHashMethods {
#[serde(rename = "blake3")]
BLAKE3,
......@@ -259,6 +259,15 @@ impl TryFrom<&str> for Checksum {
}
}
impl Default for Checksum {
fn default() -> Self {
Self {
hash: "".to_string(),
algorithm: CryptographicHashMethods::BLAKE3,
}
}
}
impl FromStr for CryptographicHashMethods {
type Err = String;
......
......@@ -4,14 +4,8 @@
mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError};
mod model_entry;
pub use model_entry::ModelEntry;
mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher};
/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
/// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "kv_routers";
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::{protocols, slug::Slug};
use serde::{Deserialize, Serialize};
use crate::local_model::runtime_config::ModelRuntimeConfig;
/// [ModelEntry] contains the information to discover models
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
/// Public name of the model
/// Used to identify the model in the HTTP service from the value used in an OpenAI ChatRequest.
pub name: String,
/// How to address this on the network
#[serde(rename = "endpoint")]
pub endpoint_id: protocols::EndpointId,
/// Runtime configuration specific to this model instance
#[serde(default, skip_serializing_if = "Option::is_none")]
pub runtime_config: Option<ModelRuntimeConfig>,
}
impl ModelEntry {
/// Slugified display name for use in network storage, or URL-safe environments
pub fn slug(&self) -> Slug {
Slug::from_string(&self.name)
}
}
......@@ -11,7 +11,6 @@ use parking_lot::{Mutex, RwLock};
use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider;
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
use crate::{
kv_router::KvRouter,
......@@ -21,6 +20,10 @@ use crate::{
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
},
};
use crate::{
kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector},
model_type::ModelType,
};
#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
......@@ -39,7 +42,7 @@ pub struct ModelManager {
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// These two are Mutex because we read and write rarely and equally
// These are Mutex because we read and write rarely and equally
cards: Mutex<HashMap<String, ModelDeploymentCard>>,
kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
}
......@@ -62,6 +65,43 @@ impl ModelManager {
}
}
pub fn is_valid_checksum(
&self,
model_type: ModelType,
model_name: &str,
candidate_checksum: &str,
) -> Option<bool> {
let mut results = vec![];
for unit in model_type.units() {
let maybe_valid_checksum = match unit {
ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
_ => {
continue;
}
};
if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
tracing::debug!(
model_name,
valid_checksum,
candidate_checksum,
"is_valid_checksum: check case"
);
valid_checksum == candidate_checksum
}) {
results.push(is_valid)
}
}
if results.is_empty() {
None
} else {
// The checksum is valid if it is correct for all the ModelType in the bitflag.
Some(results.into_iter().all(|x| x))
}
}
pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.cards.lock().values().cloned().collect()
}
......@@ -99,37 +139,41 @@ impl ModelManager {
pub fn add_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}
pub fn add_chat_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}
pub fn add_embeddings_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}
pub fn add_tensor_model(
&self,
model: &str,
card_checksum: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
......@@ -196,10 +240,11 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's etcd `models/` key so we can fetch it later when the key is
/// deleted from etcd.
pub fn save_model_card(&self, key: &str, entry: ModelDeploymentCard) {
self.cards.lock().insert(key.to_string(), entry);
/// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.lock().insert(key.to_string(), card);
Ok(())
}
/// Remove and return model card for this instance's etcd key. We do this when the instance stops.
......@@ -291,6 +336,9 @@ pub struct ModelEngines<E> {
/// Optional default model name
default: Option<String>,
engines: HashMap<String, E>,
/// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
/// same card.
checksums: HashMap<String, String>,
}
impl<E> Default for ModelEngines<E> {
......@@ -298,6 +346,7 @@ impl<E> Default for ModelEngines<E> {
Self {
default: None,
engines: HashMap::new(),
checksums: HashMap::new(),
}
}
}
......@@ -313,11 +362,13 @@ impl<E> ModelEngines<E> {
self.default = None;
}
fn add(&mut self, model: &str, engine: E) -> Result<(), ModelManagerError> {
fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
if self.engines.contains_key(model) {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
self.engines.insert(model.to_string(), engine);
self.checksums
.insert(model.to_string(), checksum.to_string());
Ok(())
}
......@@ -325,6 +376,7 @@ impl<E> ModelEngines<E> {
if self.engines.remove(model).is_none() {
return Err(ModelManagerError::ModelNotFound(model.to_string()));
}
let _ = self.checksums.remove(model);
Ok(())
}
......@@ -339,4 +391,10 @@ impl<E> ModelEngines<E> {
pub fn list(&self) -> Vec<String> {
self.engines.keys().map(|k| k.to_owned()).collect()
}
/// Returns a newly allocated String for called convenience. All the places I use
/// this I need a String.
pub fn checksum(&self, model: &str) -> Option<String> {
self.checksums.get(model).map(|s| s.to_string())
}
}
......@@ -13,16 +13,15 @@ use dynamo_runtime::{
ManyOut, Operator, RouterMode, SegmentSource, ServiceBackend, SingleIn, Source,
network::egress::push_router::PushRouter,
},
protocols::annotated::Annotated,
storage::key_value_store::Key,
transports::etcd::{KeyValue, WatchEvent},
protocols::{EndpointId, annotated::Annotated},
transports::etcd::WatchEvent,
};
use crate::{
backend::Backend,
entrypoint,
kv_router::KvRouterConfig,
model_card::ModelDeploymentCard,
model_card::{self, ModelDeploymentCard},
model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
protocols::{
......@@ -38,7 +37,7 @@ use crate::{
},
};
use super::{MODEL_ROOT_PATH, ModelEntry, ModelManager};
use super::ModelManager;
use crate::namespace::is_global_namespace;
#[derive(Debug, Clone)]
......@@ -105,64 +104,97 @@ impl ModelWatcher {
while let Some(event) = events_rx.recv().await {
match event {
WatchEvent::Put(kv) => {
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
Ok(model_entry) => model_entry,
let mut card = match serde_json::from_slice::<ModelDeploymentCard>(kv.value()) {
Ok(card) => card,
Err(err) => {
match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model entry")
tracing::error!(%err, value, "Invalid JSON in model card")
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON")
}
}
continue;
}
};
let key = match kv.key_str() {
Ok(k) => k,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid UTF-8 string in model card key, skipping");
continue;
}
};
let endpoint_id = match etcd_key_extract(key) {
Ok((eid, _)) => eid,
Err(err) => {
tracing::error!(%key, model_name = card.name(), %err, "Failed extracting EndpointId from key. Ignoring instance.");
continue;
}
};
// Filter by namespace if target_namespace is specified
if !global_namespace
&& let Some(target_ns) = target_namespace
&& model_entry.endpoint_id.namespace != target_ns
&& endpoint_id.namespace != target_ns
{
tracing::debug!(
model_namespace = model_entry.endpoint_id.namespace,
model_namespace = endpoint_id.namespace,
target_namespace = target_ns,
model_name = model_entry.name,
model_name = card.name(),
"Skipping model from different namespace"
);
continue;
}
let key = match kv.key_str() {
Ok(k) => k,
Err(err) => {
tracing::error!(%err, ?kv, "Invalid UTF-8 string in model entry key, skipping");
// If we already have a worker for this model, and the ModelDeploymentCard
// cards don't match, alert, and don't add the new instance
let can_add =
self.manager
.is_valid_checksum(card.model_type, card.name(), card.mdcsum());
if can_add.is_some_and(|is_valid| !is_valid) {
tracing::error!(
model_name = card.name(),
"Checksum for new model does not match existing model."
);
// TODO: mark that instance down in clients
// Not obvious how to do that given the current design
// Instances come from an `InstanceSource` in a `Client` in a `PushRouter`.
// Calling `report_instance_down` on the Client should do it (although
// needs more testing).
// The `PushRouter` is in `ModelMananger` (`self.manager` here), but inside
// interface `AsyncEngine` which only has a `generate` method.
continue;
}
};
match self.handle_put(key, &model_entry).await {
match self.handle_put(key, &endpoint_id, &mut card).await {
Ok(()) => {
tracing::info!(
model_name = model_entry.name,
namespace = model_entry.endpoint_id.namespace,
model_name = card.name(),
namespace = endpoint_id.namespace,
"added model"
);
self.notify_on_model.notify_waiters();
}
Err(err) => {
tracing::error!(
model_name = card.name(),
namespace = endpoint_id.namespace,
error = format!("{err:#}"),
"error adding model {} from namespace {}",
model_entry.name,
model_entry.endpoint_id.namespace,
"Error adding model from discovery",
);
}
}
}
WatchEvent::Delete(kv) => match self
.handle_delete(&kv, target_namespace, global_namespace)
WatchEvent::Delete(kv) => {
let Ok(deleted_key) = kv.key_str() else {
tracing::warn!("Invalid UTF-8 in etcd delete notification key: {kv:?}");
continue;
};
match self
.handle_delete(deleted_key, target_namespace, global_namespace)
.await
{
Ok(Some(model_name)) => {
......@@ -174,7 +206,8 @@ impl ModelWatcher {
Err(e) => {
tracing::error!(error = %e, "error removing model");
}
},
}
}
}
}
}
......@@ -183,20 +216,19 @@ impl ModelWatcher {
/// Returns the name of the model we just deleted, if any.
async fn handle_delete(
&self,
kv: &KeyValue,
key: &str,
target_namespace: Option<&str>,
is_global_namespace: bool,
) -> anyhow::Result<Option<String>> {
let key = kv.key_str()?;
let card = match self.manager.remove_model_card(key) {
Some(card) => card,
None => {
anyhow::bail!("Missing ModelDeploymentCard for {key}");
}
};
let model_name = card.display_name.clone();
let model_name = card.name().to_string();
let active_instances = self
.entries_for_model(&model_name, target_namespace, is_global_namespace)
.cards_for_model(&model_name, target_namespace, is_global_namespace)
.await
.with_context(|| model_name.clone())?;
if !active_instances.is_empty() {
......@@ -265,53 +297,35 @@ impl ModelWatcher {
// Handles a PUT event from etcd, this usually means adding a new model to the list of served
// models.
async fn handle_put(&self, key: &str, model_entry: &ModelEntry) -> anyhow::Result<()> {
let endpoint_id = &model_entry.endpoint_id;
async fn handle_put(
&self,
key: &str,
endpoint_id: &EndpointId,
card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> {
card.move_from_nats(self.drt.nats_client()).await?;
let component = self
.drt
.namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?;
let model_slug = model_entry.slug();
let card = match ModelDeploymentCard::load_from_store(
&Key::from_raw(model_slug.to_string()),
&self.drt,
)
.await
{
Ok(Some(mut card)) => {
tracing::debug!(card.display_name, "adding model");
// Ensure runtime_config is populated
if let Some(rc) = model_entry.runtime_config.clone() {
card.runtime_config = rc;
}
card
}
Ok(None) => {
anyhow::bail!("Missing ModelDeploymentCard in storage under key {model_slug}");
}
Err(err) => {
anyhow::bail!(
"Error fetching ModelDeploymentCard from storage under key {model_slug}. {err}"
);
}
};
self.manager.save_model_card(key, card.clone());
tracing::debug!(model_name = card.name(), "adding model");
self.manager.save_model_card(key, card.clone())?;
if self.manager.has_model_any(&model_entry.name) {
tracing::trace!(
name = model_entry.name,
namespace = model_entry.endpoint_id.namespace,
if self.manager.has_model_any(card.name()) {
tracing::debug!(
model_name = card.name(),
namespace = endpoint_id.namespace,
"New endpoint for existing model"
);
self.notify_on_model.notify_waiters();
//self.notify_on_model.notify_waiters();
return Ok(());
}
if let Some(tx) = &self.model_update_tx {
tx.send(ModelUpdate::Added(card.clone())).await.ok();
}
let checksum = card.mdcsum();
if card.model_input == ModelInput::Tokens
&& (card.model_type.supports_chat() || card.model_type.supports_completions())
......@@ -324,7 +338,7 @@ impl ModelWatcher {
Some(
self.manager
.kv_chooser_for(
&model_entry.name,
card.name(),
&component,
card.kv_cache_block_size,
self.kv_router_config,
......@@ -344,7 +358,7 @@ impl ModelWatcher {
NvCreateChatCompletionRequest,
NvCreateChatCompletionStreamResponse,
>(
&card,
card,
&client,
self.router_mode,
self.busy_threshold,
......@@ -354,7 +368,7 @@ impl ModelWatcher {
.await
.context("build_routed_pipeline")?;
self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)
.add_chat_completions_model(card.name(), checksum, chat_engine)
.context("add_chat_completions_model")?;
tracing::info!("Chat completions is ready");
}
......@@ -373,7 +387,7 @@ impl ModelWatcher {
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(
&card,
card,
&client,
self.router_mode,
self.busy_threshold,
......@@ -384,7 +398,7 @@ impl ModelWatcher {
.await
.context("build_routed_pipeline_with_preprocessor")?;
self.manager
.add_completions_model(&model_entry.name, completions_engine)
.add_completions_model(card.name(), checksum, completions_engine)
.context("add_completions_model")?;
tracing::info!("Completions is ready");
}
......@@ -411,7 +425,7 @@ impl ModelWatcher {
.await?;
let engine = Arc::new(push_router);
self.manager
.add_chat_completions_model(&model_entry.name, engine)?;
.add_chat_completions_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Text && card.model_type.supports_completions() {
// Case 2: Text + Completions
let push_router = PushRouter::<
......@@ -423,7 +437,7 @@ impl ModelWatcher {
.await?;
let engine = Arc::new(push_router);
self.manager
.add_completions_model(&model_entry.name, engine)?;
.add_completions_model(card.name(), checksum, engine)?;
} else if card.model_input == ModelInput::Tokens && card.model_type.supports_embedding() {
// Case 4: Tokens + Embeddings
......@@ -434,7 +448,7 @@ impl ModelWatcher {
>::new();
let preprocessor = OpenAIPreprocessor::new(card.clone())?.into_operator();
let backend = Backend::from_mdc(&card).into_operator();
let backend = Backend::from_mdc(card).into_operator();
let router = PushRouter::<
PreprocessedEmbeddingRequest,
......@@ -457,7 +471,7 @@ impl ModelWatcher {
.link(frontend)?;
self.manager
.add_embeddings_model(&model_entry.name, embedding_engine)?;
.add_embeddings_model(card.name(), checksum, embedding_engine)?;
} else if card.model_input == ModelInput::Tensor && card.model_type.supports_tensor() {
// Case 5: Tensor + Tensor (non-LLM)
let push_router = PushRouter::<
......@@ -468,7 +482,8 @@ impl ModelWatcher {
)
.await?;
let engine = Arc::new(push_router);
self.manager.add_tensor_model(&model_entry.name, engine)?;
self.manager
.add_tensor_model(card.name(), checksum, engine)?;
} else {
// Reject unsupported combinations
anyhow::bail!(
......@@ -482,49 +497,116 @@ impl ModelWatcher {
Ok(())
}
/// All the registered ModelEntry, one per instance
pub async fn all_entries(&self) -> anyhow::Result<Vec<ModelEntry>> {
/// All the registered ModelDeploymentCard with the EndpointId they are attached to, one per instance
pub async fn all_cards(&self) -> anyhow::Result<Vec<(EndpointId, ModelDeploymentCard)>> {
let Some(etcd_client) = self.drt.etcd_client() else {
anyhow::bail!("all_entries: Missing etcd client");
anyhow::bail!("all_cards: Missing etcd client");
};
let kvs = etcd_client.kv_get_prefix(MODEL_ROOT_PATH).await?;
let mut entries = Vec::with_capacity(kvs.len());
let kvs = etcd_client.kv_get_prefix(model_card::ROOT_PATH).await?;
let mut results = Vec::with_capacity(kvs.len());
for kv in kvs {
let model_entry = match serde_json::from_slice::<ModelEntry>(kv.value()) {
Ok(model_entry) => model_entry,
let maybe_convert = serde_json::from_slice::<ModelDeploymentCard>(kv.value());
let r = match maybe_convert {
Ok(card) => {
let maybe_endpoint_id = kv.key_str().map_err(|err| err.into()).and_then(|k| {
etcd_key_extract(k).map(|(endpoint_id, _instance_id)| endpoint_id)
});
let endpoint_id = match maybe_endpoint_id {
Ok(eid) => eid,
Err(err) => {
tracing::error!(%err, "Skipping invalid etcd key, not string or not EndpointId");
continue;
}
};
(endpoint_id, card)
}
Err(err) => {
match kv.value_str() {
Ok(value) => {
tracing::error!(%err, value, "Invalid JSON in model entry")
tracing::error!(%err, value, "Invalid JSON in model card");
}
Err(value_str_err) => {
tracing::error!(original_error = %err, %value_str_err, "Invalid UTF-8 string in model entry, expected JSON")
tracing::error!(original_error=%err, %value_str_err, "Invalid UTF-8 string in model card, expected JSON");
}
}
continue;
}
};
entries.push(model_entry);
results.push(r);
}
Ok(entries)
Ok(results)
}
pub async fn entries_for_model(
pub async fn cards_for_model(
&self,
model_name: &str,
target_namespace: Option<&str>,
is_global_namespace: bool,
) -> anyhow::Result<Vec<ModelEntry>> {
let mut all = self.all_entries().await?;
all.retain(|entry| {
let matches_name = entry.name == model_name;
) -> anyhow::Result<Vec<ModelDeploymentCard>> {
let mut all = self.all_cards().await?;
all.retain(|(endpoint_id, card)| {
let matches_name = card.name() == model_name;
let matches_namespace = match (is_global_namespace, target_namespace) {
(true, _) => true,
(false, None) => true,
(false, Some(target_ns)) => entry.endpoint_id.namespace == target_ns,
(false, Some(target_ns)) => endpoint_id.namespace == target_ns,
};
matches_name && matches_namespace
});
Ok(all)
Ok(all.into_iter().map(|(_eid, card)| card).collect())
}
}
/// The ModelDeploymentCard is published in etcd with a key like "v1/mdc/dynamo/backend/generate/694d9981145a61ad".
/// Extract the EndpointId and instance_id from that.
fn etcd_key_extract(s: &str) -> anyhow::Result<(EndpointId, String)> {
let parts: Vec<&str> = s.split('/').collect();
let start_idx = if !parts.is_empty() && parts[0] == "v1" {
1
} else {
0
};
// Need at least prefix model_card::ROOT_PATH + 3 parts: namespace, component, name
if parts.len() <= start_idx + 3 {
anyhow::bail!("Invalid format: not enough path segments in {s}");
}
if parts.get(start_idx) != Some(&model_card::ROOT_PATH) {
anyhow::bail!("Invalid format: expected model card ROOT_PATH segment in {s}");
}
let endpoint_id = EndpointId {
namespace: parts[start_idx + 1].to_string(),
component: parts[start_idx + 2].to_string(),
name: parts[start_idx + 3].to_string(),
};
Ok((endpoint_id, parts[parts.len() - 1].to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_etcd_key_extract() {
let input = format!(
"v1/{}/dynamo/backend/generate/694d9981145a61ad",
model_card::ROOT_PATH
);
let (endpoint_id, instance_id) = etcd_key_extract(&input).unwrap();
assert_eq!(endpoint_id.namespace, "dynamo");
assert_eq!(endpoint_id.component, "backend");
assert_eq!(endpoint_id.name, "generate");
assert_eq!(instance_id, "694d9981145a61ad");
let input = format!(
"{}/dynamo/backend/generate/694d9981145a61ad",
model_card::ROOT_PATH
);
let (endpoint_id, _) = etcd_key_extract(&input).unwrap();
assert_eq!(endpoint_id.namespace, "dynamo");
assert_eq!(endpoint_id.component, "backend");
assert_eq!(endpoint_id.name, "generate");
}
}
......@@ -5,12 +5,12 @@ use std::pin::Pin;
use crate::{
backend::{Backend, ExecutionContext},
discovery::{MODEL_ROOT_PATH, ModelManager, ModelWatcher},
discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig},
kv_router::{KvPushRouter, KvRouter},
migration::Migration,
model_card::ModelDeploymentCard,
model_card::{self, ModelDeploymentCard},
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate,
......@@ -73,7 +73,9 @@ pub async fn prepare_engine(
None,
None,
));
let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
let models_watcher = etcd_client
.kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let inner_watch_obj = watch_obj.clone();
......
......@@ -4,11 +4,12 @@
use std::sync::Arc;
use crate::{
discovery::{MODEL_ROOT_PATH, ModelManager, ModelWatcher},
discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common},
grpc::service::kserve,
kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
......@@ -46,7 +47,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
distributed_runtime,
grpc_service.state().manager_clone(),
etcd_client.clone(),
MODEL_ROOT_PATH,
model_card::ROOT_PATH,
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
......@@ -62,6 +63,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
}
EngineConfig::StaticRemote(local_model) => {
let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static
......@@ -103,7 +105,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
tokenizer_hf.clone(),
)
.await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;
manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine =
entrypoint::build_routed_pipeline::<
......@@ -111,7 +117,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?;
manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
grpc_service
}
......@@ -119,8 +129,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let grpc_service = grpc_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine));
let manager = grpc_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?;
manager.add_chat_completions_model(model.service_name(), engine)?;
let checksum = model.card().mdcsum();
manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
grpc_service
}
EngineConfig::StaticCore {
......@@ -130,6 +141,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} => {
let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline =
......@@ -138,14 +150,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer_hf)
.await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
grpc_service
}
};
......
......@@ -4,12 +4,13 @@
use std::sync::Arc;
use crate::{
discovery::{MODEL_ROOT_PATH, ModelManager, ModelUpdate, ModelWatcher},
discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType,
engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common},
http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
......@@ -74,7 +75,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
distributed_runtime,
http_service.state().manager_clone(),
etcd_client.clone(),
MODEL_ROOT_PATH,
model_card::ROOT_PATH,
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
......@@ -92,6 +93,8 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
}
EngineConfig::StaticRemote(local_model) => {
let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static
......@@ -133,7 +136,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
tokenizer_hf.clone(),
)
.await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;
manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine =
entrypoint::build_routed_pipeline::<
......@@ -141,7 +148,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?;
manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true);
......@@ -153,8 +164,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let http_service = http_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine));
let manager = http_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?;
manager.add_chat_completions_model(model.service_name(), engine)?;
let checksum = model.card().mdcsum();
manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
// Enable all endpoints
for endpoint_type in EndpointType::all() {
......@@ -169,6 +181,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} => {
let http_service = http_service_builder.build()?;
let manager = http_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline =
......@@ -177,14 +190,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer_hf)
.await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
// Enable all endpoints
for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true);
......
......@@ -32,7 +32,6 @@ pub mod sequence;
pub mod subscriber;
use crate::{
discovery::{MODEL_ROOT_PATH, ModelEntry},
kv_router::{
approx::ApproxKvIndexer,
indexer::{
......@@ -45,6 +44,7 @@ use crate::{
subscriber::start_kv_router_background,
},
local_model::runtime_config::ModelRuntimeConfig,
model_card::{self, ModelDeploymentCard},
preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput,
};
......@@ -247,9 +247,9 @@ impl KvRouter {
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
MODEL_ROOT_PATH,
model_card::ROOT_PATH,
key_extractors::lease_id,
|model_entry: ModelEntry| model_entry.runtime_config,
|card: ModelDeploymentCard| Some(card.runtime_config),
cancellation_token.clone(),
)
.await?;
......
......@@ -15,15 +15,12 @@ use dynamo_runtime::{
storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
};
use crate::discovery::ModelEntry;
use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::{ModelInput, ModelType};
use crate::request_template::RequestTemplate;
mod network_name;
pub use network_name::ModelNetworkName;
pub mod runtime_config;
use runtime_config::ModelRuntimeConfig;
......@@ -421,36 +418,13 @@ impl LocalModel {
// Publish the Model Deployment Card to KV store
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string();
// TODO: Next PR will use this
//let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
//let key = Key::from_raw(endpoint.unique_path(lease_id));
card_store
.publish(
model_card::ROOT_PATH,
None,
&Key::from_raw(key),
&mut self.card,
)
.await?;
let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
let key = Key::from_raw(endpoint.unique_path(lease_id));
// Publish our ModelEntry to etcd. This allows ingress to find the model card.
// (Why don't we put the model card directly under this key?)
let network_name = ModelNetworkName::new();
tracing::debug!("Registering with etcd as {network_name}");
let model_registration = ModelEntry {
name: self.display_name().to_string(),
endpoint_id: endpoint.id(),
runtime_config: Some(self.runtime_config.clone()),
};
etcd_client
.kv_create(
&network_name,
serde_json::to_vec_pretty(&model_registration)?,
None, // use primary lease
)
.await
let _outcome = card_store
.publish(model_card::ROOT_PATH, None, &key, &mut self.card)
.await?;
Ok(())
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::discovery::MODEL_ROOT_PATH;
#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);
impl ModelNetworkName {
pub fn new() -> Self {
ModelNetworkName(format!("{MODEL_ROOT_PATH}/{}", uuid::Uuid::new_v4()))
}
}
impl Default for ModelNetworkName {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl AsRef<str> for ModelNetworkName {
fn as_ref(&self) -> &str {
&self.0
}
}
impl std::ops::Deref for ModelNetworkName {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
......@@ -15,9 +15,9 @@
use std::fmt;
use std::fs::File;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use crate::common::checked_file::CheckedFile;
use crate::common::checked_file::{CheckedFile, Checksum};
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_type::{ModelInput, ModelType};
use anyhow::{Context, Result};
......@@ -43,6 +43,15 @@ pub enum ModelInfoType {
GGUF(PathBuf),
}
impl ModelInfoType {
pub fn checksum(&self) -> String {
match self {
ModelInfoType::HfConfigJson(c) => c.checksum().to_string(),
ModelInfoType::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum TokenizerKind {
......@@ -50,6 +59,15 @@ pub enum TokenizerKind {
GGUF(Box<HfTokenizer>),
}
impl TokenizerKind {
pub fn checksum(&self) -> String {
match self {
TokenizerKind::HfTokenizerJson(c) => c.checksum().to_string(),
TokenizerKind::GGUF(_) => Checksum::default().to_string(),
}
}
}
/// Supported types of prompt formatters.
///
/// We need a way to associate the prompt formatter template definition with an associated
......@@ -70,6 +88,16 @@ pub enum PromptFormatterArtifact {
GGUF(PathBuf),
}
impl PromptFormatterArtifact {
pub fn checksum(&self) -> String {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.checksum().to_string(),
PromptFormatterArtifact::HfChatTemplate(c) => c.checksum().to_string(),
PromptFormatterArtifact::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum PromptContextMixin {
......@@ -87,6 +115,15 @@ pub enum GenerationConfig {
GGUF(PathBuf),
}
impl GenerationConfig {
pub fn checksum(&self) -> String {
match self {
GenerationConfig::HfGenerationConfigJson(c) => c.checksum().to_string(),
GenerationConfig::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
pub struct ModelDeploymentCard {
/// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct"
......@@ -145,6 +182,9 @@ pub struct ModelDeploymentCard {
#[serde(skip)]
cache_dir: Option<Arc<tempfile::TempDir>>,
#[serde(skip, default)]
checksum: OnceLock<String>,
}
impl ModelDeploymentCard {
......@@ -189,6 +229,12 @@ impl ModelDeploymentCard {
Ok(())
}
#[inline]
pub fn name(&self) -> &str {
&self.display_name
}
#[inline]
pub fn slug(&self) -> &Slug {
&self.slug
}
......@@ -198,9 +244,45 @@ impl ModelDeploymentCard {
Ok(serde_json::to_string(self)?)
}
pub fn mdcsum(&self) -> String {
let json = self.to_json().unwrap();
format!("{}", blake3::hash(json.as_bytes()))
pub fn mdcsum(&self) -> &str {
self.checksum
.get_or_init(|| {
// Only include the important fields
let mut bytes_to_hash: Vec<u8> = Vec::with_capacity(512);
bytes_to_hash.extend(self.display_name.as_bytes());
// The files can be either a URL or a local path, so we ignore that and hash their
// checksum instead, which won't change wherever they are.
if let Some(model_info) = self.model_info.as_ref() {
bytes_to_hash.extend(model_info.checksum().as_bytes());
}
if let Some(tokenizer) = self.tokenizer.as_ref() {
bytes_to_hash.extend(tokenizer.checksum().as_bytes());
}
if let Some(prompt_formatter) = self.prompt_formatter.as_ref() {
bytes_to_hash.extend(prompt_formatter.checksum().as_bytes());
}
if let Some(chat_template) = self.chat_template_file.as_ref() {
bytes_to_hash.extend(chat_template.checksum().as_bytes());
}
if let Some(gen_config) = self.gen_config.as_ref() {
bytes_to_hash.extend(gen_config.checksum().as_bytes());
}
if let Some(prompt_context_vec) = self.prompt_context.as_ref() {
// Paste it as the bytes of the debug format. It's a Vec of enum, so this should be
// fine. If the debug representation changes that only happens in a new release.
bytes_to_hash.extend(format!("{prompt_context_vec:?}").as_bytes());
}
bytes_to_hash.extend(self.context_length.to_be_bytes());
bytes_to_hash.extend(self.kv_cache_block_size.to_be_bytes());
// TODO: Do we want any of user_data or runtime_config?
blake3::hash(&bytes_to_hash).to_string()
})
.as_ref()
}
/// Is this a full model card with tokenizer?
......@@ -291,9 +373,7 @@ impl ModelDeploymentCard {
/// Move the files this MDC uses from the NATS object store to local disk.
/// Updates the URI's to point to the created files.
///
/// The returned TempDir must be kept alive, it cleans up on drop.
async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<tempfile::TempDir> {
pub async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
let nats_addr = nats_client.addr();
let bucket_name = self.slug();
let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?;
......@@ -345,7 +425,9 @@ impl ModelDeploymentCard {
"tokenizer.json"
);
Ok(target_dir)
// This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive.
self.cache_dir = Some(Arc::new(target_dir));
Ok(())
}
/// Delete this card from the key-value store and it's URLs from the object store
......@@ -411,8 +493,7 @@ impl ModelDeploymentCard {
else {
return Ok(None);
};
// This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive.
card.cache_dir = Some(Arc::new(card.move_from_nats(drt.nats_client()).await?));
card.move_from_nats(drt.nats_client()).await?;
Ok(Some(card))
}
......@@ -487,6 +568,7 @@ impl ModelDeploymentCard {
user_data: None,
runtime_config: ModelRuntimeConfig::default(),
cache_dir: None,
checksum: OnceLock::new(),
})
}
......@@ -551,6 +633,7 @@ impl ModelDeploymentCard {
user_data: None,
runtime_config: ModelRuntimeConfig::default(),
cache_dir: None,
checksum: OnceLock::new(),
})
}
}
......
......@@ -74,6 +74,25 @@ impl ModelType {
result
}
/// Decompose the bitflag into it's component units:
/// Chat | Completion -> [Chat, Completion]
pub fn units(&self) -> Vec<ModelType> {
let mut result = Vec::new();
if self.supports_chat() {
result.push(ModelType::Chat);
}
if self.supports_completions() {
result.push(ModelType::Completions);
}
if self.supports_embedding() {
result.push(ModelType::Embedding);
}
if self.supports_tensor() {
result.push(ModelType::TensorBased);
}
result
}
/// Returns all endpoint types supported by this model type.
/// This properly handles combinations like Chat | Completions.
pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> {
......
......@@ -124,7 +124,7 @@ impl OpenAIPreprocessor {
formatter: Arc<dyn OAIPromptFormatter>,
hf_tokenizer: tokenizers::Tokenizer,
) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum();
let mdcsum = mdc.mdcsum().to_string();
let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
let Some(model_info) = mdc.model_info else {
anyhow::bail!(
......
......@@ -4,9 +4,20 @@
use anyhow::Error;
use async_stream::stream;
use dynamo_async_openai::config::OpenAIConfig;
use dynamo_llm::http::{
use dynamo_llm::protocols::{
Annotated,
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_llm::{
http::{
client::{
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient, PureOpenAIClient,
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient,
PureOpenAIClient,
},
service::{
Metrics,
......@@ -14,15 +25,8 @@ use dynamo_llm::http::{
metrics::{Endpoint, RequestType, Status},
service_v2::HttpService,
},
};
use dynamo_llm::protocols::{
Annotated,
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
model_card::ModelDeploymentCard,
};
use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix};
use dynamo_runtime::{
......@@ -275,15 +279,18 @@ async fn test_http_service() {
let registry = Registry::new();
// TODO: Shouldn't this test know the card before it registers a model?
let card = ModelDeploymentCard::with_name_only("foo");
let counter = Arc::new(CounterEngine {});
let result = manager.add_chat_completions_model("foo", counter);
let result = manager.add_chat_completions_model("foo", card.mdcsum(), counter);
assert!(result.is_ok());
let failure = Arc::new(AlwaysFailEngine {});
let result = manager.add_chat_completions_model("bar", failure.clone());
let card = ModelDeploymentCard::with_name_only("bar");
let result = manager.add_chat_completions_model("bar", card.mdcsum(), failure.clone());
assert!(result.is_ok());
let result = manager.add_completions_model("bar", failure);
let result = manager.add_completions_model("bar", card.mdcsum(), failure);
assert!(result.is_ok());
let metrics = state.metrics_clone();
......@@ -578,14 +585,16 @@ async fn service_with_engines() -> (HttpService, Arc<CounterEngine>, Arc<AlwaysF
let counter = Arc::new(CounterEngine {});
let failure = Arc::new(AlwaysFailEngine {});
let card = ModelDeploymentCard::with_name_only("foo");
manager
.add_chat_completions_model("foo", counter.clone())
.add_chat_completions_model("foo", card.mdcsum(), counter.clone())
.unwrap();
let card = ModelDeploymentCard::with_name_only("bar");
manager
.add_chat_completions_model("bar", failure.clone())
.add_chat_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap();
manager
.add_completions_model("bar", failure.clone())
.add_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap();
(service, counter, failure, port)
......@@ -977,9 +986,10 @@ async fn test_client_disconnect_cancellation_unary() {
wait_for_service_ready(port).await;
// Create a long-running engine (10 seconds)
let card = ModelDeploymentCard::with_name_only("slow-model");
let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager
.add_chat_completions_model("slow-model", long_running_engine.clone())
.add_chat_completions_model("slow-model", card.mdcsum(), long_running_engine.clone())
.unwrap();
let client = reqwest::Client::new();
......@@ -1068,9 +1078,14 @@ async fn test_client_disconnect_cancellation_streaming() {
wait_for_service_ready(port).await;
// Create a long-running engine (10 seconds)
let card = ModelDeploymentCard::with_name_only("slow-stream-model");
let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager
.add_chat_completions_model("slow-stream-model", long_running_engine.clone())
.add_chat_completions_model(
"slow-stream-model",
card.mdcsum(),
long_running_engine.clone(),
)
.unwrap();
let client = reqwest::Client::new();
......@@ -1166,9 +1181,10 @@ async fn test_request_id_annotation() {
wait_for_service_ready(port).await;
// Add a counter engine for this test
let card = ModelDeploymentCard::with_name_only("test-model");
let counter_engine = Arc::new(CounterEngine {});
manager
.add_chat_completions_model("test-model", counter_engine)
.add_chat_completions_model("test-model", card.mdcsum(), counter_engine)
.unwrap();
// Create reqwest client directly
......
......@@ -4,8 +4,8 @@
use anyhow::Error;
use async_stream::stream;
use dynamo_llm::{
http::service::metrics::Endpoint,
http::service::service_v2::HttpService,
http::service::{metrics::Endpoint, service_v2::HttpService},
model_card::ModelDeploymentCard,
protocols::{
Annotated,
openai::chat_completions::{
......@@ -206,9 +206,10 @@ async fn test_metrics_with_mock_model() {
let task = tokio::spawn(async move { service.run(token.clone()).await });
// Add mock model engine
let card = ModelDeploymentCard::with_name_only("mockmodel");
let mock_engine = Arc::new(MockModelEngine {});
manager
.add_chat_completions_model("mockmodel", mock_engine)
.add_chat_completions_model("mockmodel", card.mdcsum(), mock_engine)
.unwrap();
// Wait for service to be ready
......@@ -293,10 +294,8 @@ async fn test_metrics_with_mock_model() {
mod integration_tests {
use super::*;
use dynamo_llm::{
discovery::{MODEL_ROOT_PATH, ModelEntry, ModelWatcher},
engines::make_echo_engine,
entrypoint::EngineConfig,
local_model::LocalModelBuilder,
discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig,
local_model::LocalModelBuilder, model_card,
};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
......@@ -348,7 +347,7 @@ mod integration_tests {
// Start watching etcd for model registrations
if let Some(etcd_client) = distributed_runtime.etcd_client() {
let models_watcher = etcd_client
.kv_get_and_watch_prefix(MODEL_ROOT_PATH)
.kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await
.unwrap();
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
......@@ -364,10 +363,11 @@ mod integration_tests {
panic!("Expected StaticFull config");
};
let card = local_model.card().clone();
let engine = Arc::new(dynamo_llm::engines::StreamingEngineAdapter::new(engine));
let manager = service.model_manager();
manager
.add_chat_completions_model(model.service_name(), engine.clone())
.add_chat_completions_model(model.service_name(), card.mdcsum(), engine.clone())
.unwrap();
// Now do the proper MDC registration via LocalModel::attach()
......@@ -376,7 +376,7 @@ mod integration_tests {
let test_component = namespace.component("test-mdc-component").unwrap();
let test_endpoint = test_component.endpoint("test-mdc-endpoint");
// This will store the MDC in etcd and create the ModelEntry for discovery
// This will store the MDC in etcd for discovery
local_model
.attach(
&test_endpoint,
......@@ -388,8 +388,7 @@ mod integration_tests {
// Manually save the model card and update metrics
// This simulates what the ModelWatcher polling task would do in production
let card = local_model.card().clone();
manager.save_model_card("test-mdc-key", card.clone());
let _ = manager.save_model_card("test-mdc-key", card.clone());
if let Err(e) = service
.state()
......@@ -500,8 +499,13 @@ mod integration_tests {
assert!(metrics_body.contains("request_type=\"stream\""));
assert!(metrics_body.contains("status=\"success\""));
// etcd lease will ensure we everything is deleted from etcd
// Now test the complete lifecycle: remove the model from etcd
// We don't need to cleanup model manager because it's local to this test
/*
// Clean up
// Remove the model using the cleaner ModelWatcher approach
if let Some(etcd_client) = distributed_runtime.etcd_client() {
// Use ModelWatcher to find and remove the model (following ModelWatcher::handle_delete pattern)
......@@ -514,10 +518,7 @@ mod integration_tests {
);
// Get all model entries for our test model
let model_entries = watcher
.entries_for_model("test-mdc-model", None, true)
.await
.unwrap();
let model_entries = watcher.entries_for_model("test-mdc-model").await.unwrap();
if !model_entries.is_empty() {
// For each model entry, we need to find its etcd key and remove it
......@@ -566,8 +567,8 @@ mod integration_tests {
}
}
}
*/
// Clean up
cancel_token.cancel();
task.await.unwrap().unwrap();
}
......
......@@ -280,31 +280,32 @@ pub mod kserve_test {
let failure = Arc::new(AlwaysFailEngine {});
let long_running = Arc::new(LongRunningEngine::new(1_000));
manager
.add_completions_model("split", split.clone())
.unwrap();
let mut card = ModelDeploymentCard::with_name_only("split");
card.model_type = ModelType::Completions;
card.model_input = ModelInput::Text;
manager.save_model_card("split", card);
manager
.add_chat_completions_model("failure", failure.clone())
.unwrap();
manager
.add_completions_model("failure", failure.clone())
.add_completions_model("split", card.mdcsum(), split.clone())
.unwrap();
let _ = manager.save_model_card("split", card.clone());
let mut card = ModelDeploymentCard::with_name_only("failure");
card.model_type = ModelType::Completions | ModelType::Chat;
card.model_input = ModelInput::Text;
manager.save_model_card("failure", card);
manager
.add_completions_model("long_running", long_running.clone())
.add_chat_completions_model("failure", card.mdcsum(), failure.clone())
.unwrap();
manager
.add_completions_model("failure", card.mdcsum(), failure.clone())
.unwrap();
let _ = manager.save_model_card("failure", card);
let mut card = ModelDeploymentCard::with_name_only("long_running");
card.model_type = ModelType::Completions;
card.model_input = ModelInput::Text;
manager.save_model_card("long_running", card);
manager
.add_completions_model("long_running", card.mdcsum(), long_running.clone())
.unwrap();
let _ = manager.save_model_card("long_running", card);
(service, split, failure, long_running)
}
......@@ -1130,11 +1131,16 @@ pub mod kserve_test {
text_input: inference::model_infer_request::InferInputTensor,
) {
// add tensor model
// Failure, model registered as Tensor but does not provide model config (in runtime config)
let mut card = ModelDeploymentCard::with_name_only("tensor");
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
let tensor = Arc::new(TensorEngine {});
service_with_engines
.0
.model_manager()
.add_tensor_model("tensor", tensor.clone())
.add_tensor_model("tensor", card.mdcsum(), tensor.clone())
.unwrap();
// start server
......@@ -1147,11 +1153,7 @@ pub mod kserve_test {
version: "".into(),
});
// Failure, model registered as Tensor but does not provide model config (in runtime config)
let mut card = ModelDeploymentCard::with_name_only("tensor");
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
service_with_engines
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);
......@@ -1217,7 +1219,7 @@ pub mod kserve_test {
}),
..Default::default()
};
service_with_engines
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment