Unverified Commit dad42f42 authored by Graham King's avatar Graham King Committed by GitHub
Browse files

feat(model_card): Make bos_token_id optional (#5394)


Signed-off-by: default avatarGraham King <grahamk@nvidia.com>
parent 8c129ed4
...@@ -684,8 +684,8 @@ pub trait ModelInfo: Send + Sync { ...@@ -684,8 +684,8 @@ pub trait ModelInfo: Send + Sync {
/// Model type /// Model type
fn model_type(&self) -> String; fn model_type(&self) -> String;
/// Token ID for the beginning of sequence /// Token ID for the beginning of sequence (optional - not all models have it)
fn bos_token_id(&self) -> TokenIdType; fn bos_token_id(&self) -> Option<TokenIdType>;
/// Token ID for the end of sequence /// Token ID for the end of sequence
fn eos_token_ids(&self) -> Vec<TokenIdType>; fn eos_token_ids(&self) -> Vec<TokenIdType>;
...@@ -729,13 +729,9 @@ struct HFConfig { ...@@ -729,13 +729,9 @@ struct HFConfig {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
struct HFTextConfig { struct HFTextConfig {
// It can take multiple attempts to load this, so Option // Optional - not all models have a bos_token_id
bos_token_id: Option<TokenIdType>, bos_token_id: Option<TokenIdType>,
// We set this once bos_token_id is loaded so we don't have to deal with Option
#[serde(default)]
final_bos_token_id: TokenIdType,
eos_token_id: Option<serde_json::Value>, eos_token_id: Option<serde_json::Value>,
#[serde(default)] #[serde(default)]
...@@ -771,7 +767,6 @@ impl HFConfig { ...@@ -771,7 +767,6 @@ impl HFConfig {
config.text_config = Some(text_config); config.text_config = Some(text_config);
} }
// Sometimes bos_token_id is in generation_config.json not config.json
let Some(text_config) = config.text_config.as_mut() else { let Some(text_config) = config.text_config.as_mut() else {
anyhow::bail!( anyhow::bail!(
"Missing text config fields (model_type, eos_token_ids, etc) in config.json" "Missing text config fields (model_type, eos_token_ids, etc) in config.json"
...@@ -782,16 +777,13 @@ impl HFConfig { ...@@ -782,16 +777,13 @@ impl HFConfig {
.parent() .parent()
.unwrap_or_else(|| Path::new("")) .unwrap_or_else(|| Path::new(""))
.join("generation_config.json"); .join("generation_config.json");
// bos_token_id is optional - not all models have it
// Try to load from generation_config.json if not in config.json
if text_config.bos_token_id.is_none() { if text_config.bos_token_id.is_none() {
let bos_token_id = crate::file_json_field::<TokenIdType>(&gencfg_path, "bos_token_id") text_config.bos_token_id =
.context( crate::file_json_field::<TokenIdType>(&gencfg_path, "bos_token_id").ok();
"missing bos_token_id in generation_config.json and config.json, cannot load",
)?;
text_config.bos_token_id = Some(bos_token_id);
} }
// Now that we have it for sure, set it in the non-Option field
let final_bos_token_id = text_config.bos_token_id.take().unwrap();
text_config.final_bos_token_id = final_bos_token_id;
// TODO: refactor this when we switch to per-architecture tokenization // TODO: refactor this when we switch to per-architecture tokenization
// eos_token_id can appear in multiple places, and as suggested by HuggingFace // eos_token_id can appear in multiple places, and as suggested by HuggingFace
...@@ -867,8 +859,8 @@ impl ModelInfo for HFConfig { ...@@ -867,8 +859,8 @@ impl ModelInfo for HFConfig {
self.model_type.clone() self.model_type.clone()
} }
fn bos_token_id(&self) -> TokenIdType { fn bos_token_id(&self) -> Option<TokenIdType> {
self.text_config.as_ref().unwrap().final_bos_token_id self.text_config.as_ref().and_then(|tc| tc.bos_token_id)
} }
fn eos_token_ids(&self) -> Vec<TokenIdType> { fn eos_token_ids(&self) -> Vec<TokenIdType> {
...@@ -983,7 +975,7 @@ mod tests { ...@@ -983,7 +975,7 @@ mod tests {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR")) let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json"); .join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
let config = HFConfig::from_json_file(&config_file)?; let config = HFConfig::from_json_file(&config_file)?;
assert_eq!(config.bos_token_id(), 128000); assert_eq!(config.bos_token_id(), Some(128000));
// eos_token_ids can be in any order as long as the set is correct // eos_token_ids can be in any order as long as the set is correct
let eos_token_id_set: HashSet<_> = config.eos_token_ids().iter().cloned().collect(); let eos_token_id_set: HashSet<_> = config.eos_token_ids().iter().cloned().collect();
assert_eq!(eos_token_id_set, vec![128001, 128009].into_iter().collect()); assert_eq!(eos_token_id_set, vec![128001, 128009].into_iter().collect());
...@@ -995,7 +987,7 @@ mod tests { ...@@ -995,7 +987,7 @@ mod tests {
let config_file = Path::new(env!("CARGO_MANIFEST_DIR")) let config_file = Path::new(env!("CARGO_MANIFEST_DIR"))
.join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json"); .join("tests/data/sample-models/Llama-4-Scout-17B-16E-Instruct/config.json");
let config = HFConfig::from_json_file(&config_file)?; let config = HFConfig::from_json_file(&config_file)?;
assert_eq!(config.bos_token_id(), 200000); assert_eq!(config.bos_token_id(), Some(200000));
Ok(()) Ok(())
} }
......
...@@ -11,7 +11,7 @@ async fn test_model_info_from_hf_like_local_repo() { ...@@ -11,7 +11,7 @@ async fn test_model_info_from_hf_like_local_repo() {
let mdc = ModelDeploymentCard::load_from_disk(HF_PATH, None).unwrap(); let mdc = ModelDeploymentCard::load_from_disk(HF_PATH, None).unwrap();
let info = mdc.model_info.unwrap().get_model_info().unwrap(); let info = mdc.model_info.unwrap().get_model_info().unwrap();
assert_eq!(info.model_type(), "llama"); assert_eq!(info.model_type(), "llama");
assert_eq!(info.bos_token_id(), 1); assert_eq!(info.bos_token_id(), Some(1));
assert_eq!(info.eos_token_ids(), vec![2]); assert_eq!(info.eos_token_ids(), vec![2]);
assert_eq!(info.max_position_embeddings(), Some(2048)); assert_eq!(info.max_position_embeddings(), Some(2048));
assert_eq!(info.vocab_size(), Some(32000)); assert_eq!(info.vocab_size(), Some(32000));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment