Unverified Commit 9060ce12 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Make part of discovery re-usable (#3073)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 045b61dd
...@@ -17,7 +17,7 @@ impl ModelDeploymentCard { ...@@ -17,7 +17,7 @@ impl ModelDeploymentCard {
// Previously called "from_local_path" // Previously called "from_local_path"
#[staticmethod] #[staticmethod]
fn load(path: String, model_name: String) -> PyResult<ModelDeploymentCard> { fn load(path: String, model_name: String) -> PyResult<ModelDeploymentCard> {
let mut card = RsModelDeploymentCard::load(&path, None).map_err(to_pyerr)?; let mut card = RsModelDeploymentCard::load_from_disk(&path, None).map_err(to_pyerr)?;
card.set_name(&model_name); card.set_name(&model_name);
Ok(ModelDeploymentCard { inner: card }) Ok(ModelDeploymentCard { inner: card })
} }
......
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
use std::sync::Arc; use dynamo_runtime::{protocols, slug::Slug};
use dynamo_runtime::transports::etcd;
use dynamo_runtime::{
protocols,
slug::Slug,
storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager},
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{
local_model::runtime_config::ModelRuntimeConfig, local_model::runtime_config::ModelRuntimeConfig,
model_card::{self, ModelDeploymentCard},
model_type::{ModelInput, ModelType}, model_type::{ModelInput, ModelType},
}; };
/// [ModelEntry] contains the information to discover models from the etcd cluster. /// [ModelEntry] contains the information to discover models
#[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)] #[derive(Debug, Clone, Serialize, Deserialize, Eq, PartialEq)]
pub struct ModelEntry { pub struct ModelEntry {
/// Public name of the model /// Public name of the model
...@@ -42,7 +34,7 @@ pub struct ModelEntry { ...@@ -42,7 +34,7 @@ pub struct ModelEntry {
} }
impl ModelEntry { impl ModelEntry {
/// Slugified display name for use in etcd and NATS /// Slugified display name for use in network storage, or URL-safe environments
pub fn slug(&self) -> Slug { pub fn slug(&self) -> Slug {
Slug::from_string(&self.name) Slug::from_string(&self.name)
} }
...@@ -50,29 +42,4 @@ impl ModelEntry { ...@@ -50,29 +42,4 @@ impl ModelEntry {
pub fn requires_preprocessing(&self) -> bool { pub fn requires_preprocessing(&self) -> bool {
matches!(self.model_input, ModelInput::Tokens) matches!(self.model_input, ModelInput::Tokens)
} }
/// Fetch the ModelDeploymentCard from etcd.
/// This does not touch its fields so you may need to call move_from_nats on it.
pub async fn load_mdc(
&self,
etcd_client: &etcd::Client,
) -> anyhow::Result<ModelDeploymentCard> {
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
let card_key = self.slug();
match card_store
.load::<ModelDeploymentCard>(model_card::ROOT_PATH, &card_key)
.await
{
Ok(Some(mdc)) => Ok(mdc),
Ok(None) => {
anyhow::bail!("Missing ModelDeploymentCard in etcd under key {card_key}");
}
Err(err) => {
anyhow::bail!(
"Error fetching ModelDeploymentCard from etcd under key {card_key}. {err}"
);
}
}
}
} }
...@@ -21,6 +21,7 @@ use crate::{ ...@@ -21,6 +21,7 @@ use crate::{
backend::Backend, backend::Backend,
entrypoint, entrypoint,
kv_router::KvRouterConfig, kv_router::KvRouterConfig,
model_card::ModelDeploymentCard,
model_type::{ModelInput, ModelType}, model_type::{ModelInput, ModelType},
preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter}, preprocessor::{OpenAIPreprocessor, PreprocessedEmbeddingRequest, prompt::PromptFormatter},
protocols::{ protocols::{
...@@ -260,19 +261,16 @@ impl ModelWatcher { ...@@ -260,19 +261,16 @@ impl ModelWatcher {
.namespace(&endpoint_id.namespace)? .namespace(&endpoint_id.namespace)?
.component(&endpoint_id.component)?; .component(&endpoint_id.component)?;
let client = component.endpoint(&endpoint_id.name).client().await?; let client = component.endpoint(&endpoint_id.name).client().await?;
let model_slug = model_entry.slug();
let Some(etcd_client) = self.drt.etcd_client() else { let card = match ModelDeploymentCard::load_from_store(&model_slug, &self.drt).await {
// Should be impossible because we only get here on an etcd event Ok(Some(card)) => card,
anyhow::bail!("Missing etcd_client"); Ok(None) => {
}; anyhow::bail!("Missing ModelDeploymentCard in storage under key {model_slug}");
let card = match model_entry.load_mdc(&etcd_client).await {
Ok(card) => {
tracing::debug!(card.display_name, "adding model");
Some(card)
} }
Err(err) => { Err(err) => {
tracing::info!(error = %err, "load_mdc did not complete"); anyhow::bail!(
None "Error fetching ModelDeploymentCard from storage under key {model_slug}. {err}"
);
} }
}; };
...@@ -284,15 +282,6 @@ impl ModelWatcher { ...@@ -284,15 +282,6 @@ impl ModelWatcher {
// A model that expects pre-processed requests meaning it's up to us whether we // A model that expects pre-processed requests meaning it's up to us whether we
// handle Chat or Completions requests, so handle whatever the model supports. // handle Chat or Completions requests, so handle whatever the model supports.
let Some(mut card) = card else {
anyhow::bail!("Missing model deployment card");
};
// Download tokenizer.json etc to local disk
// This cache_dir is a tempfile::TempDir will be deleted on drop. I _think_
// OpenAIPreprocessor::new loads the files, so we can delete them after this
// function. Needs checking carefully, possibly we need to store it in state.
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
let kv_chooser = if self.router_mode == RouterMode::KV { let kv_chooser = if self.router_mode == RouterMode::KV {
Some( Some(
self.manager self.manager
...@@ -309,7 +298,7 @@ impl ModelWatcher { ...@@ -309,7 +298,7 @@ impl ModelWatcher {
}; };
// This is expensive, we are loading ~10MiB JSON, so only do it once // This is expensive, we are loading ~10MiB JSON, so only do it once
let tokenizer_hf = card.tokenizer_hf()?; let tokenizer_hf = card.tokenizer_hf().context("tokenizer_hf")?;
// Add chat engine only if the model supports chat // Add chat engine only if the model supports chat
if model_entry.model_type.supports_chat() { if model_entry.model_type.supports_chat() {
...@@ -324,9 +313,11 @@ impl ModelWatcher { ...@@ -324,9 +313,11 @@ impl ModelWatcher {
kv_chooser.clone(), kv_chooser.clone(),
tokenizer_hf.clone(), tokenizer_hf.clone(),
) )
.await?; .await
.context("build_routed_pipeline")?;
self.manager self.manager
.add_chat_completions_model(&model_entry.name, chat_engine)?; .add_chat_completions_model(&model_entry.name, chat_engine)
.context("add_chat_completions_model")?;
tracing::info!("Chat completions is ready"); tracing::info!("Chat completions is ready");
} }
...@@ -338,7 +329,8 @@ impl ModelWatcher { ...@@ -338,7 +329,8 @@ impl ModelWatcher {
card.clone(), card.clone(),
formatter, formatter,
tokenizer_hf.clone(), tokenizer_hf.clone(),
)?; )
.context("OpenAIPreprocessor::new_with_parts")?;
let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::< let completions_engine = entrypoint::build_routed_pipeline_with_preprocessor::<
NvCreateCompletionRequest, NvCreateCompletionRequest,
NvCreateCompletionResponse, NvCreateCompletionResponse,
...@@ -351,9 +343,11 @@ impl ModelWatcher { ...@@ -351,9 +343,11 @@ impl ModelWatcher {
preprocessor, preprocessor,
tokenizer_hf, tokenizer_hf,
) )
.await?; .await
.context("build_routed_pipeline_with_preprocessor")?;
self.manager self.manager
.add_completions_model(&model_entry.name, completions_engine)?; .add_completions_model(&model_entry.name, completions_engine)
.context("add_completions_model")?;
tracing::info!("Completions is ready"); tracing::info!("Completions is ready");
} }
} else if model_entry.model_input == ModelInput::Text } else if model_entry.model_input == ModelInput::Text
...@@ -388,12 +382,6 @@ impl ModelWatcher { ...@@ -388,12 +382,6 @@ impl ModelWatcher {
&& model_entry.model_type.supports_embedding() && model_entry.model_type.supports_embedding()
{ {
// Case 4: Tokens + Embeddings // Case 4: Tokens + Embeddings
let Some(mut card) = card else {
anyhow::bail!("Missing model deployment card for embedding model");
};
// Download tokenizer files to local disk
let _cache_dir = Some(card.move_from_nats(self.drt.nats_client()).await?);
// Create preprocessing pipeline similar to Backend // Create preprocessing pipeline similar to Backend
let frontend = SegmentSource::< let frontend = SegmentSource::<
......
...@@ -252,8 +252,10 @@ impl LocalModelBuilder { ...@@ -252,8 +252,10 @@ impl LocalModelBuilder {
// --model-config takes precedence over --model-path // --model-config takes precedence over --model-path
let model_config_path = self.model_config.as_ref().unwrap_or(&full_path); let model_config_path = self.model_config.as_ref().unwrap_or(&full_path);
let mut card = let mut card = ModelDeploymentCard::load_from_disk(
ModelDeploymentCard::load(model_config_path, self.custom_template_path.as_deref())?; model_config_path,
self.custom_template_path.as_deref(),
)?;
// Usually we infer from the path, self.model_name is user override // Usually we infer from the path, self.model_name is user override
let model_name = self.model_name.take().unwrap_or_else(|| { let model_name = self.model_name.take().unwrap_or_else(|| {
......
...@@ -23,6 +23,8 @@ use crate::common::checked_file::CheckedFile; ...@@ -23,6 +23,8 @@ use crate::common::checked_file::CheckedFile;
use crate::local_model::runtime_config::ModelRuntimeConfig; use crate::local_model::runtime_config::ModelRuntimeConfig;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::DistributedRuntime;
use dynamo_runtime::storage::key_value_store::{EtcdStorage, KeyValueStore, KeyValueStoreManager};
use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats}; use dynamo_runtime::{slug::Slug, storage::key_value_store::Versioned, transports::nats};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer; use tokenizers::Tokenizer as HfTokenizer;
...@@ -141,6 +143,9 @@ pub struct ModelDeploymentCard { ...@@ -141,6 +143,9 @@ pub struct ModelDeploymentCard {
#[serde(default)] #[serde(default)]
pub runtime_config: ModelRuntimeConfig, pub runtime_config: ModelRuntimeConfig,
#[serde(skip)]
cache_dir: Option<Arc<tempfile::TempDir>>,
} }
impl ModelDeploymentCard { impl ModelDeploymentCard {
...@@ -228,9 +233,9 @@ impl ModelDeploymentCard { ...@@ -228,9 +233,9 @@ impl ModelDeploymentCard {
pub fn tokenizer_hf(&self) -> anyhow::Result<HfTokenizer> { pub fn tokenizer_hf(&self) -> anyhow::Result<HfTokenizer> {
match &self.tokenizer { match &self.tokenizer {
Some(TokenizerKind::HfTokenizerJson(checked_file)) => { Some(TokenizerKind::HfTokenizerJson(checked_file)) => {
let p = checked_file.path().ok_or_else(|| let p = checked_file.path().ok_or_else(|| {
anyhow::anyhow!("Tokenizer is URL-backed ({:?}); call move_from_nats() before tokenizer_hf()", checked_file.url()) anyhow::anyhow!("Tokenizer is URL-backed ({:?})", checked_file.url())
)?; })?;
HfTokenizer::from_file(p) HfTokenizer::from_file(p)
.inspect_err(|err| { .inspect_err(|err| {
if let Some(serde_err) = err.downcast_ref::<serde_json::Error>() if let Some(serde_err) = err.downcast_ref::<serde_json::Error>()
...@@ -240,6 +245,7 @@ impl ModelDeploymentCard { ...@@ -240,6 +245,7 @@ impl ModelDeploymentCard {
} }
}) })
.map_err(anyhow::Error::msg) .map_err(anyhow::Error::msg)
.with_context(|| p.display().to_string())
} }
Some(TokenizerKind::GGUF(t)) => Ok(*t.clone()), Some(TokenizerKind::GGUF(t)) => Ok(*t.clone()),
None => { None => {
...@@ -308,7 +314,7 @@ impl ModelDeploymentCard { ...@@ -308,7 +314,7 @@ impl ModelDeploymentCard {
/// Updates the URI's to point to the created files. /// Updates the URI's to point to the created files.
/// ///
/// The returned TempDir must be kept alive, it cleans up on drop. /// The returned TempDir must be kept alive, it cleans up on drop.
pub async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<tempfile::TempDir> { async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<tempfile::TempDir> {
let nats_addr = nats_client.addr(); let nats_addr = nats_client.addr();
let bucket_name = self.slug(); let bucket_name = self.slug();
let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?; let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?;
...@@ -388,7 +394,7 @@ impl ModelDeploymentCard { ...@@ -388,7 +394,7 @@ impl ModelDeploymentCard {
/// - a folder containing config.json, tokenizer.json and token_config.json /// - a folder containing config.json, tokenizer.json and token_config.json
/// - a GGUF file /// - a GGUF file
/// With an optional custom template /// With an optional custom template
pub fn load( pub fn load_from_disk(
config_path: impl AsRef<Path>, config_path: impl AsRef<Path>,
custom_template_path: Option<&Path>, custom_template_path: Option<&Path>,
) -> anyhow::Result<ModelDeploymentCard> { ) -> anyhow::Result<ModelDeploymentCard> {
...@@ -404,6 +410,29 @@ impl ModelDeploymentCard { ...@@ -404,6 +410,29 @@ impl ModelDeploymentCard {
} }
} }
/// Load a ModelDeploymentCard from storage the DistributedRuntime is configured to use.
/// Card should be fully local and ready to use when the call returns.
pub async fn load_from_store(
model_slug: &Slug,
drt: &DistributedRuntime,
) -> anyhow::Result<Option<Self>> {
let Some(etcd_client) = drt.etcd_client() else {
// Should be impossible because we only get here on an etcd event
anyhow::bail!("Missing etcd_client");
};
let store: Box<dyn KeyValueStore> = Box::new(EtcdStorage::new(etcd_client));
let card_store = Arc::new(KeyValueStoreManager::new(store));
let Some(mut card) = card_store
.load::<ModelDeploymentCard>(ROOT_PATH, model_slug)
.await?
else {
return Ok(None);
};
// This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive.
card.cache_dir = Some(Arc::new(card.move_from_nats(drt.nats_client()).await?));
Ok(Some(card))
}
/// Creates a ModelDeploymentCard from a local directory path. /// Creates a ModelDeploymentCard from a local directory path.
/// ///
/// Currently HuggingFace format is supported and following files are expected: /// Currently HuggingFace format is supported and following files are expected:
...@@ -474,6 +503,7 @@ impl ModelDeploymentCard { ...@@ -474,6 +503,7 @@ impl ModelDeploymentCard {
migration_limit: 0, migration_limit: 0,
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
cache_dir: None,
}) })
} }
...@@ -537,6 +567,7 @@ impl ModelDeploymentCard { ...@@ -537,6 +567,7 @@ impl ModelDeploymentCard {
migration_limit: 0, migration_limit: 0,
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
cache_dir: None,
}) })
} }
} }
......
...@@ -6,7 +6,8 @@ use dynamo_llm::model_card::ModelDeploymentCard; ...@@ -6,7 +6,8 @@ use dynamo_llm::model_card::ModelDeploymentCard;
#[test] #[test]
fn test_sequence_factory() { fn test_sequence_factory() {
let mdc = ModelDeploymentCard::load("tests/data/sample-models/TinyLlama_v1.1", None).unwrap(); let mdc = ModelDeploymentCard::load_from_disk("tests/data/sample-models/TinyLlama_v1.1", None)
.unwrap();
let operator = Backend::from_mdc(&mdc); let operator = Backend::from_mdc(&mdc);
......
...@@ -8,7 +8,7 @@ const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1"; ...@@ -8,7 +8,7 @@ const HF_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1";
#[tokio::test] #[tokio::test]
async fn test_model_info_from_hf_like_local_repo() { async fn test_model_info_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH, None).unwrap(); let mdc = ModelDeploymentCard::load_from_disk(HF_PATH, None).unwrap();
let info = mdc.model_info.unwrap().get_model_info().unwrap(); let info = mdc.model_info.unwrap().get_model_info().unwrap();
assert_eq!(info.model_type(), "llama"); assert_eq!(info.model_type(), "llama");
assert_eq!(info.bos_token_id(), 1); assert_eq!(info.bos_token_id(), 1);
...@@ -20,13 +20,13 @@ async fn test_model_info_from_hf_like_local_repo() { ...@@ -20,13 +20,13 @@ async fn test_model_info_from_hf_like_local_repo() {
#[tokio::test] #[tokio::test]
async fn test_model_info_from_non_existent_local_repo() { async fn test_model_info_from_non_existent_local_repo() {
let path = "tests/data/sample-models/this-model-does-not-exist"; let path = "tests/data/sample-models/this-model-does-not-exist";
let result = ModelDeploymentCard::load(path, None); let result = ModelDeploymentCard::load_from_disk(path, None);
assert!(result.is_err()); assert!(result.is_err());
} }
#[tokio::test] #[tokio::test]
async fn test_tokenizer_from_hf_like_local_repo() { async fn test_tokenizer_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH, None).unwrap(); let mdc = ModelDeploymentCard::load_from_disk(HF_PATH, None).unwrap();
// Verify tokenizer file was found // Verify tokenizer file was found
match mdc.tokenizer.unwrap() { match mdc.tokenizer.unwrap() {
TokenizerKind::HfTokenizerJson(_) => (), TokenizerKind::HfTokenizerJson(_) => (),
...@@ -36,7 +36,7 @@ async fn test_tokenizer_from_hf_like_local_repo() { ...@@ -36,7 +36,7 @@ async fn test_tokenizer_from_hf_like_local_repo() {
#[tokio::test] #[tokio::test]
async fn test_prompt_formatter_from_hf_like_local_repo() { async fn test_prompt_formatter_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load(HF_PATH, None).unwrap(); let mdc = ModelDeploymentCard::load_from_disk(HF_PATH, None).unwrap();
// Verify prompt formatter was found // Verify prompt formatter was found
match mdc.prompt_formatter { match mdc.prompt_formatter {
Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (), Some(PromptFormatterArtifact::HfTokenizerConfigJson(_)) => (),
...@@ -48,7 +48,7 @@ async fn test_prompt_formatter_from_hf_like_local_repo() { ...@@ -48,7 +48,7 @@ async fn test_prompt_formatter_from_hf_like_local_repo() {
async fn test_missing_required_files() { async fn test_missing_required_files() {
// Create empty temp directory // Create empty temp directory
let temp_dir = tempdir().unwrap(); let temp_dir = tempdir().unwrap();
let result = ModelDeploymentCard::load(temp_dir.path(), None); let result = ModelDeploymentCard::load_from_disk(temp_dir.path(), None);
assert!(result.is_err()); assert!(result.is_err());
let err = result.unwrap_err().to_string(); let err = result.unwrap_err().to_string();
// Should fail because config.json is missing // Should fail because config.json is missing
......
...@@ -57,7 +57,7 @@ async fn make_mdc_from_repo( ...@@ -57,7 +57,7 @@ async fn make_mdc_from_repo(
//TODO: remove this once we have nim-hub support. See the NOTE above. //TODO: remove this once we have nim-hub support. See the NOTE above.
let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await; let downloaded_path = maybe_download_model(local_path, hf_repo, hf_revision).await;
let display_name = format!("{}--{}", hf_repo, hf_revision); let display_name = format!("{}--{}", hf_repo, hf_revision);
let mut mdc = ModelDeploymentCard::load(downloaded_path, None).unwrap(); let mut mdc = ModelDeploymentCard::load_from_disk(downloaded_path, None).unwrap();
mdc.set_name(&display_name); mdc.set_name(&display_name);
mdc.prompt_context = mixins; mdc.prompt_context = mixins;
mdc mdc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment