Unverified Commit 81162dfe authored by Graham King's avatar Graham King Committed by GitHub
Browse files

chore(discovery): Watch/publish ModelDeploymentCard instead of ModelEntry (#3350)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent ddbb4f50
......@@ -69,7 +69,7 @@ async def worker(runtime: DistributedRuntime):
host: str = "localhost"
port: int = 8000
service: HttpService = HttpService(port=port)
service.add_chat_completions_model(served_model_name, engine)
service.add_chat_completions_model(served_model_name, "mdcsum", engine)
print("Starting service...")
shutdown_signal = service.run(runtime.child_token())
......
......@@ -30,23 +30,29 @@ impl HttpService {
Ok(Self { inner })
}
pub fn add_completions_model(&self, model: String, engine: HttpAsyncEngine) -> PyResult<()> {
pub fn add_completions_model(
&self,
model: String,
checksum: String,
engine: HttpAsyncEngine,
) -> PyResult<()> {
let engine = Arc::new(engine);
self.inner
.model_manager()
.add_completions_model(&model, engine)
.add_completions_model(&model, &checksum, engine)
.map_err(to_pyerr)
}
pub fn add_chat_completions_model(
&self,
model: String,
checksum: String,
engine: HttpAsyncEngine,
) -> PyResult<()> {
let engine = Arc::new(engine);
self.inner
.model_manager()
.add_chat_completions_model(&model, engine)
.add_chat_completions_model(&model, &checksum, engine)
.map_err(to_pyerr)
}
......
......@@ -85,6 +85,7 @@ async def http_server(runtime: DistributedRuntime):
model_name = "test_model"
start_done = asyncio.Event()
child_token = runtime.child_token()
checksum = "abc123" # Checksum of ModelDeplomentCard for that model
async def worker():
"""The server worker task."""
......@@ -94,7 +95,7 @@ async def http_server(runtime: DistributedRuntime):
engine = HttpAsyncEngine(python_engine.generate, loop)
service = HttpService(port=port)
service.add_chat_completions_model(model_name, engine)
service.add_chat_completions_model(model_name, checksum, engine)
service.enable_endpoint("chat", True)
shutdown_signal = service.run(child_token)
......
......@@ -33,7 +33,7 @@ pub struct Checksum {
algorithm: CryptographicHashMethods,
}
#[derive(Serialize, Deserialize, Debug, Clone, Eq, PartialEq)]
#[derive(Serialize, Deserialize, Debug, Clone, Copy, Eq, PartialEq)]
pub enum CryptographicHashMethods {
#[serde(rename = "blake3")]
BLAKE3,
......@@ -259,6 +259,15 @@ impl TryFrom<&str> for Checksum {
}
}
impl Default for Checksum {
fn default() -> Self {
Self {
hash: "".to_string(),
algorithm: CryptographicHashMethods::BLAKE3,
}
}
}
impl FromStr for CryptographicHashMethods {
type Err = String;
......
......@@ -4,14 +4,8 @@
mod model_manager;
pub use model_manager::{ModelManager, ModelManagerError};
mod model_entry;
pub use model_entry::ModelEntry;
mod watcher;
pub use watcher::{ModelUpdate, ModelWatcher};
/// The root etcd path for ModelEntry
pub const MODEL_ROOT_PATH: &str = "models";
/// The root etcd path for KV Router registrations
pub const KV_ROUTERS_ROOT_PATH: &str = "kv_routers";
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use dynamo_runtime::{protocols, slug::Slug};
use serde::{Deserialize, Serialize};
use crate::local_model::runtime_config::ModelRuntimeConfig;
/// [ModelEntry] contains the information to discover models
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry {
/// Public name of the model
/// Used to identify the model in the HTTP service from the value used in an OpenAI ChatRequest.
pub name: String,
/// How to address this on the network
#[serde(rename = "endpoint")]
pub endpoint_id: protocols::EndpointId,
/// Runtime configuration specific to this model instance
#[serde(default, skip_serializing_if = "Option::is_none")]
pub runtime_config: Option<ModelRuntimeConfig>,
}
impl ModelEntry {
/// Slugified display name for use in network storage, or URL-safe environments
pub fn slug(&self) -> Slug {
Slug::from_string(&self.name)
}
}
......@@ -11,7 +11,6 @@ use parking_lot::{Mutex, RwLock};
use dynamo_runtime::component::Component;
use dynamo_runtime::prelude::DistributedRuntimeProvider;
use crate::kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector};
use crate::{discovery::KV_ROUTERS_ROOT_PATH, model_card::ModelDeploymentCard};
use crate::{
kv_router::KvRouter,
......@@ -21,6 +20,10 @@ use crate::{
completions::OpenAICompletionsStreamingEngine, embeddings::OpenAIEmbeddingsStreamingEngine,
},
};
use crate::{
kv_router::{KvRouterConfig, scheduler::DefaultWorkerSelector},
model_type::ModelType,
};
#[derive(Debug, thiserror::Error)]
pub enum ModelManagerError {
......@@ -39,7 +42,7 @@ pub struct ModelManager {
embeddings_engines: RwLock<ModelEngines<OpenAIEmbeddingsStreamingEngine>>,
tensor_engines: RwLock<ModelEngines<TensorStreamingEngine>>,
// These two are Mutex because we read and write rarely and equally
// These are Mutex because we read and write rarely and equally
cards: Mutex<HashMap<String, ModelDeploymentCard>>,
kv_choosers: Mutex<HashMap<String, Arc<KvRouter>>>,
}
......@@ -62,6 +65,43 @@ impl ModelManager {
}
}
pub fn is_valid_checksum(
&self,
model_type: ModelType,
model_name: &str,
candidate_checksum: &str,
) -> Option<bool> {
let mut results = vec![];
for unit in model_type.units() {
let maybe_valid_checksum = match unit {
ModelType::Chat => self.chat_completion_engines.read().checksum(model_name),
ModelType::Completions => self.completion_engines.read().checksum(model_name),
ModelType::Embedding => self.embeddings_engines.read().checksum(model_name),
ModelType::TensorBased => self.tensor_engines.read().checksum(model_name),
_ => {
continue;
}
};
if let Some(is_valid) = maybe_valid_checksum.map(|valid_checksum| {
tracing::debug!(
model_name,
valid_checksum,
candidate_checksum,
"is_valid_checksum: check case"
);
valid_checksum == candidate_checksum
}) {
results.push(is_valid)
}
}
if results.is_empty() {
None
} else {
// The checksum is valid if it is correct for all the ModelType in the bitflag.
Some(results.into_iter().all(|x| x))
}
}
pub fn get_model_cards(&self) -> Vec<ModelDeploymentCard> {
self.cards.lock().values().cloned().collect()
}
......@@ -99,37 +139,41 @@ impl ModelManager {
pub fn add_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAICompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.completion_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}
pub fn add_chat_completions_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIChatCompletionsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.chat_completion_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}
pub fn add_embeddings_model(
&self,
model: &str,
card_checksum: &str,
engine: OpenAIEmbeddingsStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.embeddings_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}
pub fn add_tensor_model(
&self,
model: &str,
card_checksum: &str,
engine: TensorStreamingEngine,
) -> Result<(), ModelManagerError> {
let mut clients = self.tensor_engines.write();
clients.add(model, engine)
clients.add(model, card_checksum, engine)
}
pub fn remove_completions_model(&self, model: &str) -> Result<(), ModelManagerError> {
......@@ -196,10 +240,11 @@ impl ModelManager {
.ok_or(ModelManagerError::ModelNotFound(model.to_string()))
}
/// Save a ModelDeploymentCard from an instance's etcd `models/` key so we can fetch it later when the key is
/// deleted from etcd.
pub fn save_model_card(&self, key: &str, entry: ModelDeploymentCard) {
self.cards.lock().insert(key.to_string(), entry);
/// Save a ModelDeploymentCard from an instance's ModelDeploymentCard key so we can fetch it later when the key is
/// deleted.
pub fn save_model_card(&self, key: &str, card: ModelDeploymentCard) -> anyhow::Result<()> {
self.cards.lock().insert(key.to_string(), card);
Ok(())
}
/// Remove and return model card for this instance's etcd key. We do this when the instance stops.
......@@ -291,6 +336,9 @@ pub struct ModelEngines<E> {
/// Optional default model name
default: Option<String>,
engines: HashMap<String, E>,
/// Key: Model name, value: Checksum of the ModelDeploymentCard. New instances must have the
/// same card.
checksums: HashMap<String, String>,
}
impl<E> Default for ModelEngines<E> {
......@@ -298,6 +346,7 @@ impl<E> Default for ModelEngines<E> {
Self {
default: None,
engines: HashMap::new(),
checksums: HashMap::new(),
}
}
}
......@@ -313,11 +362,13 @@ impl<E> ModelEngines<E> {
self.default = None;
}
fn add(&mut self, model: &str, engine: E) -> Result<(), ModelManagerError> {
fn add(&mut self, model: &str, checksum: &str, engine: E) -> Result<(), ModelManagerError> {
if self.engines.contains_key(model) {
return Err(ModelManagerError::ModelAlreadyExists(model.to_string()));
}
self.engines.insert(model.to_string(), engine);
self.checksums
.insert(model.to_string(), checksum.to_string());
Ok(())
}
......@@ -325,6 +376,7 @@ impl<E> ModelEngines<E> {
if self.engines.remove(model).is_none() {
return Err(ModelManagerError::ModelNotFound(model.to_string()));
}
let _ = self.checksums.remove(model);
Ok(())
}
......@@ -339,4 +391,10 @@ impl<E> ModelEngines<E> {
pub fn list(&self) -> Vec<String> {
self.engines.keys().map(|k| k.to_owned()).collect()
}
/// Returns a newly allocated String for called convenience. All the places I use
/// this I need a String.
pub fn checksum(&self, model: &str) -> Option<String> {
self.checksums.get(model).map(|s| s.to_string())
}
}
This diff is collapsed.
......@@ -5,12 +5,12 @@ use std::pin::Pin;
use crate::{
backend::{Backend, ExecutionContext},
discovery::{MODEL_ROOT_PATH, ModelManager, ModelWatcher},
discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig},
kv_router::{KvPushRouter, KvRouter},
migration::Migration,
model_card::ModelDeploymentCard,
model_card::{self, ModelDeploymentCard},
preprocessor::{OpenAIPreprocessor, prompt::PromptFormatter},
protocols::common::llm_backend::{BackendOutput, LLMEngineOutput, PreprocessedRequest},
request_template::RequestTemplate,
......@@ -73,7 +73,9 @@ pub async fn prepare_engine(
None,
None,
));
let models_watcher = etcd_client.kv_get_and_watch_prefix(MODEL_ROOT_PATH).await?;
let models_watcher = etcd_client
.kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await?;
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
let inner_watch_obj = watch_obj.clone();
......
......@@ -4,11 +4,12 @@
use std::sync::Arc;
use crate::{
discovery::{MODEL_ROOT_PATH, ModelManager, ModelWatcher},
discovery::{ModelManager, ModelWatcher},
engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common},
grpc::service::kserve,
kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
......@@ -46,7 +47,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
distributed_runtime,
grpc_service.state().manager_clone(),
etcd_client.clone(),
MODEL_ROOT_PATH,
model_card::ROOT_PATH,
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
......@@ -62,6 +63,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
}
EngineConfig::StaticRemote(local_model) => {
let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static
......@@ -103,7 +105,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
tokenizer_hf.clone(),
)
.await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;
manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine =
entrypoint::build_routed_pipeline::<
......@@ -111,7 +117,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?;
manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
grpc_service
}
......@@ -119,8 +129,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let grpc_service = grpc_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine));
let manager = grpc_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?;
manager.add_chat_completions_model(model.service_name(), engine)?;
let checksum = model.card().mdcsum();
manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
grpc_service
}
EngineConfig::StaticCore {
......@@ -130,6 +141,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} => {
let grpc_service = grpc_service_builder.build()?;
let manager = grpc_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline =
......@@ -138,14 +150,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer_hf)
.await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
grpc_service
}
};
......
......@@ -4,12 +4,13 @@
use std::sync::Arc;
use crate::{
discovery::{MODEL_ROOT_PATH, ModelManager, ModelUpdate, ModelWatcher},
discovery::{ModelManager, ModelUpdate, ModelWatcher},
endpoint_type::EndpointType,
engines::StreamingEngineAdapter,
entrypoint::{self, EngineConfig, input::common},
http::service::service_v2::{self, HttpService},
kv_router::KvRouterConfig,
model_card,
namespace::is_global_namespace,
types::openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
......@@ -74,7 +75,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
distributed_runtime,
http_service.state().manager_clone(),
etcd_client.clone(),
MODEL_ROOT_PATH,
model_card::ROOT_PATH,
router_config.router_mode,
Some(router_config.kv_router_config),
router_config.busy_threshold,
......@@ -92,6 +93,8 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
}
EngineConfig::StaticRemote(local_model) => {
let card = local_model.card();
let checksum = card.mdcsum();
let router_mode = local_model.router_config().router_mode;
let dst_config = DistributedConfig::from_settings(true); // true means static
......@@ -133,7 +136,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
tokenizer_hf.clone(),
)
.await?;
manager.add_chat_completions_model(local_model.display_name(), chat_engine)?;
manager.add_chat_completions_model(
local_model.display_name(),
checksum,
chat_engine,
)?;
let completions_engine =
entrypoint::build_routed_pipeline::<
......@@ -141,7 +148,11 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateCompletionResponse,
>(card, &client, router_mode, None, kv_chooser, tokenizer_hf)
.await?;
manager.add_completions_model(local_model.display_name(), completions_engine)?;
manager.add_completions_model(
local_model.display_name(),
checksum,
completions_engine,
)?;
for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true);
......@@ -153,8 +164,9 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
let http_service = http_service_builder.build()?;
let engine = Arc::new(StreamingEngineAdapter::new(engine));
let manager = http_service.model_manager();
manager.add_completions_model(model.service_name(), engine.clone())?;
manager.add_chat_completions_model(model.service_name(), engine)?;
let checksum = model.card().mdcsum();
manager.add_completions_model(model.service_name(), checksum, engine.clone())?;
manager.add_chat_completions_model(model.service_name(), checksum, engine)?;
// Enable all endpoints
for endpoint_type in EndpointType::all() {
......@@ -169,6 +181,7 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
} => {
let http_service = http_service_builder.build()?;
let manager = http_service.model_manager();
let checksum = model.card().mdcsum();
let tokenizer_hf = model.card().tokenizer_hf()?;
let chat_pipeline =
......@@ -177,14 +190,14 @@ pub async fn run(runtime: Runtime, engine_config: EngineConfig) -> anyhow::Resul
NvCreateChatCompletionStreamResponse,
>(model.card(), inner_engine.clone(), tokenizer_hf.clone())
.await?;
manager.add_chat_completions_model(model.service_name(), chat_pipeline)?;
manager.add_chat_completions_model(model.service_name(), checksum, chat_pipeline)?;
let cmpl_pipeline = common::build_pipeline::<
NvCreateCompletionRequest,
NvCreateCompletionResponse,
>(model.card(), inner_engine, tokenizer_hf)
.await?;
manager.add_completions_model(model.service_name(), cmpl_pipeline)?;
manager.add_completions_model(model.service_name(), checksum, cmpl_pipeline)?;
// Enable all endpoints
for endpoint_type in EndpointType::all() {
http_service.enable_model_endpoint(endpoint_type, true);
......
......@@ -32,7 +32,6 @@ pub mod sequence;
pub mod subscriber;
use crate::{
discovery::{MODEL_ROOT_PATH, ModelEntry},
kv_router::{
approx::ApproxKvIndexer,
indexer::{
......@@ -45,6 +44,7 @@ use crate::{
subscriber::start_kv_router_background,
},
local_model::runtime_config::ModelRuntimeConfig,
model_card::{self, ModelDeploymentCard},
preprocessor::PreprocessedRequest,
protocols::common::llm_backend::LLMEngineOutput,
};
......@@ -247,9 +247,9 @@ impl KvRouter {
let runtime_configs_watcher = watch_prefix_with_extraction(
etcd_client,
MODEL_ROOT_PATH,
model_card::ROOT_PATH,
key_extractors::lease_id,
|model_entry: ModelEntry| model_entry.runtime_config,
|card: ModelDeploymentCard| Some(card.runtime_config),
cancellation_token.clone(),
)
.await?;
......
......@@ -15,15 +15,12 @@ use dynamo_runtime::{
storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
};
use crate::discovery::ModelEntry;
use crate::entrypoint::RouterConfig;
use crate::mocker::protocols::MockEngineArgs;
use crate::model_card::{self, ModelDeploymentCard};
use crate::model_type::{ModelInput, ModelType};
use crate::request_template::RequestTemplate;
mod network_name;
pub use network_name::ModelNetworkName;
pub mod runtime_config;
use runtime_config::ModelRuntimeConfig;
......@@ -421,36 +418,13 @@ impl LocalModel {
// Publish the Model Deployment Card to KV store
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let key = self.card.slug().to_string();
// TODO: Next PR will use this
//let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
//let key = Key::from_raw(endpoint.unique_path(lease_id));
card_store
.publish(
model_card::ROOT_PATH,
None,
&Key::from_raw(key),
&mut self.card,
)
.await?;
let lease_id = endpoint.drt().primary_lease().map(|l| l.id()).unwrap_or(0);
let key = Key::from_raw(endpoint.unique_path(lease_id));
// Publish our ModelEntry to etcd. This allows ingress to find the model card.
// (Why don't we put the model card directly under this key?)
let network_name = ModelNetworkName::new();
tracing::debug!("Registering with etcd as {network_name}");
let model_registration = ModelEntry {
name: self.display_name().to_string(),
endpoint_id: endpoint.id(),
runtime_config: Some(self.runtime_config.clone()),
};
etcd_client
.kv_create(
&network_name,
serde_json::to_vec_pretty(&model_registration)?,
None, // use primary lease
)
.await
let _outcome = card_store
.publish(model_card::ROOT_PATH, None, &key, &mut self.card)
.await?;
Ok(())
}
}
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
use crate::discovery::MODEL_ROOT_PATH;
#[derive(Debug, Clone)]
pub struct ModelNetworkName(String);
impl ModelNetworkName {
pub fn new() -> Self {
ModelNetworkName(format!("{MODEL_ROOT_PATH}/{}", uuid::Uuid::new_v4()))
}
}
impl Default for ModelNetworkName {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Display for ModelNetworkName {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
impl AsRef<str> for ModelNetworkName {
fn as_ref(&self) -> &str {
&self.0
}
}
impl std::ops::Deref for ModelNetworkName {
type Target = str;
fn deref(&self) -> &Self::Target {
&self.0
}
}
......@@ -15,9 +15,9 @@
use std::fmt;
use std::fs::File;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::sync::{Arc, OnceLock};
use crate::common::checked_file::CheckedFile;
use crate::common::checked_file::{CheckedFile, Checksum};
use crate::local_model::runtime_config::ModelRuntimeConfig;
use crate::model_type::{ModelInput, ModelType};
use anyhow::{Context, Result};
......@@ -43,6 +43,15 @@ pub enum ModelInfoType {
GGUF(PathBuf),
}
impl ModelInfoType {
pub fn checksum(&self) -> String {
match self {
ModelInfoType::HfConfigJson(c) => c.checksum().to_string(),
ModelInfoType::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum TokenizerKind {
......@@ -50,6 +59,15 @@ pub enum TokenizerKind {
GGUF(Box<HfTokenizer>),
}
impl TokenizerKind {
pub fn checksum(&self) -> String {
match self {
TokenizerKind::HfTokenizerJson(c) => c.checksum().to_string(),
TokenizerKind::GGUF(_) => Checksum::default().to_string(),
}
}
}
/// Supported types of prompt formatters.
///
/// We need a way to associate the prompt formatter template definition with an associated
......@@ -70,6 +88,16 @@ pub enum PromptFormatterArtifact {
GGUF(PathBuf),
}
impl PromptFormatterArtifact {
pub fn checksum(&self) -> String {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.checksum().to_string(),
PromptFormatterArtifact::HfChatTemplate(c) => c.checksum().to_string(),
PromptFormatterArtifact::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum PromptContextMixin {
......@@ -87,6 +115,15 @@ pub enum GenerationConfig {
GGUF(PathBuf),
}
impl GenerationConfig {
pub fn checksum(&self) -> String {
match self {
GenerationConfig::HfGenerationConfigJson(c) => c.checksum().to_string(),
GenerationConfig::GGUF(_) => Checksum::default().to_string(),
}
}
}
#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
pub struct ModelDeploymentCard {
/// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct"
......@@ -145,6 +182,9 @@ pub struct ModelDeploymentCard {
#[serde(skip)]
cache_dir: Option<Arc<tempfile::TempDir>>,
#[serde(skip, default)]
checksum: OnceLock<String>,
}
impl ModelDeploymentCard {
......@@ -189,6 +229,12 @@ impl ModelDeploymentCard {
Ok(())
}
#[inline]
pub fn name(&self) -> &str {
&self.display_name
}
#[inline]
pub fn slug(&self) -> &Slug {
&self.slug
}
......@@ -198,9 +244,45 @@ impl ModelDeploymentCard {
Ok(serde_json::to_string(self)?)
}
pub fn mdcsum(&self) -> String {
let json = self.to_json().unwrap();
format!("{}", blake3::hash(json.as_bytes()))
pub fn mdcsum(&self) -> &str {
self.checksum
.get_or_init(|| {
// Only include the important fields
let mut bytes_to_hash: Vec<u8> = Vec::with_capacity(512);
bytes_to_hash.extend(self.display_name.as_bytes());
// The files can be either a URL or a local path, so we ignore that and hash their
// checksum instead, which won't change wherever they are.
if let Some(model_info) = self.model_info.as_ref() {
bytes_to_hash.extend(model_info.checksum().as_bytes());
}
if let Some(tokenizer) = self.tokenizer.as_ref() {
bytes_to_hash.extend(tokenizer.checksum().as_bytes());
}
if let Some(prompt_formatter) = self.prompt_formatter.as_ref() {
bytes_to_hash.extend(prompt_formatter.checksum().as_bytes());
}
if let Some(chat_template) = self.chat_template_file.as_ref() {
bytes_to_hash.extend(chat_template.checksum().as_bytes());
}
if let Some(gen_config) = self.gen_config.as_ref() {
bytes_to_hash.extend(gen_config.checksum().as_bytes());
}
if let Some(prompt_context_vec) = self.prompt_context.as_ref() {
// Paste it as the bytes of the debug format. It's a Vec of enum, so this should be
// fine. If the debug representation changes that only happens in a new release.
bytes_to_hash.extend(format!("{prompt_context_vec:?}").as_bytes());
}
bytes_to_hash.extend(self.context_length.to_be_bytes());
bytes_to_hash.extend(self.kv_cache_block_size.to_be_bytes());
// TODO: Do we want any of user_data or runtime_config?
blake3::hash(&bytes_to_hash).to_string()
})
.as_ref()
}
/// Is this a full model card with tokenizer?
......@@ -291,9 +373,7 @@ impl ModelDeploymentCard {
/// Move the files this MDC uses from the NATS object store to local disk.
/// Updates the URI's to point to the created files.
///
/// The returned TempDir must be kept alive, it cleans up on drop.
async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<tempfile::TempDir> {
pub async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
let nats_addr = nats_client.addr();
let bucket_name = self.slug();
let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?;
......@@ -345,7 +425,9 @@ impl ModelDeploymentCard {
"tokenizer.json"
);
Ok(target_dir)
// This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive.
self.cache_dir = Some(Arc::new(target_dir));
Ok(())
}
/// Delete this card from the key-value store and it's URLs from the object store
......@@ -411,8 +493,7 @@ impl ModelDeploymentCard {
else {
return Ok(None);
};
// This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive.
card.cache_dir = Some(Arc::new(card.move_from_nats(drt.nats_client()).await?));
card.move_from_nats(drt.nats_client()).await?;
Ok(Some(card))
}
......@@ -487,6 +568,7 @@ impl ModelDeploymentCard {
user_data: None,
runtime_config: ModelRuntimeConfig::default(),
cache_dir: None,
checksum: OnceLock::new(),
})
}
......@@ -551,6 +633,7 @@ impl ModelDeploymentCard {
user_data: None,
runtime_config: ModelRuntimeConfig::default(),
cache_dir: None,
checksum: OnceLock::new(),
})
}
}
......
......@@ -74,6 +74,25 @@ impl ModelType {
result
}
/// Decompose the bitflag into it's component units:
/// Chat | Completion -> [Chat, Completion]
pub fn units(&self) -> Vec<ModelType> {
let mut result = Vec::new();
if self.supports_chat() {
result.push(ModelType::Chat);
}
if self.supports_completions() {
result.push(ModelType::Completions);
}
if self.supports_embedding() {
result.push(ModelType::Embedding);
}
if self.supports_tensor() {
result.push(ModelType::TensorBased);
}
result
}
/// Returns all endpoint types supported by this model type.
/// This properly handles combinations like Chat | Completions.
pub fn as_endpoint_types(&self) -> Vec<crate::endpoint_type::EndpointType> {
......
......@@ -124,7 +124,7 @@ impl OpenAIPreprocessor {
formatter: Arc<dyn OAIPromptFormatter>,
hf_tokenizer: tokenizers::Tokenizer,
) -> Result<Arc<Self>> {
let mdcsum = mdc.mdcsum();
let mdcsum = mdc.mdcsum().to_string();
let tokenizer = Arc::new(HuggingFaceTokenizer::from_tokenizer(hf_tokenizer));
let Some(model_info) = mdc.model_info else {
anyhow::bail!(
......
......@@ -4,9 +4,20 @@
use anyhow::Error;
use async_stream::stream;
use dynamo_async_openai::config::OpenAIConfig;
use dynamo_llm::http::{
use dynamo_llm::protocols::{
Annotated,
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
};
use dynamo_llm::{
http::{
client::{
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient, PureOpenAIClient,
GenericBYOTClient, HttpClientConfig, HttpRequestContext, NvCustomClient,
PureOpenAIClient,
},
service::{
Metrics,
......@@ -14,15 +25,8 @@ use dynamo_llm::http::{
metrics::{Endpoint, RequestType, Status},
service_v2::HttpService,
},
};
use dynamo_llm::protocols::{
Annotated,
codec::SseLineCodec,
convert_sse_stream,
openai::{
chat_completions::{NvCreateChatCompletionRequest, NvCreateChatCompletionStreamResponse},
completions::{NvCreateCompletionRequest, NvCreateCompletionResponse},
},
model_card::ModelDeploymentCard,
};
use dynamo_runtime::metrics::prometheus_names::{frontend_service, name_prefix};
use dynamo_runtime::{
......@@ -275,15 +279,18 @@ async fn test_http_service() {
let registry = Registry::new();
// TODO: Shouldn't this test know the card before it registers a model?
let card = ModelDeploymentCard::with_name_only("foo");
let counter = Arc::new(CounterEngine {});
let result = manager.add_chat_completions_model("foo", counter);
let result = manager.add_chat_completions_model("foo", card.mdcsum(), counter);
assert!(result.is_ok());
let failure = Arc::new(AlwaysFailEngine {});
let result = manager.add_chat_completions_model("bar", failure.clone());
let card = ModelDeploymentCard::with_name_only("bar");
let result = manager.add_chat_completions_model("bar", card.mdcsum(), failure.clone());
assert!(result.is_ok());
let result = manager.add_completions_model("bar", failure);
let result = manager.add_completions_model("bar", card.mdcsum(), failure);
assert!(result.is_ok());
let metrics = state.metrics_clone();
......@@ -578,14 +585,16 @@ async fn service_with_engines() -> (HttpService, Arc<CounterEngine>, Arc<AlwaysF
let counter = Arc::new(CounterEngine {});
let failure = Arc::new(AlwaysFailEngine {});
let card = ModelDeploymentCard::with_name_only("foo");
manager
.add_chat_completions_model("foo", counter.clone())
.add_chat_completions_model("foo", card.mdcsum(), counter.clone())
.unwrap();
let card = ModelDeploymentCard::with_name_only("bar");
manager
.add_chat_completions_model("bar", failure.clone())
.add_chat_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap();
manager
.add_completions_model("bar", failure.clone())
.add_completions_model("bar", card.mdcsum(), failure.clone())
.unwrap();
(service, counter, failure, port)
......@@ -977,9 +986,10 @@ async fn test_client_disconnect_cancellation_unary() {
wait_for_service_ready(port).await;
// Create a long-running engine (10 seconds)
let card = ModelDeploymentCard::with_name_only("slow-model");
let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager
.add_chat_completions_model("slow-model", long_running_engine.clone())
.add_chat_completions_model("slow-model", card.mdcsum(), long_running_engine.clone())
.unwrap();
let client = reqwest::Client::new();
......@@ -1068,9 +1078,14 @@ async fn test_client_disconnect_cancellation_streaming() {
wait_for_service_ready(port).await;
// Create a long-running engine (10 seconds)
let card = ModelDeploymentCard::with_name_only("slow-stream-model");
let long_running_engine = Arc::new(LongRunningEngine::new(10_000));
manager
.add_chat_completions_model("slow-stream-model", long_running_engine.clone())
.add_chat_completions_model(
"slow-stream-model",
card.mdcsum(),
long_running_engine.clone(),
)
.unwrap();
let client = reqwest::Client::new();
......@@ -1166,9 +1181,10 @@ async fn test_request_id_annotation() {
wait_for_service_ready(port).await;
// Add a counter engine for this test
let card = ModelDeploymentCard::with_name_only("test-model");
let counter_engine = Arc::new(CounterEngine {});
manager
.add_chat_completions_model("test-model", counter_engine)
.add_chat_completions_model("test-model", card.mdcsum(), counter_engine)
.unwrap();
// Create reqwest client directly
......
......@@ -4,8 +4,8 @@
use anyhow::Error;
use async_stream::stream;
use dynamo_llm::{
http::service::metrics::Endpoint,
http::service::service_v2::HttpService,
http::service::{metrics::Endpoint, service_v2::HttpService},
model_card::ModelDeploymentCard,
protocols::{
Annotated,
openai::chat_completions::{
......@@ -206,9 +206,10 @@ async fn test_metrics_with_mock_model() {
let task = tokio::spawn(async move { service.run(token.clone()).await });
// Add mock model engine
let card = ModelDeploymentCard::with_name_only("mockmodel");
let mock_engine = Arc::new(MockModelEngine {});
manager
.add_chat_completions_model("mockmodel", mock_engine)
.add_chat_completions_model("mockmodel", card.mdcsum(), mock_engine)
.unwrap();
// Wait for service to be ready
......@@ -293,10 +294,8 @@ async fn test_metrics_with_mock_model() {
mod integration_tests {
use super::*;
use dynamo_llm::{
discovery::{MODEL_ROOT_PATH, ModelEntry, ModelWatcher},
engines::make_echo_engine,
entrypoint::EngineConfig,
local_model::LocalModelBuilder,
discovery::ModelWatcher, engines::make_echo_engine, entrypoint::EngineConfig,
local_model::LocalModelBuilder, model_card,
};
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::pipeline::RouterMode;
......@@ -348,7 +347,7 @@ mod integration_tests {
// Start watching etcd for model registrations
if let Some(etcd_client) = distributed_runtime.etcd_client() {
let models_watcher = etcd_client
.kv_get_and_watch_prefix(MODEL_ROOT_PATH)
.kv_get_and_watch_prefix(model_card::ROOT_PATH)
.await
.unwrap();
let (_prefix, _watcher, receiver) = models_watcher.dissolve();
......@@ -364,10 +363,11 @@ mod integration_tests {
panic!("Expected StaticFull config");
};
let card = local_model.card().clone();
let engine = Arc::new(dynamo_llm::engines::StreamingEngineAdapter::new(engine));
let manager = service.model_manager();
manager
.add_chat_completions_model(model.service_name(), engine.clone())
.add_chat_completions_model(model.service_name(), card.mdcsum(), engine.clone())
.unwrap();
// Now do the proper MDC registration via LocalModel::attach()
......@@ -376,7 +376,7 @@ mod integration_tests {
let test_component = namespace.component("test-mdc-component").unwrap();
let test_endpoint = test_component.endpoint("test-mdc-endpoint");
// This will store the MDC in etcd and create the ModelEntry for discovery
// This will store the MDC in etcd for discovery
local_model
.attach(
&test_endpoint,
......@@ -388,8 +388,7 @@ mod integration_tests {
// Manually save the model card and update metrics
// This simulates what the ModelWatcher polling task would do in production
let card = local_model.card().clone();
manager.save_model_card("test-mdc-key", card.clone());
let _ = manager.save_model_card("test-mdc-key", card.clone());
if let Err(e) = service
.state()
......@@ -500,8 +499,13 @@ mod integration_tests {
assert!(metrics_body.contains("request_type=\"stream\""));
assert!(metrics_body.contains("status=\"success\""));
// etcd lease will ensure we everything is deleted from etcd
// Now test the complete lifecycle: remove the model from etcd
// We don't need to cleanup model manager because it's local to this test
/*
// Clean up
// Remove the model using the cleaner ModelWatcher approach
if let Some(etcd_client) = distributed_runtime.etcd_client() {
// Use ModelWatcher to find and remove the model (following ModelWatcher::handle_delete pattern)
......@@ -514,10 +518,7 @@ mod integration_tests {
);
// Get all model entries for our test model
let model_entries = watcher
.entries_for_model("test-mdc-model", None, true)
.await
.unwrap();
let model_entries = watcher.entries_for_model("test-mdc-model").await.unwrap();
if !model_entries.is_empty() {
// For each model entry, we need to find its etcd key and remove it
......@@ -566,8 +567,8 @@ mod integration_tests {
}
}
}
*/
// Clean up
cancel_token.cancel();
task.await.unwrap().unwrap();
}
......
......@@ -280,31 +280,32 @@ pub mod kserve_test {
let failure = Arc::new(AlwaysFailEngine {});
let long_running = Arc::new(LongRunningEngine::new(1_000));
manager
.add_completions_model("split", split.clone())
.unwrap();
let mut card = ModelDeploymentCard::with_name_only("split");
card.model_type = ModelType::Completions;
card.model_input = ModelInput::Text;
manager.save_model_card("split", card);
manager
.add_chat_completions_model("failure", failure.clone())
.unwrap();
manager
.add_completions_model("failure", failure.clone())
.add_completions_model("split", card.mdcsum(), split.clone())
.unwrap();
let _ = manager.save_model_card("split", card.clone());
let mut card = ModelDeploymentCard::with_name_only("failure");
card.model_type = ModelType::Completions | ModelType::Chat;
card.model_input = ModelInput::Text;
manager.save_model_card("failure", card);
manager
.add_completions_model("long_running", long_running.clone())
.add_chat_completions_model("failure", card.mdcsum(), failure.clone())
.unwrap();
manager
.add_completions_model("failure", card.mdcsum(), failure.clone())
.unwrap();
let _ = manager.save_model_card("failure", card);
let mut card = ModelDeploymentCard::with_name_only("long_running");
card.model_type = ModelType::Completions;
card.model_input = ModelInput::Text;
manager.save_model_card("long_running", card);
manager
.add_completions_model("long_running", card.mdcsum(), long_running.clone())
.unwrap();
let _ = manager.save_model_card("long_running", card);
(service, split, failure, long_running)
}
......@@ -1130,11 +1131,16 @@ pub mod kserve_test {
text_input: inference::model_infer_request::InferInputTensor,
) {
// add tensor model
// Failure, model registered as Tensor but does not provide model config (in runtime config)
let mut card = ModelDeploymentCard::with_name_only("tensor");
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
let tensor = Arc::new(TensorEngine {});
service_with_engines
.0
.model_manager()
.add_tensor_model("tensor", tensor.clone())
.add_tensor_model("tensor", card.mdcsum(), tensor.clone())
.unwrap();
// start server
......@@ -1147,11 +1153,7 @@ pub mod kserve_test {
version: "".into(),
});
// Failure, model registered as Tensor but does not provide model config (in runtime config)
let mut card = ModelDeploymentCard::with_name_only("tensor");
card.model_type = ModelType::TensorBased;
card.model_input = ModelInput::Tensor;
service_with_engines
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);
......@@ -1217,7 +1219,7 @@ pub mod kserve_test {
}),
..Default::default()
};
service_with_engines
let _ = service_with_engines
.0
.model_manager()
.save_model_card("key", card);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment