Unverified Commit 4ace4c85 authored by GuanLuo's avatar GuanLuo Committed by GitHub
Browse files

fix: ModelDeploymentCard obtains full set of eos_token_ids by taking union...


fix: ModelDeploymentCard obtains full set of eos_token_ids by taking union from different files (#3192)
Signed-off-by: default avatarGuan Luo <gluo@nvidia.com>
Signed-off-by: default avatarGuan Luo <41310872+GuanLuo@users.noreply.github.com>
parent af9ae793
......@@ -248,6 +248,11 @@ class HandlerBase:
if min_tokens:
sampling_params.min_tokens = min_tokens
stop_token_ids = request["stop_conditions"].get("stop_token_ids_hidden")
if stop_token_ids:
existing = sampling_params.stop_token_ids or []
sampling_params.stop_token_ids = list(set(existing).union(stop_token_ids))
# TODO: Instead of True, we should use streaming from the request.
# However, currently dynamo run does not send streaming in the request.
streaming = (
......
......@@ -94,6 +94,13 @@ def build_sampling_params(
if key == "stop":
continue
setattr(sampling_params, key, value)
if (
key == "stop_token_ids_hidden"
and value is not None
and hasattr(sampling_params, "stop_token_ids")
):
existing = sampling_params.stop_token_ids or []
sampling_params.stop_token_ids = list(set(existing).union(value))
# If max_tokens wasn't provided (None or missing), compute a dynamic default
provided_max_tokens = request.get("stop_conditions", {}).get("max_tokens", None)
......
......@@ -670,44 +670,18 @@ impl HFConfig {
text_config.final_bos_token_id = final_bos_token_id;
// TODO: refactor this when we switch to per-architecture tokenization
let final_eos_token_ids: Vec<TokenIdType> = config
.eos_token_id
.as_ref()
.or(text_config.eos_token_id.as_ref())
.and_then(|v| {
if v.is_number() {
v.as_number()
.and_then(|n| n.as_u64())
.map(|n| vec![n as TokenIdType])
} else if v.is_array() {
let arr = v.as_array().unwrap(); // Safety: We just checked
Some(
arr.iter()
.filter_map(|inner_v| {
inner_v
.as_number()
.and_then(|n| n.as_u64())
.map(|n| n as TokenIdType)
})
.collect(),
)
} else {
tracing::error!(
?v,
path = %file_path.display(),
"eos_token_id is not a number or an array, cannot use"
);
None
}
})
.or_else(|| {
// Maybe it's in generation_config.json
// eos_token_id can appear in multiple places, and as suggested by HuggingFace
// community that the priority should be:
// 1. generation_config.json;
// 2. config.json, or text_config field in config.json.
// https://github.com/huggingface/transformers/issues/25395#issuecomment-1671863257
let final_eos_token_ids: Vec<TokenIdType> = {
// Firstly check the generation_config.json
crate::file_json_field::<serde_json::Value>(&gencfg_path, "eos_token_id")
.inspect_err(
|err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"),
)
.ok()
.and_then(|v| {
.ok().and_then(|v| {
if v.is_number() {
v.as_number()
.and_then(|n| n.as_u64())
......@@ -728,6 +702,30 @@ impl HFConfig {
None
}
})
}.or_else(|| {
// Check config.json and text_config
config
.eos_token_id
.as_ref()
.or(text_config.eos_token_id.as_ref())
.and_then(|v| {
if v.is_number() {
v.as_number()
.and_then(|n| n.as_u64())
.map(|n| vec![n as TokenIdType])
} else {
serde_json::from_value(v.clone())
.map(Some)
.unwrap_or_else(|err| {
tracing::error!(
?v,
path = %file_path.display(),
"eos_token_id is not a number or an array, cannot deserialize: {err}",
);
None
})
}
})
})
.ok_or_else(|| {
anyhow::anyhow!(
......@@ -850,6 +848,7 @@ fn check_valid_local_repo_path(path: impl AsRef<Path>) -> Result<()> {
#[cfg(test)]
mod tests {
use super::HFConfig;
use std::collections::HashSet;
use std::path::Path;
#[test]
......@@ -858,6 +857,9 @@ mod tests {
.join("tests/data/sample-models/mock-llama-3.1-8b-instruct/config.json");
let config = HFConfig::from_json_file(&config_file)?;
assert_eq!(config.bos_token_id(), 128000);
// eos_token_ids can be in any order as long as the set is correct
let eos_token_id_set: HashSet<_> = config.eos_token_ids().iter().cloned().collect();
assert_eq!(eos_token_id_set, vec![128001, 128009].into_iter().collect());
Ok(())
}
......
......@@ -5,6 +5,7 @@ use anyhow::{Ok, Result};
use dynamo_runtime::config::environment_names::model::huggingface as env_hf;
use dynamo_llm::model_card::{ModelDeploymentCard, PromptContextMixin};
use dynamo_llm::preprocessor::OpenAIPreprocessor;
use dynamo_llm::preprocessor::prompt::PromptFormatter;
use dynamo_llm::protocols::openai::chat_completions::NvCreateChatCompletionRequest;
use serde::{Deserialize, Serialize};
......@@ -79,12 +80,21 @@ async fn maybe_download_model(local_path: &str, model: &str, revision: &str) ->
let repo = Repo::with_revision(String::from(model), RepoType::Model, String::from(revision));
let files_to_download = vec!["config.json", "tokenizer.json", "tokenizer_config.json"];
let optional_files = vec!["generation_config.json", "chat_template.jinja"];
let repo_builder = api.repo(repo);
let mut downloaded_path = PathBuf::new();
for file in &files_to_download {
downloaded_path = repo_builder.get(file).await.unwrap();
}
for file in &optional_files {
if let Err(e) = repo_builder.get(file).await {
println!(
"Failed to download optional file {} for model {}: {}",
file, model, e
);
}
}
downloaded_path.parent().unwrap().display().to_string()
}
......@@ -496,6 +506,53 @@ async fn test_multi_turn_with_continuation() {
});
}
pub mod openai_preprocessor_tests {
// re-export all the tests from the parent module
pub use super::*;
use std::collections::HashSet;
#[tokio::test]
async fn test_stop_condition() {
if let Err(e) = get_hf_token() {
println!("HF_TOKEN is not set, skipping test: {}", e);
return;
}
let mdc = make_mdc_from_repo(
"tests/data/sample-models",
"openai/gpt-oss-120b",
"b5c939de8f754692c1647ca79fbf85e8c1e70f8a",
Some(vec![PromptContextMixin::OaiChat]),
)
.await;
let oai_preprocessor = OpenAIPreprocessor::new(mdc.clone()).unwrap();
let request = Request::from(SINGLE_CHAT_MESSAGE, None, None, mdc.slug().to_string());
let preprocessed_request = oai_preprocessor
.preprocess_request(&request)
.await
.unwrap()
.0;
assert!(
preprocessed_request
.stop_conditions
.stop_token_ids_hidden
.is_some()
);
// eos_token_ids can be in any order as long as the set is correct
let eos_token_id_set: HashSet<_> = preprocessed_request
.stop_conditions
.stop_token_ids_hidden
.unwrap()
.iter()
.cloned()
.collect();
assert_eq!(
eos_token_id_set,
vec![200002, 199999, 200012].into_iter().collect()
);
}
}
// Helper to build message with media chunks (single or mixed types)
fn build_message(text: &str, chunks: &[(&str, usize)]) -> String {
let mut content_parts = vec![format!(r#"{{"type": "text", "text": "{}"}}"#, text)];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment