Unverified Commit cfd12d7f authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat: Support larger Gemma 3 models (#1359)

Publish `generation_config.json` from worker to ingress, as part of Model Deployment Card. That allows ingress to read key fields out of it. Gemma 3 4B+ has some important information that's only in there.
parent 9e8731e5
......@@ -106,8 +106,12 @@ impl ModelWatcher {
tracing::info!(model_name = model_entry.name, "added model");
self.notify_on_model.notify_waiters();
}
Err(e) => {
tracing::error!(%e, "error adding model {}", model_entry.name);
Err(err) => {
tracing::error!(
error = format!("{err:#}"),
"error adding model {}",
model_entry.name
);
}
}
}
......
......@@ -6,6 +6,10 @@
//! The `dynamo.llm` crate is a Rust library that provides a set of traits and types for building
//! distributed LLM inference solutions.
use std::{fs::File, io::BufReader, path::Path};
use anyhow::Context as _;
pub mod backend;
pub mod common;
pub mod disagg_router;
......@@ -30,3 +34,233 @@ pub mod types;
#[cfg(feature = "block-manager")]
pub mod block_manager;
/// Reads a JSON file, extracts a specific field, and deserializes it into type T.
///
/// # Arguments
///
/// * `json_file_path`: Path to the JSON file.
/// * `field_name`: The name of the field to extract from the JSON map.
///
/// # Returns
///
/// A `Result` containing the deserialized value of type `T` if successful,
/// or an `anyhow::Error` if any step fails (file I/O, JSON parsing, field not found,
/// or deserialization to `T` fails).
///
/// # Type Parameters
///
/// * `T`: The expected type of the field's value. `T` must implement `serde::de::DeserializeOwned`.
pub fn file_json_field<T: serde::de::DeserializeOwned>(
json_file_path: &Path,
field_name: &str,
) -> anyhow::Result<T> {
// 1. Open the file
let file = File::open(json_file_path)
.with_context(|| format!("Failed to open file: {:?}", json_file_path))?;
let reader = BufReader::new(file);
// 2. Parse the JSON file into a generic serde_json::Value
// We parse into `serde_json::Value` first because we need to look up a specific field.
// If we tried to deserialize directly into `T`, `T` would need to represent the whole JSON structure.
let json_data: serde_json::Value = serde_json::from_reader(reader)
.with_context(|| format!("Failed to parse JSON from file: {:?}", json_file_path))?;
// 3. Ensure the root of the JSON is an object (map)
let map = json_data.as_object().ok_or_else(|| {
anyhow::anyhow!("JSON root is not an object in file: {:?}", json_file_path)
})?;
// 4. Get the specific field's value
let field_value = map.get(field_name).ok_or_else(|| {
anyhow::anyhow!(
"Field '{}' not found in JSON file: {:?}",
field_name,
json_file_path
)
})?;
// 5. Deserialize the field's value into the target type T
// We need to clone `field_value` because `from_value` consumes its input.
serde_json::from_value(field_value.clone()).with_context(|| {
format!(
"Failed to deserialize field '{}' (value: {:?}) to the expected type from file: {:?}",
field_name, field_value, json_file_path
)
})
}
#[cfg(test)]
mod file_json_field_tests {
use super::file_json_field;
use serde::Deserialize;
use std::fs::File;
use std::io::Write;
use std::path::{Path, PathBuf};
use tempfile::tempdir;
// Helper function to create a temporary JSON file
fn create_temp_json_file(dir: &Path, file_name: &str, content: &str) -> PathBuf {
let file_path = dir.join(file_name);
let mut file = File::create(&file_path)
.unwrap_or_else(|_| panic!("Failed to create test file: {:?}", file_path));
file.write_all(content.as_bytes())
.unwrap_or_else(|_| panic!("Failed to write to test file: {:?}", file_path));
file_path
}
// Define a custom struct for testing deserialization
#[derive(Debug, PartialEq, Deserialize)]
struct MyConfig {
version: String,
enabled: bool,
count: u32,
}
#[test]
fn test_success_basic() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"test_basic.json",
r#"{ "name": "Rust", "age": 30, "is_active": true }"#,
);
let name: String = file_json_field(&file_path, "name").unwrap();
assert_eq!(name, "Rust");
let age: i32 = file_json_field(&file_path, "age").unwrap();
assert_eq!(age, 30);
let is_active: bool = file_json_field(&file_path, "is_active").unwrap();
assert!(is_active);
}
#[test]
fn test_success_custom_struct_field() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"test_struct.json",
r#"{
"config": {
"version": "1.0.0",
"enabled": true,
"count": 123
},
"other_field": "value"
}"#,
);
let config: MyConfig = file_json_field(&file_path, "config").unwrap();
assert_eq!(
config,
MyConfig {
version: "1.0.0".to_string(),
enabled: true,
count: 123,
}
);
}
#[test]
fn test_file_not_found() {
let tmp_dir = tempdir().unwrap();
let non_existent_path = tmp_dir.path().join("non_existent.json");
let result: anyhow::Result<String> = file_json_field(&non_existent_path, "field");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Failed to open file"));
}
#[test]
fn test_invalid_json_syntax() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"invalid.json",
r#"{ "key": "value", "bad_syntax": }"#, // Malformed JSON
);
let result: anyhow::Result<String> = file_json_field(&file_path, "key");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Failed to parse JSON from file"));
}
#[test]
fn test_json_root_not_object_array() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"root_array.json",
r#"[ { "item": 1 }, { "item": 2 } ]"#, // Root is an array
);
let result: anyhow::Result<String> = file_json_field(&file_path, "item");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("JSON root is not an object"));
}
#[test]
fn test_json_root_not_object_primitive() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"root_primitive.json",
r#""just_a_string""#, // Root is a string
);
let result: anyhow::Result<String> = file_json_field(&file_path, "field");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("JSON root is not an object"));
}
#[test]
fn test_field_not_found() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"missing_field.json",
r#"{ "existing_field": "hello" }"#,
);
let result: anyhow::Result<String> = file_json_field(&file_path, "non_existent_field");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err
.to_string()
.contains("Field 'non_existent_field' not found"));
}
#[test]
fn test_field_type_mismatch() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(
tmp_dir.path(),
"type_mismatch.json",
r#"{ "count": "not_an_integer" }"#,
);
let result: anyhow::Result<u32> = file_json_field(&file_path, "count");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err
.to_string()
.contains("Failed to deserialize field 'count'"));
}
#[test]
fn test_empty_file() {
let tmp_dir = tempdir().unwrap();
let file_path = create_temp_json_file(tmp_dir.path(), "empty.json", "");
let result: anyhow::Result<String> = file_json_field(&file_path, "field");
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Failed to parse JSON from file"));
}
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::time::Duration;
pub mod create;
pub mod model;
pub use model::ModelDeploymentCard;
// TODO: Do these network/publish related model deployment card values belong here or in a
// network module?
/// Identify model deployment cards in the key-value store
pub const ROOT_PATH: &str = "mdc";
/// Delete model deployment cards that haven't been re-published after this long.
/// Cleans up if the worker stopped.
pub const BUCKET_TTL: Duration = Duration::from_secs(5 * 60);
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashMap;
use crate::model_card::model::ModelDeploymentCard;
use anyhow::{Context, Result};
use std::fs::{self, File};
use std::io::BufReader;
use std::path::{Path, PathBuf};
use crate::model_card::model::{ModelInfoType, PromptFormatterArtifact, TokenizerKind};
use super::model::GenerationConfig;
impl ModelDeploymentCard {
/// Allow user to override the name we register this model under.
/// Corresponds to vllm's `--served-model-name`.
......@@ -98,6 +84,7 @@ impl ModelDeploymentCard {
service_name: model_name.to_string(),
model_info: Some(ModelInfoType::GGUF(gguf_file.to_path_buf())),
tokenizer: Some(TokenizerKind::from_gguf(gguf_file)?),
gen_config: None, // AFAICT there is no equivalent in a GGUF
prompt_formatter: Some(PromptFormatterArtifact::GGUF(gguf_file.to_path_buf())),
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
......@@ -116,14 +103,14 @@ impl ModelDeploymentCard {
async fn from_repo(repo_id: &str, model_name: &str) -> anyhow::Result<Self> {
// This is usually the right choice
let context_length = file_json_field(
&Path::join(&PathBuf::from(repo_id), "config.json"),
let context_length = crate::file_json_field(
&PathBuf::from(repo_id).join("config.json"),
"max_position_embeddings",
)
// But sometimes this is
.or_else(|_| {
file_json_field(
&Path::join(&PathBuf::from(repo_id), "tokenizer_config.json"),
crate::file_json_field(
&PathBuf::from(repo_id).join("tokenizer_config.json"),
"model_max_length",
)
})
......@@ -135,6 +122,7 @@ impl ModelDeploymentCard {
service_name: model_name.to_string(),
model_info: Some(ModelInfoType::from_repo(repo_id).await?),
tokenizer: Some(TokenizerKind::from_repo(repo_id).await?),
gen_config: GenerationConfig::from_repo(repo_id).await.ok(), // optional
prompt_formatter: PromptFormatterArtifact::from_repo(repo_id).await?,
prompt_context: None, // TODO - auto-detect prompt context
revision: 0,
......@@ -190,37 +178,28 @@ impl TokenizerKind {
}
}
/// Checks if the provided path contains the expected file.
async fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<String> {
let mut files = check_for_files(repo_id, vec![file.to_string()]).await?;
let file = files
.remove(file)
.ok_or(anyhow::anyhow!("file {} not found", file))?;
Ok(file)
}
impl GenerationConfig {
pub async fn from_repo(repo_id: &str) -> Result<Self> {
Self::try_is_hf_repo(repo_id)
.await
.with_context(|| format!("unable to extract generation config from repo {repo_id}"))
}
async fn check_for_files(repo_id: &str, files: Vec<String>) -> Result<HashMap<String, String>> {
let dir_entries =
fs::read_dir(repo_id).with_context(|| format!("Failed to read directory: {}", repo_id))?;
let mut found_files = HashMap::new();
for entry in dir_entries {
let entry =
entry.with_context(|| format!("Failed to read directory entry in {}", repo_id))?;
let path = entry.path();
let file_name = path
.file_name()
.and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Invalid file name in {}", repo_id))?;
if files.contains(&file_name.to_string()) {
found_files.insert(
file_name.to_string(),
path.to_str()
.ok_or_else(|| anyhow::anyhow!("Invalid path"))?
.to_string(),
);
async fn try_is_hf_repo(repo: &str) -> anyhow::Result<Self> {
Ok(Self::HfGenerationConfigJson(
check_for_file(repo, "generation_config.json").await?,
))
}
}
/// Checks if the provided path contains the expected file.
async fn check_for_file(repo_id: &str, file: &str) -> anyhow::Result<String> {
let p = PathBuf::from(repo_id).join(file);
let name = p.display().to_string();
if !p.exists() {
anyhow::bail!("File not found: {name}")
}
Ok(found_files)
Ok(name)
}
/// Checks if the provided path is a valid local repository path.
......@@ -247,58 +226,3 @@ fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> {
}
Ok(())
}
/// Reads a JSON file, extracts a specific field, and deserializes it into type T.
///
/// # Arguments
///
/// * `json_file_path`: Path to the JSON file.
/// * `field_name`: The name of the field to extract from the JSON map.
///
/// # Returns
///
/// A `Result` containing the deserialized value of type `T` if successful,
/// or an `anyhow::Error` if any step fails (file I/O, JSON parsing, field not found,
/// or deserialization to `T` fails).
///
/// # Type Parameters
///
/// * `T`: The expected type of the field's value. `T` must implement `serde::de::DeserializeOwned`.
fn file_json_field<T: serde::de::DeserializeOwned>(
json_file_path: &Path,
field_name: &str,
) -> anyhow::Result<T> {
// 1. Open the file
let file = File::open(json_file_path)
.with_context(|| format!("Failed to open file: {:?}", json_file_path))?;
let reader = BufReader::new(file);
// 2. Parse the JSON file into a generic serde_json::Value
// We parse into `serde_json::Value` first because we need to look up a specific field.
// If we tried to deserialize directly into `T`, `T` would need to represent the whole JSON structure.
let json_data: serde_json::Value = serde_json::from_reader(reader)
.with_context(|| format!("Failed to parse JSON from file: {:?}", json_file_path))?;
// 3. Ensure the root of the JSON is an object (map)
let map = json_data.as_object().ok_or_else(|| {
anyhow::anyhow!("JSON root is not an object in file: {:?}", json_file_path)
})?;
// 4. Get the specific field's value
let field_value = map.get(field_name).ok_or_else(|| {
anyhow::anyhow!(
"Field '{}' not found in JSON file: {:?}",
field_name,
json_file_path
)
})?;
// 5. Deserialize the field's value into the target type T
// We need to clone `field_value` because `from_value` consumes its input.
serde_json::from_value(field_value.clone()).with_context(|| {
format!(
"Failed to deserialize field '{}' (value: {:?}) to the expected type from file: {:?}",
field_name, field_value, json_file_path
)
})
}
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! # Model Deployment Card
//!
......@@ -35,7 +23,6 @@ use anyhow::{Context, Result};
use derive_builder::Builder;
use dynamo_runtime::slug::Slug;
use dynamo_runtime::transports::nats;
use either::Either;
use serde::{Deserialize, Serialize};
use tokenizers::Tokenizer as HfTokenizer;
use url::Url;
......@@ -44,10 +31,6 @@ use crate::gguf::{Content, ContentConfig, ModelConfigLike};
use crate::key_value_store::Versioned;
use crate::protocols::TokenIdType;
/// Delete model deployment cards that haven't been re-published after this long.
/// Cleans up if the worker stopped.
pub const BUCKET_TTL: Duration = Duration::from_secs(5 * 60);
/// If a model deployment card hasn't been refreshed in this much time the worker is likely gone
const CARD_MAX_AGE: chrono::TimeDelta = chrono::TimeDelta::minutes(5);
......@@ -94,6 +77,13 @@ pub enum PromptContextMixin {
Llama3DateTime,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")]
pub enum GenerationConfig {
HfGenerationConfigJson(String),
GGUF(PathBuf),
}
#[derive(Serialize, Deserialize, Clone, Debug, Builder, Default)]
pub struct ModelDeploymentCard {
/// Human readable model name, e.g. "Meta Llama 3.1 8B Instruct"
......@@ -113,6 +103,10 @@ pub struct ModelDeploymentCard {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_formatter: Option<PromptFormatterArtifact>,
/// Generation config - default sampling params
#[serde(default, skip_serializing_if = "Option::is_none")]
pub gen_config: Option<GenerationConfig>,
/// Prompt Formatter Config
#[serde(default, skip_serializing_if = "Option::is_none")]
pub prompt_context: Option<Vec<PromptContextMixin>>,
......@@ -244,38 +238,39 @@ impl ModelDeploymentCard {
"Uploading model deployment card fields to NATS"
);
if let Some(ModelInfoType::HfConfigJson(ref src_file)) = self.model_info {
if !nats::is_nats_url(src_file) {
let target = format!("nats://{nats_addr}/{bucket_name}/config.json");
macro_rules! nats_upload {
($field:expr, $enum_variant:path, $filename:literal) => {
if let Some($enum_variant(src_file)) = $field.take() {
if !nats::is_nats_url(&src_file) {
let target = format!("nats://{nats_addr}/{bucket_name}/{}", $filename);
nats_client
.object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?)
.object_store_upload(
&std::path::PathBuf::from(&src_file),
url::Url::parse(&target)?,
)
.await?;
self.model_info = Some(ModelInfoType::HfConfigJson(target));
}
$field = Some($enum_variant(target));
}
if let Some(PromptFormatterArtifact::HfTokenizerConfigJson(ref src_file)) =
self.prompt_formatter
{
if !nats::is_nats_url(src_file) {
let target = format!("nats://{nats_addr}/{bucket_name}/tokenizer_config.json");
nats_client
.object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?)
.await?;
self.prompt_formatter =
Some(PromptFormatterArtifact::HfTokenizerConfigJson(target));
}
};
}
if let Some(TokenizerKind::HfTokenizerJson(ref src_file)) = self.tokenizer {
if !nats::is_nats_url(src_file) {
let target = format!("nats://{nats_addr}/{bucket_name}/tokenizer.json");
nats_client
.object_store_upload(&PathBuf::from(src_file), Url::parse(&target)?)
.await?;
self.tokenizer = Some(TokenizerKind::HfTokenizerJson(target));
}
}
nats_upload!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
nats_upload!(
self.prompt_formatter,
PromptFormatterArtifact::HfTokenizerConfigJson,
"tokenizer_config.json"
);
nats_upload!(
self.tokenizer,
TokenizerKind::HfTokenizerJson,
"tokenizer.json"
);
nats_upload!(
self.gen_config,
GenerationConfig::HfGenerationConfigJson,
"generation_config.json"
);
Ok(())
}
......@@ -295,39 +290,36 @@ impl ModelDeploymentCard {
"Downloading model deployment card fields from NATS"
);
if let Some(ModelInfoType::HfConfigJson(ref src_url)) = self.model_info {
if nats::is_nats_url(src_url) {
let target = target_dir.path().join("config.json");
macro_rules! nats_download {
($field:expr, $enum_variant:path, $filename:literal) => {
if let Some($enum_variant(src_url)) = $field.take() {
if nats::is_nats_url(&src_url) {
let target = target_dir.path().join($filename);
nats_client
.object_store_download(Url::parse(src_url)?, &target)
.object_store_download(Url::parse(&src_url)?, &target)
.await?;
self.model_info = Some(ModelInfoType::HfConfigJson(target.display().to_string()));
}
$field = Some($enum_variant(target.display().to_string()));
}
if let Some(PromptFormatterArtifact::HfTokenizerConfigJson(ref src_url)) =
self.prompt_formatter
{
if nats::is_nats_url(src_url) {
let target = target_dir.path().join("tokenizer_config.json");
nats_client
.object_store_download(Url::parse(src_url)?, &target)
.await?;
self.prompt_formatter = Some(PromptFormatterArtifact::HfTokenizerConfigJson(
target.display().to_string(),
));
}
};
}
if let Some(TokenizerKind::HfTokenizerJson(ref src_url)) = self.tokenizer {
if nats::is_nats_url(src_url) {
let target = target_dir.path().join("tokenizer.json");
nats_client
.object_store_download(Url::parse(src_url)?, &target)
.await?;
self.tokenizer = Some(TokenizerKind::HfTokenizerJson(target.display().to_string()));
}
}
nats_download!(self.model_info, ModelInfoType::HfConfigJson, "config.json");
nats_download!(
self.prompt_formatter,
PromptFormatterArtifact::HfTokenizerConfigJson,
"tokenizer_config.json"
);
nats_download!(
self.tokenizer,
TokenizerKind::HfTokenizerJson,
"tokenizer.json"
);
nats_download!(
self.gen_config,
GenerationConfig::HfGenerationConfigJson,
"generation_config.json"
);
Ok(target_dir)
}
......@@ -374,10 +366,12 @@ pub trait ModelInfo: Send + Sync {
fn eos_token_ids(&self) -> Vec<TokenIdType>;
/// Maximum position embeddings / max sequence length
fn max_position_embeddings(&self) -> usize;
/// TODO: This is only used in a single test, no other code. Remove?
fn max_position_embeddings(&self) -> Option<usize>;
/// Vocabulary size
fn vocab_size(&self) -> usize;
/// TODO: This is only used in a single test, no other code. Remove?
fn vocab_size(&self) -> Option<usize>;
}
impl ModelInfoType {
......@@ -402,36 +396,123 @@ struct HFConfig {
model_type: String,
text_config: Option<HFTextConfig>,
// Sometimes it's inside HFTextConfig, sometimes it's here
eos_token_id: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct HFTextConfig {
bos_token_id: TokenIdType,
// It can take multiple attempts to load this, so Option
bos_token_id: Option<TokenIdType>,
// We set this once bos_token_id is loaded so we don't have to deal with Option
#[serde(default)]
final_bos_token_id: TokenIdType,
eos_token_id: Option<serde_json::Value>,
#[serde(with = "either::serde_untagged")]
eos_token_id: Either<TokenIdType, Vec<TokenIdType>>,
#[serde(default)]
final_eos_token_ids: Vec<TokenIdType>,
/// max sequence length
max_position_embeddings: usize,
max_position_embeddings: Option<usize>,
/// number of layers in the model
num_hidden_layers: usize,
/// number of attention heads in the model
num_attention_heads: usize,
num_attention_heads: Option<usize>,
/// Vocabulary size
vocab_size: usize,
vocab_size: Option<usize>,
}
impl HFConfig {
async fn from_json_file(file: &str) -> Result<Arc<dyn ModelInfo>> {
let file_pathbuf = PathBuf::from(file);
let contents = std::fs::read_to_string(file)?;
let mut config: Self = serde_json::from_str(&contents)?;
if config.text_config.is_none() {
let text_config: HFTextConfig = serde_json::from_str(&contents)?;
config.text_config = Some(text_config);
}
// Sometimes bos_token_id is in generation_config.json not config.json
let Some(text_config) = config.text_config.as_mut() else {
anyhow::bail!(
"Missing text config fields (model_type, eos_token_ids, etc) in config.json"
);
};
if text_config.bos_token_id.is_none() {
let bos_token_id = crate::file_json_field::<TokenIdType>(
&Path::join(
file_pathbuf.parent().unwrap_or(&PathBuf::from("")),
"generation_config.json",
),
"bos_token_id",
)
.context(
"missing bos_token_id in generation_config.json and config.json, cannot load",
)?;
text_config.bos_token_id = Some(bos_token_id);
}
// Now that we have it for sure, set it in the non-Option field
let final_bos_token_id = text_config.bos_token_id.take().unwrap();
text_config.final_bos_token_id = final_bos_token_id;
// TODO: refactor this when we switch to per-architecture tokenization
let final_eos_token_ids: Vec<TokenIdType> = config
.eos_token_id
.as_ref()
.or(text_config.eos_token_id.as_ref())
.and_then(|v| {
if v.is_number() {
v.as_number()
.and_then(|n| n.as_u64())
.map(|n| vec![n as TokenIdType])
} else if v.is_array() {
let arr = v.as_array().unwrap(); // Safety: We just checked
Some(
arr.iter()
.filter_map(|inner_v| {
inner_v
.as_number()
.and_then(|n| n.as_u64())
.map(|n| n as TokenIdType)
})
.collect(),
)
} else {
tracing::error!(
?v,
file,
"eos_token_id is not a number or an array, cannot use"
);
None
}
})
.or_else(|| {
// Maybe it's in generation_config.json
crate::file_json_field(
&Path::join(
file_pathbuf.parent().unwrap_or(&PathBuf::from("")),
"generation_config.json",
),
"eos_token_id",
)
.inspect_err(
|err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"),
)
.ok()
})
.ok_or_else(|| {
anyhow::anyhow!(
"missing eos_token_id in config.json and generation_config.json, cannot load"
)
})?;
text_config.final_eos_token_ids = final_eos_token_ids;
Ok(Arc::new(config))
}
fn from_gguf(gguf_file: &Path) -> Result<Arc<dyn ModelInfo>> {
......@@ -454,17 +535,22 @@ impl HFConfig {
// "general.architecture"
model_type: arch,
text_config: Some(HFTextConfig {
bos_token_id,
eos_token_id: Either::Left(eos_token_id),
bos_token_id: None,
final_bos_token_id: bos_token_id,
eos_token_id: None,
final_eos_token_ids: vec![eos_token_id],
// "llama.context_length"
max_position_embeddings: model_config_metadata.max_seq_len(),
max_position_embeddings: Some(model_config_metadata.max_seq_len()),
// "llama.block_count"
num_hidden_layers,
// "llama.attention.head_count"
num_attention_heads: model_config_metadata.num_attn_heads(),
num_attention_heads: Some(model_config_metadata.num_attn_heads()),
// "tokenizer.ggml.tokens".len()
vocab_size,
vocab_size: Some(vocab_size),
}),
eos_token_id: None,
}))
}
}
......@@ -475,21 +561,22 @@ impl ModelInfo for HFConfig {
}
fn bos_token_id(&self) -> TokenIdType {
self.text_config.as_ref().unwrap().bos_token_id
self.text_config.as_ref().unwrap().final_bos_token_id
}
fn eos_token_ids(&self) -> Vec<TokenIdType> {
match &self.text_config.as_ref().unwrap().eos_token_id {
Either::Left(eos_token_id) => vec![*eos_token_id],
Either::Right(eos_token_ids) => eos_token_ids.clone(),
}
self.text_config
.as_ref()
.unwrap()
.final_eos_token_ids
.clone()
}
fn max_position_embeddings(&self) -> usize {
fn max_position_embeddings(&self) -> Option<usize> {
self.text_config.as_ref().unwrap().max_position_embeddings
}
fn vocab_size(&self) -> usize {
fn vocab_size(&self) -> Option<usize> {
self.text_config.as_ref().unwrap().vocab_size
}
}
......
......@@ -25,8 +25,8 @@ async fn test_model_info_from_hf_like_local_repo() {
assert_eq!(info.model_type(), "llama");
assert_eq!(info.bos_token_id(), 1);
assert_eq!(info.eos_token_ids(), vec![2]);
assert_eq!(info.max_position_embeddings(), 2048);
assert_eq!(info.vocab_size(), 32000);
assert_eq!(info.max_position_embeddings(), Some(2048));
assert_eq!(info.vocab_size(), Some(32000));
}
#[tokio::test]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment