Unverified Commit 9d03b8dc authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(frontend): Get model config files (`tokenizer.json` et al.) from MX (#3659)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent e93170d6
...@@ -100,6 +100,29 @@ impl CheckedFile { ...@@ -100,6 +100,29 @@ impl CheckedFile {
} }
} }
} }
/// Is the CheckedFile a path on disk that exists?
pub fn is_local(&self) -> bool {
match self.path.as_ref() {
Either::Left(path) => path.exists(),
Either::Right(_) => false, // is a Url
}
}
/// Keep the filename but change it's containing directory to `dir`.
/// This is used to point at a model file (e.g. `tokenizer.json`) in the HF cache dir.
pub fn update_dir(&mut self, dir: &Path) {
match self.path.as_mut() {
Either::Left(path) => {
if let Some(file_name) = path.file_name() {
let mut new_path = PathBuf::from(dir);
new_path.push(file_name);
*path = new_path;
}
}
Either::Right(_) => tracing::warn!("Cannot update directory on URL"),
}
}
} }
impl Display for CheckedFile { impl Display for CheckedFile {
......
...@@ -303,7 +303,8 @@ impl ModelWatcher { ...@@ -303,7 +303,8 @@ impl ModelWatcher {
endpoint_id: &EndpointId, endpoint_id: &EndpointId,
card: &mut ModelDeploymentCard, card: &mut ModelDeploymentCard,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
card.move_from_nats(self.drt.nats_client()).await?; card.download_config().await?;
let component = self let component = self
.drt .drt
.namespace(&endpoint_id.namespace)? .namespace(&endpoint_id.namespace)?
......
...@@ -421,10 +421,6 @@ impl LocalModel { ...@@ -421,10 +421,6 @@ impl LocalModel {
self.card.model_type = model_type; self.card.model_type = model_type;
self.card.model_input = model_input; self.card.model_input = model_input;
// Store model config files in NATS object store
let nats_client = endpoint.drt().nats_client();
self.card.move_to_nats(nats_client.clone()).await?;
// Publish the Model Deployment Card to KV store // Publish the Model Deployment Card to KV store
let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStore::new(etcd_client.clone())); let kvstore: Box<dyn KeyValueStore> = Box::new(EtcdStore::new(etcd_client.clone()));
let card_store = Arc::new(KeyValueStoreManager::new(kvstore)); let card_store = Arc::new(KeyValueStoreManager::new(kvstore));
......
...@@ -46,6 +46,18 @@ impl ModelInfoType { ...@@ -46,6 +46,18 @@ impl ModelInfoType {
ModelInfoType::HfConfigJson(c) => c.checksum().to_string(), ModelInfoType::HfConfigJson(c) => c.checksum().to_string(),
} }
} }
pub fn is_local(&self) -> bool {
match self {
ModelInfoType::HfConfigJson(c) => c.is_local(),
}
}
pub fn update_dir(&mut self, dir: &Path) {
match self {
ModelInfoType::HfConfigJson(c) => c.update_dir(dir),
}
}
} }
#[derive(Serialize, Deserialize, Clone, Debug)] #[derive(Serialize, Deserialize, Clone, Debug)]
...@@ -60,6 +72,18 @@ impl TokenizerKind { ...@@ -60,6 +72,18 @@ impl TokenizerKind {
TokenizerKind::HfTokenizerJson(c) => c.checksum().to_string(), TokenizerKind::HfTokenizerJson(c) => c.checksum().to_string(),
} }
} }
pub fn is_local(&self) -> bool {
match self {
TokenizerKind::HfTokenizerJson(c) => c.is_local(),
}
}
pub fn update_dir(&mut self, dir: &Path) {
match self {
TokenizerKind::HfTokenizerJson(c) => c.update_dir(dir),
}
}
} }
/// Supported types of prompt formatters. /// Supported types of prompt formatters.
...@@ -88,6 +112,21 @@ impl PromptFormatterArtifact { ...@@ -88,6 +112,21 @@ impl PromptFormatterArtifact {
PromptFormatterArtifact::HfChatTemplate(c) => c.checksum().to_string(), PromptFormatterArtifact::HfChatTemplate(c) => c.checksum().to_string(),
} }
} }
/// Is this file available locally
pub fn is_local(&self) -> bool {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.is_local(),
PromptFormatterArtifact::HfChatTemplate(c) => c.is_local(),
}
}
pub fn update_dir(&mut self, dir: &Path) {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.update_dir(dir),
PromptFormatterArtifact::HfChatTemplate(c) => c.update_dir(dir),
}
}
} }
#[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)] #[derive(Serialize, Deserialize, Clone, Debug, PartialEq, Eq, Hash)]
...@@ -112,6 +151,18 @@ impl GenerationConfig { ...@@ -112,6 +151,18 @@ impl GenerationConfig {
GenerationConfig::HfGenerationConfigJson(c) => c.checksum().to_string(), GenerationConfig::HfGenerationConfigJson(c) => c.checksum().to_string(),
} }
} }
pub fn is_local(&self) -> bool {
match self {
GenerationConfig::HfGenerationConfigJson(c) => c.is_local(),
}
}
pub fn update_dir(&mut self, dir: &Path) {
match self {
GenerationConfig::HfGenerationConfigJson(c) => c.update_dir(dir),
}
}
} }
#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)] #[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
...@@ -170,9 +221,6 @@ pub struct ModelDeploymentCard { ...@@ -170,9 +221,6 @@ pub struct ModelDeploymentCard {
#[serde(default)] #[serde(default)]
pub runtime_config: ModelRuntimeConfig, pub runtime_config: ModelRuntimeConfig,
#[serde(skip)]
cache_dir: Option<Arc<tempfile::TempDir>>,
#[serde(skip, default)] #[serde(skip, default)]
checksum: OnceLock<String>, checksum: OnceLock<String>,
} }
...@@ -304,114 +352,6 @@ impl ModelDeploymentCard { ...@@ -304,114 +352,6 @@ impl ModelDeploymentCard {
} }
} }
/// Move the files this MDC uses into the NATS object store.
/// Updates the URI's to point to NATS.
pub async fn move_to_nats(&mut self, nats_client: nats::Client) -> Result<()> {
let nats_addr = nats_client.addr();
let bucket_name = self.slug().clone();
tracing::debug!(
nats_addr,
%bucket_name,
"Uploading model deployment card fields to NATS"
);
macro_rules! nats_upload {
($field:expr, $enum_variant:path, $filename:literal) => {
if let Some($enum_variant(src_file)) = $field.as_mut()
&& let Some(path) = src_file.path()
{
let target = format!("nats://{nats_addr}/{bucket_name}/{}", $filename);
let dest = url::Url::parse(&target)?;
nats_client.object_store_upload(path, &dest).await?;
src_file.move_to_url(dest);
}
};
}
nats_upload!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
nats_upload!(
self.gen_config,
GenerationConfig::HfGenerationConfigJson,
"generation_config.json"
);
nats_upload!(
self.prompt_formatter,
PromptFormatterArtifact::HfTokenizerConfigJson,
"tokenizer_config.json"
);
nats_upload!(
self.chat_template_file,
PromptFormatterArtifact::HfChatTemplate,
"chat_template.jinja"
);
nats_upload!(
self.tokenizer,
TokenizerKind::HfTokenizerJson,
"tokenizer.json"
);
Ok(())
}
/// Move the files this MDC uses from the NATS object store to local disk.
/// Updates the URI's to point to the created files.
pub async fn move_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
let nats_addr = nats_client.addr();
let bucket_name = self.slug();
let target_dir = tempfile::TempDir::with_prefix(bucket_name.to_string())?;
tracing::debug!(
nats_addr,
%bucket_name,
target_dir = %target_dir.path().display(),
"Downloading model deployment card fields from NATS"
);
macro_rules! nats_download {
($field:expr, $enum_variant:path, $filename:literal) => {
if let Some($enum_variant(src_file)) = $field.as_mut()
&& let Some(src_url) = src_file.url()
{
let target = target_dir.path().join($filename);
nats_client.object_store_download(src_url, &target).await?;
if !src_file.checksum_matches(&target) {
anyhow::bail!(
"Invalid {} in NATS for {}, checksum does not match.",
$filename,
self.display_name
);
}
src_file.move_to_disk(target);
}
};
}
nats_download!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
nats_download!(
self.gen_config,
GenerationConfig::HfGenerationConfigJson,
"generation_config.json"
);
nats_download!(
self.prompt_formatter,
PromptFormatterArtifact::HfTokenizerConfigJson,
"tokenizer_config.json"
);
nats_download!(
self.chat_template_file,
PromptFormatterArtifact::HfChatTemplate,
"chat_template.jinja"
);
nats_download!(
self.tokenizer,
TokenizerKind::HfTokenizerJson,
"tokenizer.json"
);
// This cache_dir is a tempfile::TempDir will be deleted on drop, so keep it alive.
self.cache_dir = Some(Arc::new(target_dir));
Ok(())
}
/// Delete this card from the key-value store and it's URLs from the object store /// Delete this card from the key-value store and it's URLs from the object store
pub async fn delete_from_nats(&mut self, nats_client: nats::Client) -> Result<()> { pub async fn delete_from_nats(&mut self, nats_client: nats::Client) -> Result<()> {
let nats_addr = nats_client.addr(); let nats_addr = nats_client.addr();
...@@ -465,10 +405,80 @@ impl ModelDeploymentCard { ...@@ -465,10 +405,80 @@ impl ModelDeploymentCard {
else { else {
return Ok(None); return Ok(None);
}; };
card.move_from_nats(drt.nats_client()).await?;
card.download_config().await?;
Ok(Some(card)) Ok(Some(card))
} }
/// Download the files this card needs to work: config.json, tokenizer.json, etc.
pub async fn download_config(&mut self) -> anyhow::Result<()> {
if self.has_local_files() {
tracing::trace!("All model config is local, not downloading");
return Ok(());
}
let ignore_weights = true;
let local_path = crate::hub::from_hf(&self.display_name, ignore_weights).await?;
self.update_dir(&local_path);
Ok(())
}
/// Are all the files we need (tokenizer.json, etc) available locally?
fn has_local_files(&self) -> bool {
let has_model_info = self
.model_info
.as_ref()
.map(|p| p.is_local())
.unwrap_or(true);
let has_tokenizer = self
.tokenizer
.as_ref()
.map(|p| p.is_local())
.unwrap_or(true);
let has_prompt_formatter = self
.prompt_formatter
.as_ref()
.map(|p| p.is_local())
.unwrap_or(true);
let has_chat_template_file = self
.chat_template_file
.as_ref()
.map(|p| p.is_local())
.unwrap_or(true);
let has_gen_config = self
.gen_config
.as_ref()
.map(|p| p.is_local())
.unwrap_or(true);
has_model_info
&& has_tokenizer
&& has_prompt_formatter
&& has_chat_template_file
&& has_gen_config
}
/// Update the directory for files like tokenizer.json be in here.
fn update_dir(&mut self, dir: &Path) {
if let Some(model_info) = self.model_info.as_mut() {
model_info.update_dir(dir);
}
if let Some(tk) = self.tokenizer.as_mut() {
tk.update_dir(dir);
}
if let Some(pf) = self.prompt_formatter.as_mut() {
pf.update_dir(dir);
}
if let Some(ct) = self.chat_template_file.as_mut() {
ct.update_dir(dir);
}
if let Some(gc) = self.gen_config.as_mut() {
gc.update_dir(dir);
}
}
/// Creates a ModelDeploymentCard from a local directory path. /// Creates a ModelDeploymentCard from a local directory path.
/// ///
/// Currently HuggingFace format is supported and following files are expected: /// Currently HuggingFace format is supported and following files are expected:
...@@ -552,7 +562,6 @@ impl ModelDeploymentCard { ...@@ -552,7 +562,6 @@ impl ModelDeploymentCard {
model_input: Default::default(), // set later model_input: Default::default(), // set later
user_data: None, user_data: None,
runtime_config: ModelRuntimeConfig::default(), runtime_config: ModelRuntimeConfig::default(),
cache_dir: None,
checksum: OnceLock::new(), checksum: OnceLock::new(),
}) })
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment