Unverified Commit 06a24503 authored by Neal Vaidya's avatar Neal Vaidya Committed by GitHub
Browse files

feat: support chat_template.json as a prompt formatter artifact (#7785)

Closes https://github.com/ai-dynamo/dynamo/issues/7737
parent 52a3ca94
......@@ -104,14 +104,25 @@ impl TokenizerKind {
#[serde(rename_all = "snake_case")]
pub enum PromptFormatterArtifact {
HfTokenizerConfigJson(CheckedFile),
HfChatTemplate { is_custom: bool, file: CheckedFile },
#[serde(rename = "hf_chat_template", alias = "hf_chat_template_jinja")]
HfChatTemplateJinja {
is_custom: bool,
file: CheckedFile,
},
HfChatTemplateJson {
is_custom: bool,
file: CheckedFile,
},
}
impl PromptFormatterArtifact {
pub fn checksum(&self) -> String {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.checksum().to_string(),
PromptFormatterArtifact::HfChatTemplate { file: c, .. } => c.checksum().to_string(),
PromptFormatterArtifact::HfChatTemplateJinja { file: c, .. }
| PromptFormatterArtifact::HfChatTemplateJson { file: c, .. } => {
c.checksum().to_string()
}
}
}
......@@ -119,21 +130,24 @@ impl PromptFormatterArtifact {
pub fn is_local(&self) -> bool {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.is_local(),
PromptFormatterArtifact::HfChatTemplate { file: c, .. } => c.is_local(),
PromptFormatterArtifact::HfChatTemplateJinja { file: c, .. }
| PromptFormatterArtifact::HfChatTemplateJson { file: c, .. } => c.is_local(),
}
}
pub fn update_dir(&mut self, dir: &Path) {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.update_dir(dir),
PromptFormatterArtifact::HfChatTemplate { file: c, .. } => c.update_dir(dir),
PromptFormatterArtifact::HfChatTemplateJinja { file: c, .. }
| PromptFormatterArtifact::HfChatTemplateJson { file: c, .. } => c.update_dir(dir),
}
}
pub fn is_custom(&self) -> bool {
match self {
PromptFormatterArtifact::HfTokenizerConfigJson(_) => false,
PromptFormatterArtifact::HfChatTemplate { is_custom, .. } => *is_custom,
PromptFormatterArtifact::HfChatTemplateJinja { is_custom, .. }
| PromptFormatterArtifact::HfChatTemplateJson { is_custom, .. } => *is_custom,
}
}
}
......@@ -553,10 +567,16 @@ impl ModelDeploymentCard {
// We only "move" the chat template if it came form the repo. If we have a custom template
// file we cannot download that from HF.
if let Some(PromptFormatterArtifact::HfChatTemplate {
if let Some(
PromptFormatterArtifact::HfChatTemplateJinja {
file: src_file,
is_custom,
}) = self.chat_template_file.as_mut()
}
| PromptFormatterArtifact::HfChatTemplateJson {
file: src_file,
is_custom,
},
) = self.chat_template_file.as_mut()
{
if *is_custom {
tracing::info!(
......@@ -708,7 +728,7 @@ impl ModelDeploymentCard {
)
})?;
Some(PromptFormatterArtifact::HfChatTemplate {
Some(PromptFormatterArtifact::HfChatTemplateJinja {
is_custom: custom_template_path.is_some(),
file: CheckedFile::from_disk(template_path)?,
})
......@@ -1001,13 +1021,29 @@ impl PromptFormatterArtifact {
}
pub fn chat_template_from_disk(directory: &Path) -> Result<Option<Self>> {
match CheckedFile::from_disk(directory.join("chat_template.jinja")) {
Ok(f) => Ok(Some(Self::HfChatTemplate {
// Try chat_template.jinja first (raw Jinja template)
let jinja_path = directory.join("chat_template.jinja");
if jinja_path.exists() {
let f = CheckedFile::from_disk(&jinja_path)
.with_context(|| format!("Failed to load {}", jinja_path.display()))?;
return Ok(Some(Self::HfChatTemplateJinja {
file: f,
is_custom: false,
})),
Err(_) => Ok(None),
}));
}
// Try chat_template.json (JSON with "chat_template" key, e.g. Qwen3-Omni)
let json_path = directory.join("chat_template.json");
if json_path.exists() {
let f = CheckedFile::from_disk(&json_path)
.with_context(|| format!("Failed to load {}", json_path.display()))?;
return Ok(Some(Self::HfChatTemplateJson {
file: f,
is_custom: false,
}));
}
Ok(None)
}
}
......
......@@ -58,22 +58,54 @@ impl PromptFormatter {
// stores the chat template in a separate file, we check if the file exists and
// put the chat template into config as normalization.
// This may also be a custom template provided via CLI flag.
if let Some(PromptFormatterArtifact::HfChatTemplate {
file: checked_file, ..
}) = mdc.chat_template_file.as_ref()
{
let Some(chat_template_file) = checked_file.path() else {
match mdc.chat_template_file.as_ref() {
Some(PromptFormatterArtifact::HfChatTemplateJinja {
file: checked_file,
..
}) => {
let Some(path) = checked_file.path() else {
anyhow::bail!(
"HfChatTemplate for {} is a URL, cannot load",
"HfChatTemplateJinja for {} is a URL, cannot load",
mdc.display_name
);
};
let chat_template =
std::fs::read_to_string(chat_template_file).with_context(|| {
format!("fs:read_to_string '{}'", chat_template_file.display())
})?;
let chat_template = std::fs::read_to_string(path)
.with_context(|| format!("fs:read_to_string '{}'", path.display()))?;
config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
}
Some(PromptFormatterArtifact::HfChatTemplateJson {
file: checked_file,
..
}) => {
let Some(path) = checked_file.path() else {
anyhow::bail!(
"HfChatTemplateJson for {} is a URL, cannot load",
mdc.display_name
);
};
let raw = std::fs::read_to_string(path)
.with_context(|| format!("fs:read_to_string '{}'", path.display()))?;
let wrapper: serde_json::Value =
serde_json::from_str(&raw).with_context(|| {
format!("Failed to parse '{}' as JSON", path.display())
})?;
let field = wrapper.get("chat_template").ok_or_else(|| {
anyhow::anyhow!(
"'{}' does not contain a 'chat_template' field",
path.display()
)
})?;
let value = serde_json::from_value::<ChatTemplateValue>(field.clone())
.with_context(|| {
format!(
"Failed to deserialize 'chat_template' in '{}'",
path.display()
)
})?;
config.chat_template = Some(value);
}
_ => {}
}
Self::from_parts(
config,
mdc.prompt_context
......@@ -82,8 +114,9 @@ impl PromptFormatter {
mdc.runtime_config.exclude_tools_when_tool_choice_none,
)
}
PromptFormatterArtifact::HfChatTemplate { .. } => Err(anyhow::anyhow!(
"prompt_formatter should not have type HfChatTemplate"
PromptFormatterArtifact::HfChatTemplateJinja { .. }
| PromptFormatterArtifact::HfChatTemplateJson { .. } => Err(anyhow::anyhow!(
"prompt_formatter should not have type HfChatTemplate*"
)),
}
}
......
{"chat_template": "{%- for message in messages %}{%- if message.role == 'user' %}{{ '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}{%- elif message.role == 'assistant' %}{{ '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }}{%- endif %}{%- endfor %}{%- if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{%- endif %}"}
......@@ -2,6 +2,5 @@
"bos_token": "<|endoftext|>",
"eos_token": "<|im_end|>",
"model_max_length": 32768,
"tokenizer_class": "Qwen2Tokenizer",
"chat_template": "{% for message in messages %}{{ message.content }}{% endfor %}"
"tokenizer_class": "Qwen2Tokenizer"
}
......@@ -70,3 +70,24 @@ async fn test_model_loads_without_tokenizer_json() {
// Model info should still be loaded
assert!(mdc.model_info.is_some());
}
/// chat_template.json should be picked up as a fallback when chat_template.jinja
/// does not exist (e.g. Qwen3-Omni). The fixture's tokenizer_config.json has no
/// inline chat_template, so this is the only template source.
#[tokio::test]
async fn test_chat_template_json_fallback() {
let path = "tests/data/sample-models/mock-no-tokenizer-json";
let mdc = ModelDeploymentCard::load_from_disk(path, None).unwrap();
match &mdc.chat_template_file {
Some(PromptFormatterArtifact::HfChatTemplateJson { file, is_custom }) => {
assert!(!is_custom, "Should not be marked as custom template");
let p = file.path().expect("Should be a local path");
assert!(
p.ends_with("chat_template.json"),
"Expected chat_template.json, got {:?}",
p
);
}
other => panic!("Expected HfChatTemplateJson, got {:?}", other),
}
}
......@@ -76,7 +76,11 @@ async fn maybe_download_model(local_path: &str, model: &str, revision: &str) ->
let repo = Repo::with_revision(String::from(model), RepoType::Model, String::from(revision));
let files_to_download = vec!["config.json", "tokenizer.json", "tokenizer_config.json"];
let optional_files = vec!["generation_config.json", "chat_template.jinja"];
let optional_files = vec![
"generation_config.json",
"chat_template.jinja",
"chat_template.json",
];
let repo_builder = api.repo(repo);
let mut downloaded_path = PathBuf::new();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment