Unverified Commit cfd12d7f authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Support larger Gemma 3 models (#1359)

Publish `generation_config.json` from worker to ingress, as part of Model Deployment Card. That allows ingress to read key fields out of it. Gemma 3 4B+ has some important information that's only in there.
parent 9e8731e5
...@@ -106,8 +106,12 @@ impl ModelWatcher { ...@@ -106,8 +106,12 @@ impl ModelWatcher {
tracing::info!(model_name = model_entry.name, "added model"); tracing::info!(model_name = model_entry.name, "added model");
self.notify_on_model.notify_waiters(); self.notify_on_model.notify_waiters();
} }
Err(e) => { Err(err) => {
tracing::error!(%e, "error adding model {}", model_entry.name); tracing::error!(
error = format!("{err:#}"),
"error adding model {}",
model_entry.name
);
} }
} }
} }
......
...@@ -6,6 +6,10 @@ ...@@ -6,6 +6,10 @@
//! The `dynamo.llm` crate is a Rust library that provides a set of traits and types for building //! The `dynamo.llm` crate is a Rust library that provides a set of traits and types for building
//! distributed LLM inference solutions. //! distributed LLM inference solutions.
use std::{fs::File, io::BufReader, path::Path};
use anyhow::Context as _;
pub mod backend; pub mod backend;
pub mod common; pub mod common;
pub mod disagg_router; pub mod disagg_router;
...@@ -30,3 +34,233 @@ pub mod types; ...@@ -30,3 +34,233 @@ pub mod types;
#[cfg(feature = "block-manager")] #[cfg(feature = "block-manager")]
pub mod block_manager; pub mod block_manager;
/// Reads a JSON file, extracts a specific field, and deserializes it into type T.
///
/// # Arguments
///
/// * `json_file_path`: Path to the JSON file.
/// * `field_name`: The name of the field to extract from the JSON map.
///
/// # Returns
///
/// A `Result` containing the deserialized value of type `T` if successful,
/// or an `anyhow::Error` if any step fails (file I/O, JSON parsing, field not found,
/// or deserialization to `T` fails).
///
/// # Type Parameters
///
/// * `T`: The expected type of the field's value. `T` must implement `serde::de::DeserializeOwned`.
pub fn file_json_field<T: serde::de::DeserializeOwned>(
json_file_path: &Path,
field_name: &str,
) -> anyhow::Result<T> {
// 1. Open the file
let file = File::open(json_file_path)
.with_context(|| format!("Failed to open file: {:?}", json_file_path))?;
let reader = BufReader::new(file);
// 2. Parse the JSON file into a generic serde_json::Value
// We parse into `serde_json::Value` first because we need to look up a specific field.
// If we tried to deserialize directly into `T`, `T` would need to represent the whole JSON structure.
let json_data: serde_json::Value = serde_json::from_reader(reader)
.with_context(|| format!("Failed to parse JSON from file: {:?}", json_file_path))?;
// 3. Ensure the root of the JSON is an object (map)
let map = json_data.as_object().ok_or_else(|| {
anyhow::anyhow!("JSON root is not an object in file: {:?}", json_file_path)
})?;
// 4. Get the specific field's value
let field_value = map.get(field_name).ok_or_else(|| {
anyhow::anyhow!(
"Field '{}' not found in JSON file: {:?}",
field_name,
json_file_path
)
})?;
// 5. Deserialize the field's value into the target type T
// We need to clone `field_value` because `from_value` consumes its input.
serde_json::from_value(field_value.clone()).with_context(|| {
format!(
"Failed to deserialize field '{}' (value: {:?}) to the expected type from file: {:?}",
field_name, field_value, json_file_path
)
})
}
#[cfg(test)]
mod file_json_field_tests {
use super::file_json_field;
use serde::Deserialize;
use std::fs::File;
use std::io::Write;
use std::path::{Path, PathBuf};
use tempfile::tempdir;
// Helper function to create a temporary JSON file
fn create_temp_json_file(dir: &Path, file_name: &str, content: &str) -> PathBuf {
let file_path = dir.join(file_name);
let mut file = File::create(&file_path)
.unwrap_or_else(|_| panic!("Failed to create test file: {:?}", file_path));
file.write_all(content.as_bytes())
.unwrap_or_else(|_| panic!("Failed to write to test file: {:?}", file_path));
file_path
}
// Define a custom struct for testing deserialization
#[derive(Debug, PartialEq, Deserialize)]
struct MyConfig {
version: String,
enabled: bool,
count: u32,
}
#[test]
fn test_success_basic() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"test_basic.json",
r#"{ "name": "Rust", "age": 30, "is_active": true }"#,
);
let name: String = file_json_field(&file_path, "name").unwrap();
assert_eq!(name, "Rust");
let age: i32 = file_json_field(&file_path, "age").unwrap();
assert_eq!(age, 30);
let is_active: bool = file_json_field(&file_path, "is_active").unwrap();
assert!(is_active);
}
#[test]
fn test_success_custom_struct_field() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"test_struct.json",
r#"{
"config": {
"version": "1.0.0",
"enabled": true,
"count": 123
},
"other_field": "value"
}"#,
);
let config: MyConfig = file_json_field(&file_path, "config").unwrap();
assert_eq!(
config,
MyConfig {
version: "1.0.0".to_string(),
enabled: true,
count: 123,
}
);
}
#[test]
fn test_file_not_found() {
let tmp_dir = tempdir().unwrap();
let non_existent_path = tmp_dir.path().join("non_existent.json");
let result: anyhow::Result<String> = file_json_field(&non_existent_path, "field");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Failed to open file"));
}
#[test]
fn test_invalid_json_syntax() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"invalid.json",
r#"{ "key": "value", "bad_syntax": }"#, // Malformed JSON
);
let result: anyhow::Result<String> = file_json_field(&file_path, "key");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Failed to parse JSON from file"));
}
#[test]
fn test_json_root_not_object_array() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"root_array.json",
r#"[ { "item": 1 }, { "item": 2 } ]"#, // Root is an array
);
let result: anyhow::Result<String> = file_json_field(&file_path, "item");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("JSON root is not an object"));
}
#[test]
fn test_json_root_not_object_primitive() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"root_primitive.json",
r#""just_a_string""#, // Root is a string
);
let result: anyhow::Result<String> = file_json_field(&file_path, "field");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("JSON root is not an object"));
}
#[test]
fn test_field_not_found() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"missing_field.json",
r#"{ "existing_field": "hello" }"#,
);
let result: anyhow::Result<String> = file_json_field(&file_path, "non_existent_field");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err
.to_string()
.contains("Field 'non_existent_field' not found"));
}
#[test]
fn test_field_type_mismatch() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"type_mismatch.json",
r#"{ "count": "not_an_integer" }"#,
);
let result: anyhow::Result<u32> = file_json_field(&file_path, "count");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err
.to_string()
.contains("Failed to deserialize field 'count'"));
}
#[test]
fn test_empty_file() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(tmp_dir.path(), "empty.json", "");
let result: anyhow::Result<String> = file_json_field(&file_path, "field");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Failed to parse JSON from file"));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
pub mod create; pub mod create;
pub mod model; pub mod model;
pub use model::ModelDeploymentCard; pub use model::ModelDeploymentCard;
// TODO: Do these network/publish related model deployment card values belong here or in a
// network module?
/// Identify model deployment cards in the key-value store /// Identify model deployment cards in the key-value store
pub const ROOT_PATH: &str = "mdc"; pub const ROOT_PATH: &str = "mdc";
/// Delete model deployment cards that haven't been re-published after this long.
/// Cleans up if the worker stopped.
pub const BUCKET_TTL: Duration = Duration::from_secs(5 * 60);
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use crate::model_card::model::ModelDeploymentCard; use crate::model_card::model::ModelDeploymentCard;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use std::fs::{self, File};
use std::io::BufReader;
use std::path::{Path, PathBuf}; use std::path::{Path, PathBuf};
use crate::model_card::model::{ModelInfoType, PromptFormatterArtifact, TokenizerKind}; use crate::model_card::model::{ModelInfoType, PromptFormatterArtifact, TokenizerKind};
use super::model::GenerationConfig;
impl ModelDeploymentCard { impl ModelDeploymentCard {
/// Allow user to override the name we register this model under. /// Allow user to override the name we register this model under.
/// Corresponds to vllm's `--served-model-name`. /// Corresponds to vllm's `--served-model-name`.
...@@ -98,6 +84,7 @@ impl ModelDeploymentCard { ...@@ -98,6 +84,7 @@ impl ModelDeploymentCard {
service_name: model_name.to_string(), service_name: model_name.to_string(),
model_info: Some(ModelInfoType::GGUF(gguf_file.to_path_buf())), model_info: Some(ModelInfoType::GGUF(gguf_file.to_path_buf())),
tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?), tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?),
gen_config: None, // AFAICT there is no equivalent in a GGUF
prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())), prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())),
prompt_context: None, // TODO - auto-detect prompt context prompt_context: None, // TODO - auto-detect prompt context
revision: 0, revision: 0,
...@@ -116,14 +103,14 @@ impl ModelDeploymentCard { ...@@ -116,14 +103,14 @@ impl ModelDeploymentCard {
async fn from_repo(repo_id: &str, model_name: &str) -> anyhow::Result<Self> { async fn from_repo(repo_id: &str, model_name: &str) -> anyhow::Result<Self> {
// This is usually the right choice // This is usually the right choice
let context_length = file_json_field( let context_length = crate::file_json_field(
&Path::join(&PathBuf::from(repo_id), "config.json"), &PathBuf::from(repo_id).join("config.json"),
"max_position_embeddings", "max_position_embeddings",
) )
// But sometimes this is // But sometimes this is
.or_else(|_| { .or_else(|_| {
file_json_field( crate::file_json_field(
&Path::join(&PathBuf::from(repo_id), "tokenizer_config.json"), &PathBuf::from(repo_id).join("tokenizer_config.json"),
"model_max_length", "model_max_length",
) )
}) })
...@@ -135,6 +122,7 @@ impl ModelDeploymentCard { ...@@ -135,6 +122,7 @@ impl ModelDeploymentCard {
service_name: model_name.to_string(), service_name: model_name.to_string(),
model_info: Some(ModelInfoType::from_repo(repo_id).await?), model_info: Some(ModelInfoType::from_repo(repo_id).await?),
tokenizer: Some(TokenizerKind::from_repo(repo_id).await?), tokenizer: Some(TokenizerKind::from_repo(repo_id).await?),
gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?, prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?,
prompt_context: None, // TODO - auto-detect prompt context prompt_context: None, // TODO - auto-detect prompt context
revision: 0, revision: 0,
...@@ -190,37 +178,28 @@ impl TokenizerKind { ...@@ -190,37 +178,28 @@ impl TokenizerKind {
} }
} }
/// Checks if the provided path contains the expected file. impl GenerationConfig {
async fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<String> { pub async fn from_repo(repo_id: &str) -> Result<Self> {
let mut files = check_for_files(repo_id, vec![file.to_string()]).await?; Self::try_is_hf_repo(repo_id)
let file = files .await
.remove(file) .with_context(|| format!("unable to extract generation config from repo {repo_id}"))
.ok_or(anyhow::anyhow!("file {} not found", file))?; }
Ok(file)
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfGenerationConfigJson(
check_for_file(repo, "generation_config.json").await?,
))
}
} }
async fn check_for_files(repo_id: &str, files: Vec<String>) -> Result<HashMap<String, String>> { /// Checks if the provided path contains the expected file.
let dir_entries = async fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<String> {
fs::read_dir(repo_id).with_context(|| format!("Failed to read directory: {}", repo_id))?; let p = PathBuf::from(repo_id).join(file);
let mut found_files = HashMap::new(); let name = p.display().to_string();
for entry in dir_entries { if !p.exists() {
let entry = anyhow::bail!("File not found: {name}")
entry.with_context(|| format!("Failed to read directory entry in {}", repo_id))?;
let path = entry.path();
let file_name = path
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Invalid file name in {}", repo_id))?;
if files.contains(&file_name.to_string()) {
found_files.insert(
file_name.to_string(),
path.to_str()
.ok_or_else(|| anyhow::anyhow!("Invalid path"))?
.to_string(),
);
}
} }
Ok(found_files) Ok(name)
} }
/// Checks if the provided path is a valid local repository path. /// Checks if the provided path is a valid local repository path.
...@@ -247,58 +226,3 @@ fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> { ...@@ -247,58 +226,3 @@ fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> {
} }
Ok(()) Ok(())
} }
/// Reads a JSON file, extracts a specific field, and deserializes it into type T.
///
/// # Arguments
///
/// * `json_file_path`: Path to the JSON file.
/// * `field_name`: The name of the field to extract from the JSON map.
///
/// # Returns
///
/// A `Result` containing the deserialized value of type `T` if successful,
/// or an `anyhow::Error` if any step fails (file I/O, JSON parsing, field not found,
/// or deserialization to `T` fails).
///
/// # Type Parameters
///
/// * `T`: The expected type of the field's value. `T` must implement `serde::de::DeserializeOwned`.
fn file_json_field<T: serde::de::DeserializeOwned>(
json_file_path: &Path,
field_name: &str,
) -> anyhow::Result<T> {
// 1. Open the file
let file = File::open(json_file_path)
.with_context(|| format!("Failed to open file: {:?}", json_file_path))?;
let reader = BufReader::new(file);
// 2. Parse the JSON file into a generic serde_json::Value
// We parse into `serde_json::Value` first because we need to look up a specific field.
// If we tried to deserialize directly into `T`, `T` would need to represent the whole JSON structure.
let json_data: serde_json::Value = serde_json::from_reader(reader)
.with_context(|| format!("Failed to parse JSON from file: {:?}", json_file_path))?;
// 3. Ensure the root of the JSON is an object (map)
let map = json_data.as_object().ok_or_else(|| {
anyhow::anyhow!("JSON root is not an object in file: {:?}", json_file_path)
})?;
// 4. Get the specific field's value
let field_value = map.get(field_name).ok_or_else(|| {
anyhow::anyhow!(
"Field '{}' not found in JSON file: {:?}",
field_name,
json_file_path
)
})?;
// 5. Deserialize the field's value into the target type T
// We need to clone `field_value` because `from_value` consumes its input.
serde_json::from_value(field_value.clone()).with_context(|| {
format!(
"Failed to deserialize field '{}' (value: {:?}) to the expected type from file: {:?}",
field_name, field_value, json_file_path
)
})
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0 // SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Model Deployment Card //! # Model Deployment Card
//! //!
...@@ -35,7 +23,6 @@ use anyhow::{Context, Result}; ...@@ -35,7 +23,6 @@ use anyhow::{Context, Result};
use derive_builder::Builder; use derive_builder::Builder;
use dynamo_runtime::slug::Slug; use dynamo_runtime::slug::Slug;
use dynamo_runtime::transports::nats; use dynamo_runtime::transports::nats;
use either::Either;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer; use tokenizers::Tokenizer as HfTokenizer;
use url::Url; use url::Url;
...@@ -44,10 +31,6 @@ use crate::gguf::{Content, ContentConfig, ModelConfigLike}; ...@@ -44,10 +31,6 @@ use crate::gguf::{Content, ContentConfig, ModelConfigLike};
use crate::key_value_store::Versioned; use crate::key_value_store::Versioned;
use crate::protocols::TokenIdType; use crate::protocols::TokenIdType;
/// Delete model deployment cards that haven't been re-published after this long.
/// Cleans up if the worker stopped.
pub const BUCKET_TTL: Duration = Duration::from_secs(5 * 60);
/// If a model deployment card hasn't been refreshed in this much time the worker is likely gone /// If a model deployment card hasn't been refreshed in this much time the worker is likely gone
const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5); const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5);
...@@ -94,6 +77,13 @@ pub enum PromptContextMixin { ...@@ -94,6 +77,13 @@ pub enum PromptContextMixin {
Llama3DateTime, Llama3DateTime,
} }
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum GenerationConfig {
HfGenerationConfigJson(String),
GGUF(PathBuf),
}
#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)] #[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
pub struct ModelDeploymentCard { pub struct ModelDeploymentCard {
/// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct" /// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct"
...@@ -113,6 +103,10 @@ pub struct ModelDeploymentCard { ...@@ -113,6 +103,10 @@ pub struct ModelDeploymentCard {
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_formatter: Option<PromptFormatterArtifact>, pub prompt_formatter: Option<PromptFormatterArtifact>,
/// Generation config - default sampling params
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gen_config: Option<GenerationConfig>,
/// Prompt Formatter Config /// Prompt Formatter Config
#[serde(default, skip_serializing_if = "Option::is_none")] #[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_context: Option<Vec<PromptContextMixin>>, pub prompt_context: Option<Vec<PromptContextMixin>>,
...@@ -244,38 +238,39 @@ impl ModelDeploymentCard { ...@@ -244,38 +238,39 @@ impl ModelDeploymentCard {
"Uploading model deployment card fields to NATS" "Uploading model deployment card fields to NATS"
); );
if let Some(ModelInfoType::HfConfigJson(ref src_file)) = self.model_info { macro_rules! nats_upload {
if !nats::is_nats_url(src_file) { ($field:expr, $enum_variant:path, $filename:literal) => {
let target = format!("nats://{nats_addr}/{bucket_name}/config.json"); if let Some($enum_variant(src_file)) = $field.take() {
nats_client if !nats::is_nats_url(&src_file) {
.object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?) let target = format!("nats://{nats_addr}/{bucket_name}/{}", $filename);
.await?; nats_client
self.model_info = Some(ModelInfoType::HfConfigJson(target)); .object_store_upload(
} &std::path::PathBuf::from(&src_file),
url::Url::parse(&target)?,
)
.await?;
$field = Some($enum_variant(target));
}
}
};
} }
if let Some(PromptFormatterArtifact::HfTokenizerConfigJson(ref src_file)) = nats_upload!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
self.prompt_formatter nats_upload!(
{ self.prompt_formatter,
if !nats::is_nats_url(src_file) { PromptFormatterArtifact::HfTokenizerConfigJson,
let target = format!("nats://{nats_addr}/{bucket_name}/tokenizer_config.json"); "tokenizer_config.json"
nats_client );
.object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?) nats_upload!(
.await?; self.tokenizer,
self.prompt_formatter = TokenizerKind::HfTokenizerJson,
Some(PromptFormatterArtifact::HfTokenizerConfigJson(target)); "tokenizer.json"
} );
} nats_upload!(
self.gen_config,
if let Some(TokenizerKind::HfTokenizerJson(ref src_file)) = self.tokenizer { GenerationConfig::HfGenerationConfigJson,
if !nats::is_nats_url(src_file) { "generation_config.json"
let target = format!("nats://{nats_addr}/{bucket_name}/tokenizer.json"); );
nats_client
.object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?)
.await?;
self.tokenizer = Some(TokenizerKind::HfTokenizerJson(target));
}
}
Ok(()) Ok(())
} }
...@@ -295,39 +290,36 @@ impl ModelDeploymentCard { ...@@ -295,39 +290,36 @@ impl ModelDeploymentCard {
"Downloading model deployment card fields from NATS" "Downloading model deployment card fields from NATS"
); );
if let Some(ModelInfoType::HfConfigJson(ref src_url)) = self.model_info { macro_rules! nats_download {
if nats::is_nats_url(src_url) { ($field:expr, $enum_variant:path, $filename:literal) => {
let target = target_dir.path().join("config.json"); if let Some($enum_variant(src_url)) = $field.take() {
nats_client if nats::is_nats_url(&src_url) {
.object_store_download(Url::parse(src_url)?, &target) let target = target_dir.path().join($filename);
.await?; nats_client
self.model_info = Some(ModelInfoType::HfConfigJson(target.display().to_string())); .object_store_download(Url::parse(&src_url)?, &target)
} .await?;
$field = Some($enum_variant(target.display().to_string()));
}
}
};
} }
if let Some(PromptFormatterArtifact::HfTokenizerConfigJson(ref src_url)) = nats_download!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
self.prompt_formatter nats_download!(
{ self.prompt_formatter,
if nats::is_nats_url(src_url) { PromptFormatterArtifact::HfTokenizerConfigJson,
let target = target_dir.path().join("tokenizer_config.json"); "tokenizer_config.json"
nats_client );
.object_store_download(Url::parse(src_url)?, &target) nats_download!(
.await?; self.tokenizer,
self.prompt_formatter = Some(PromptFormatterArtifact::HfTokenizerConfigJson( TokenizerKind::HfTokenizerJson,
target.display().to_string(), "tokenizer.json"
)); );
} nats_download!(
} self.gen_config,
GenerationConfig::HfGenerationConfigJson,
if let Some(TokenizerKind::HfTokenizerJson(ref src_url)) = self.tokenizer { "generation_config.json"
if nats::is_nats_url(src_url) { );
let target = target_dir.path().join("tokenizer.json");
nats_client
.object_store_download(Url::parse(src_url)?, &target)
.await?;
self.tokenizer = Some(TokenizerKind::HfTokenizerJson(target.display().to_string()));
}
}
Ok(target_dir) Ok(target_dir)
} }
...@@ -374,10 +366,12 @@ pub trait ModelInfo: Send + Sync { ...@@ -374,10 +366,12 @@ pub trait ModelInfo: Send + Sync {
fn eos_token_ids(&self) -> Vec<TokenIdType>; fn eos_token_ids(&self) -> Vec<TokenIdType>;
/// Maximum position embeddings / max sequence length /// Maximum position embeddings / max sequence length
fn max_position_embeddings(&self) -> usize; /// TODO: This is only used in a single test, no other code. Remove?
fn max_position_embeddings(&self) -> Option<usize>;
/// Vocabulary size /// Vocabulary size
fn vocab_size(&self) -> usize; /// TODO: This is only used in a single test, no other code. Remove?
fn vocab_size(&self) -> Option<usize>;
} }
impl ModelInfoType { impl ModelInfoType {
...@@ -402,36 +396,123 @@ struct HFConfig { ...@@ -402,36 +396,123 @@ struct HFConfig {
model_type: String, model_type: String,
text_config: Option<HFTextConfig>, text_config: Option<HFTextConfig>,
// Sometimes it's inside HFTextConfig, sometimes it's here
eos_token_id: Option<serde_json::Value>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
struct HFTextConfig { struct HFTextConfig {
bos_token_id: TokenIdType, // It can take multiple attempts to load this, so Option
bos_token_id: Option<TokenIdType>,
// We set this once bos_token_id is loaded so we don't have to deal with Option
#[serde(default)]
final_bos_token_id: TokenIdType,
#[serde(with = "either::serde_untagged")] eos_token_id: Option<serde_json::Value>,
eos_token_id: Either<TokenIdType, Vec<TokenIdType>>,
#[serde(default)]
final_eos_token_ids: Vec<TokenIdType>,
/// max sequence length /// max sequence length
max_position_embeddings: usize, max_position_embeddings: Option<usize>,
/// number of layers in the model /// number of layers in the model
num_hidden_layers: usize, num_hidden_layers: usize,
/// number of attention heads in the model /// number of attention heads in the model
num_attention_heads: usize, num_attention_heads: Option<usize>,
/// Vocabulary size /// Vocabulary size
vocab_size: usize, vocab_size: Option<usize>,
} }
impl HFConfig { impl HFConfig {
async fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> { async fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> {
let file_pathbuf = PathBuf::from(file);
let contents = std::fs::read_to_string(file)?; let contents = std::fs::read_to_string(file)?;
let mut config: Self = serde_json::from_str(&contents)?; let mut config: Self = serde_json::from_str(&contents)?;
if config.text_config.is_none() { if config.text_config.is_none() {
let text_config: HFTextConfig = serde_json::from_str(&contents)?; let text_config: HFTextConfig = serde_json::from_str(&contents)?;
config.text_config = Some(text_config); config.text_config = Some(text_config);
} }
// Sometimes bos_token_id is in generation_config.json not config.json
let Some(text_config) = config.text_config.as_mut() else {
anyhow::bail!(
"Missing text config fields (model_type, eos_token_ids, etc) in config.json"
);
};
if text_config.bos_token_id.is_none() {
let bos_token_id = crate::file_json_field::<TokenIdType>(
&Path::join(
file_pathbuf.parent().unwrap_or(&PathBuf::from("")),
"generation_config.json",
),
"bos_token_id",
)
.context(
"missing bos_token_id in generation_config.json and config.json, cannot load",
)?;
text_config.bos_token_id = Some(bos_token_id);
}
// Now that we have it for sure, set it in the non-Option field
let final_bos_token_id = text_config.bos_token_id.take().unwrap();
text_config.final_bos_token_id = final_bos_token_id;
// TODO: refactor this when we switch to per-architecture tokenization
let final_eos_token_ids: Vec<TokenIdType> = config
.eos_token_id
.as_ref()
.or(text_config.eos_token_id.as_ref())
.and_then(|v| {
if v.is_number() {
v.as_number()
.and_then(|n| n.as_u64())
.map(|n| vec![n as TokenIdType])
} else if v.is_array() {
let arr = v.as_array().unwrap(); // Safety: We just checked
Some(
arr.iter()
.filter_map(|inner_v| {
inner_v
.as_number()
.and_then(|n| n.as_u64())
.map(|n| n as TokenIdType)
})
.collect(),
)
} else {
tracing::error!(
?v,
file,
"eos_token_id is not a number or an array, cannot use"
);
None
}
})
.or_else(|| {
// Maybe it's in generation_config.json
crate::file_json_field(
&Path::join(
file_pathbuf.parent().unwrap_or(&PathBuf::from("")),
"generation_config.json",
),
"eos_token_id",
)
.inspect_err(
|err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"),
)
.ok()
})
.ok_or_else(|| {
anyhow::anyhow!(
"missing eos_token_id in config.json and generation_config.json, cannot load"
)
})?;
text_config.final_eos_token_ids = final_eos_token_ids;
Ok(Arc::new(config)) Ok(Arc::new(config))
} }
fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> { fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> {
...@@ -454,17 +535,22 @@ impl HFConfig { ...@@ -454,17 +535,22 @@ impl HFConfig {
// "general.architecture" // "general.architecture"
model_type: arch, model_type: arch,
text_config: Some(HFTextConfig { text_config: Some(HFTextConfig {
bos_token_id, bos_token_id: None,
eos_token_id: Either::Left(eos_token_id), final_bos_token_id: bos_token_id,
eos_token_id: None,
final_eos_token_ids: vec![eos_token_id],
// "llama.context_length" // "llama.context_length"
max_position_embeddings: model_config_metadata.max_seq_len(), max_position_embeddings: Some(model_config_metadata.max_seq_len()),
// "llama.block_count" // "llama.block_count"
num_hidden_layers, num_hidden_layers,
// "llama.attention.head_count" // "llama.attention.head_count"
num_attention_heads: model_config_metadata.num_attn_heads(), num_attention_heads: Some(model_config_metadata.num_attn_heads()),
// "tokenizer.ggml.tokens".len() // "tokenizer.ggml.tokens".len()
vocab_size, vocab_size: Some(vocab_size),
}), }),
eos_token_id: None,
})) }))
} }
} }
...@@ -475,21 +561,22 @@ impl ModelInfo for HFConfig { ...@@ -475,21 +561,22 @@ impl ModelInfo for HFConfig {
} }
fn bos_token_id(&self) -> TokenIdType { fn bos_token_id(&self) -> TokenIdType {
self.text_config.as_ref().unwrap().bos_token_id self.text_config.as_ref().unwrap().final_bos_token_id
} }
fn eos_token_ids(&self) -> Vec<TokenIdType> { fn eos_token_ids(&self) -> Vec<TokenIdType> {
match &self.text_config.as_ref().unwrap().eos_token_id { self.text_config
Either::Left(eos_token_id) => vec![*eos_token_id], .as_ref()
Either::Right(eos_token_ids) => eos_token_ids.clone(), .unwrap()
} .final_eos_token_ids
.clone()
} }
fn max_position_embeddings(&self) -> usize { fn max_position_embeddings(&self) -> Option<usize> {
self.text_config.as_ref().unwrap().max_position_embeddings self.text_config.as_ref().unwrap().max_position_embeddings
} }
fn vocab_size(&self) -> usize { fn vocab_size(&self) -> Option<usize> {
self.text_config.as_ref().unwrap().vocab_size self.text_config.as_ref().unwrap().vocab_size
} }
} }
......
...@@ -25,8 +25,8 @@ async fn test_model_info_from_hf_like_local_repo() { ...@@ -25,8 +25,8 @@ async fn test_model_info_from_hf_like_local_repo() {
assert_eq!(info.model_type(), "llama"); assert_eq!(info.model_type(), "llama");
assert_eq!(info.bos_token_id(), 1); assert_eq!(info.bos_token_id(), 1);
assert_eq!(info.eos_token_ids(), vec![2]); assert_eq!(info.eos_token_ids(), vec![2]);
assert_eq!(info.max_position_embeddings(), 2048); assert_eq!(info.max_position_embeddings(), Some(2048));
assert_eq!(info.vocab_size(), 32000); assert_eq!(info.vocab_size(), Some(32000));
} }
#[tokio::test] #[tokio::test]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment