Unverified Commit 06a24503 authored by Neal Vaidya's avatar Neal Vaidya Committed by GitHub
Browse files

feat: support chat_template.json as a prompt formatter artifact (#7785)

Closes https://github.com/ai-dynamo/dynamo/issues/7737
parent 52a3ca94
...@@ -104,14 +104,25 @@ impl TokenizerKind { ...@@ -104,14 +104,25 @@ impl TokenizerKind {
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
pub enum PromptFormatterArtifact { pub enum PromptFormatterArtifact {
HfTokenizerConfigJson(CheckedFile), HfTokenizerConfigJson(CheckedFile),
HfChatTemplate { is_custom: bool, file: CheckedFile }, #[serde(rename = "hf_chat_template", alias = "hf_chat_template_jinja")]
HfChatTemplateJinja {
is_custom: bool,
file: CheckedFile,
},
HfChatTemplateJson {
is_custom: bool,
file: CheckedFile,
},
} }
impl PromptFormatterArtifact { impl PromptFormatterArtifact {
pub fn checksum(&self) -> String { pub fn checksum(&self) -> String {
match self { match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.checksum().to_string(), PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.checksum().to_string(),
PromptFormatterArtifact::HfChatTemplate { file: c, .. } => c.checksum().to_string(), PromptFormatterArtifact::HfChatTemplateJinja { file: c, .. }
| PromptFormatterArtifact::HfChatTemplateJson { file: c, .. } => {
c.checksum().to_string()
}
} }
} }
...@@ -119,21 +130,24 @@ impl PromptFormatterArtifact { ...@@ -119,21 +130,24 @@ impl PromptFormatterArtifact {
pub fn is_local(&self) -> bool { pub fn is_local(&self) -> bool {
match self { match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.is_local(), PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.is_local(),
PromptFormatterArtifact::HfChatTemplate { file: c, .. } => c.is_local(), PromptFormatterArtifact::HfChatTemplateJinja { file: c, .. }
| PromptFormatterArtifact::HfChatTemplateJson { file: c, .. } => c.is_local(),
} }
} }
pub fn update_dir(&mut self, dir: &Path) { pub fn update_dir(&mut self, dir: &Path) {
match self { match self {
PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.update_dir(dir), PromptFormatterArtifact::HfTokenizerConfigJson(c) => c.update_dir(dir),
PromptFormatterArtifact::HfChatTemplate { file: c, .. } => c.update_dir(dir), PromptFormatterArtifact::HfChatTemplateJinja { file: c, .. }
| PromptFormatterArtifact::HfChatTemplateJson { file: c, .. } => c.update_dir(dir),
} }
} }
pub fn is_custom(&self) -> bool { pub fn is_custom(&self) -> bool {
match self { match self {
PromptFormatterArtifact::HfTokenizerConfigJson(_) => false, PromptFormatterArtifact::HfTokenizerConfigJson(_) => false,
PromptFormatterArtifact::HfChatTemplate { is_custom, .. } => *is_custom, PromptFormatterArtifact::HfChatTemplateJinja { is_custom, .. }
| PromptFormatterArtifact::HfChatTemplateJson { is_custom, .. } => *is_custom,
} }
} }
} }
...@@ -553,10 +567,16 @@ impl ModelDeploymentCard { ...@@ -553,10 +567,16 @@ impl ModelDeploymentCard {
// We only "move" the chat template if it came form the repo. If we have a custom template // We only "move" the chat template if it came form the repo. If we have a custom template
// file we cannot download that from HF. // file we cannot download that from HF.
if let Some(PromptFormatterArtifact::HfChatTemplate { if let Some(
PromptFormatterArtifact::HfChatTemplateJinja {
file: src_file, file: src_file,
is_custom, is_custom,
}) = self.chat_template_file.as_mut() }
| PromptFormatterArtifact::HfChatTemplateJson {
file: src_file,
is_custom,
},
) = self.chat_template_file.as_mut()
{ {
if *is_custom { if *is_custom {
tracing::info!( tracing::info!(
...@@ -708,7 +728,7 @@ impl ModelDeploymentCard { ...@@ -708,7 +728,7 @@ impl ModelDeploymentCard {
) )
})?; })?;
Some(PromptFormatterArtifact::HfChatTemplate { Some(PromptFormatterArtifact::HfChatTemplateJinja {
is_custom: custom_template_path.is_some(), is_custom: custom_template_path.is_some(),
file: CheckedFile::from_disk(template_path)?, file: CheckedFile::from_disk(template_path)?,
}) })
...@@ -1001,13 +1021,29 @@ impl PromptFormatterArtifact { ...@@ -1001,13 +1021,29 @@ impl PromptFormatterArtifact {
} }
pub fn chat_template_from_disk(directory: &Path) -> Result<Option<Self>> { pub fn chat_template_from_disk(directory: &Path) -> Result<Option<Self>> {
match CheckedFile::from_disk(directory.join("chat_template.jinja")) { // Try chat_template.jinja first (raw Jinja template)
Ok(f) => Ok(Some(Self::HfChatTemplate { let jinja_path = directory.join("chat_template.jinja");
if jinja_path.exists() {
let f = CheckedFile::from_disk(&jinja_path)
.with_context(|| format!("Failed to load {}", jinja_path.display()))?;
return Ok(Some(Self::HfChatTemplateJinja {
file: f, file: f,
is_custom: false, is_custom: false,
})), }));
Err(_) => Ok(None),
} }
// Try chat_template.json (JSON with "chat_template" key, e.g. Qwen3-Omni)
let json_path = directory.join("chat_template.json");
if json_path.exists() {
let f = CheckedFile::from_disk(&json_path)
.with_context(|| format!("Failed to load {}", json_path.display()))?;
return Ok(Some(Self::HfChatTemplateJson {
file: f,
is_custom: false,
}));
}
Ok(None)
} }
} }
......
...@@ -58,22 +58,54 @@ impl PromptFormatter { ...@@ -58,22 +58,54 @@ impl PromptFormatter {
// stores the chat template in a separate file, we check if the file exists and // stores the chat template in a separate file, we check if the file exists and
// put the chat template into config as normalization. // put the chat template into config as normalization.
// This may also be a custom template provided via CLI flag. // This may also be a custom template provided via CLI flag.
if let Some(PromptFormatterArtifact::HfChatTemplate { match mdc.chat_template_file.as_ref() {
file: checked_file, .. Some(PromptFormatterArtifact::HfChatTemplateJinja {
}) = mdc.chat_template_file.as_ref() file: checked_file,
{ ..
let Some(chat_template_file) = checked_file.path() else { }) => {
let Some(path) = checked_file.path() else {
anyhow::bail!( anyhow::bail!(
"HfChatTemplate for {} is a URL, cannot load", "HfChatTemplateJinja for {} is a URL, cannot load",
mdc.display_name mdc.display_name
); );
}; };
let chat_template = let chat_template = std::fs::read_to_string(path)
std::fs::read_to_string(chat_template_file).with_context(|| { .with_context(|| format!("fs:read_to_string '{}'", path.display()))?;
format!("fs:read_to_string '{}'", chat_template_file.display())
})?;
config.chat_template = Some(ChatTemplateValue(either::Left(chat_template))); config.chat_template = Some(ChatTemplateValue(either::Left(chat_template)));
} }
Some(PromptFormatterArtifact::HfChatTemplateJson {
file: checked_file,
..
}) => {
let Some(path) = checked_file.path() else {
anyhow::bail!(
"HfChatTemplateJson for {} is a URL, cannot load",
mdc.display_name
);
};
let raw = std::fs::read_to_string(path)
.with_context(|| format!("fs:read_to_string '{}'", path.display()))?;
let wrapper: serde_json::Value =
serde_json::from_str(&raw).with_context(|| {
format!("Failed to parse '{}' as JSON", path.display())
})?;
let field = wrapper.get("chat_template").ok_or_else(|| {
anyhow::anyhow!(
"'{}' does not contain a 'chat_template' field",
path.display()
)
})?;
let value = serde_json::from_value::<ChatTemplateValue>(field.clone())
.with_context(|| {
format!(
"Failed to deserialize 'chat_template' in '{}'",
path.display()
)
})?;
config.chat_template = Some(value);
}
_ => {}
}
Self::from_parts( Self::from_parts(
config, config,
mdc.prompt_context mdc.prompt_context
...@@ -82,8 +114,9 @@ impl PromptFormatter { ...@@ -82,8 +114,9 @@ impl PromptFormatter {
mdc.runtime_config.exclude_tools_when_tool_choice_none, mdc.runtime_config.exclude_tools_when_tool_choice_none,
) )
} }
PromptFormatterArtifact::HfChatTemplate { .. } => Err(anyhow::anyhow!( PromptFormatterArtifact::HfChatTemplateJinja { .. }
"prompt_formatter should not have type HfChatTemplate" | PromptFormatterArtifact::HfChatTemplateJson { .. } => Err(anyhow::anyhow!(
"prompt_formatter should not have type HfChatTemplate*"
)), )),
} }
} }
......
{"chat_template": "{%- for message in messages %}{%- if message.role == 'user' %}{{ '<|im_start|>user\n' + message.content + '<|im_end|>\n' }}{%- elif message.role == 'assistant' %}{{ '<|im_start|>assistant\n' + message.content + '<|im_end|>\n' }}{%- endif %}{%- endfor %}{%- if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{%- endif %}"}
...@@ -2,6 +2,5 @@ ...@@ -2,6 +2,5 @@
"bos_token": "<|endoftext|>", "bos_token": "<|endoftext|>",
"eos_token": "<|im_end|>", "eos_token": "<|im_end|>",
"model_max_length": 32768, "model_max_length": 32768,
"tokenizer_class": "Qwen2Tokenizer", "tokenizer_class": "Qwen2Tokenizer"
"chat_template": "{% for message in messages %}{{ message.content }}{% endfor %}"
} }
...@@ -70,3 +70,24 @@ async fn test_model_loads_without_tokenizer_json() { ...@@ -70,3 +70,24 @@ async fn test_model_loads_without_tokenizer_json() {
// Model info should still be loaded // Model info should still be loaded
assert!(mdc.model_info.is_some()); assert!(mdc.model_info.is_some());
} }
/// chat_template.json should be picked up as a fallback when chat_template.jinja
/// does not exist (e.g. Qwen3-Omni). The fixture's tokenizer_config.json has no
/// inline chat_template, so this is the only template source.
#[tokio::test]
async fn test_chat_template_json_fallback() {
let path = "tests/data/sample-models/mock-no-tokenizer-json";
let mdc = ModelDeploymentCard::load_from_disk(path, None).unwrap();
match &mdc.chat_template_file {
Some(PromptFormatterArtifact::HfChatTemplateJson { file, is_custom }) => {
assert!(!is_custom, "Should not be marked as custom template");
let p = file.path().expect("Should be a local path");
assert!(
p.ends_with("chat_template.json"),
"Expected chat_template.json, got {:?}",
p
);
}
other => panic!("Expected HfChatTemplateJson, got {:?}", other),
}
}
...@@ -76,7 +76,11 @@ async fn maybe_download_model(local_path: &str, model: &str, revision: &str) -> ...@@ -76,7 +76,11 @@ async fn maybe_download_model(local_path: &str, model: &str, revision: &str) ->
let repo = Repo::with_revision(String::from(model), RepoType::Model, String::from(revision)); let repo = Repo::with_revision(String::from(model), RepoType::Model, String::from(revision));
let files_to_download = vec!["config.json", "tokenizer.json", "tokenizer_config.json"]; let files_to_download = vec!["config.json", "tokenizer.json", "tokenizer_config.json"];
let optional_files = vec!["generation_config.json", "chat_template.jinja"]; let optional_files = vec![
"generation_config.json",
"chat_template.jinja",
"chat_template.json",
];
let repo_builder = api.repo(repo); let repo_builder = api.repo(repo);
let mut downloaded_path = PathBuf::new(); let mut downloaded_path = PathBuf::new();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment