Unverified Commit f395b641 authored by Keiven C's avatar Keiven C Committed by GitHub
Browse files

fix: support multimodal models with non-standard Jinja2 tags (#4379)


Signed-off-by: default avatarKeiven Chang <keivenchang@users.noreply.github.com>
parent 4215b59a
......@@ -621,7 +621,8 @@ struct HFTextConfig {
max_position_embeddings: Option<usize>,
/// number of layers in the model
num_hidden_layers: usize,
/// Optional because some multimodal models (e.g., LLaVA) don't include this in text_config
num_hidden_layers: Option<usize>,
/// number of attention heads in the model
num_attention_heads: Option<usize>,
......@@ -701,11 +702,32 @@ impl HFConfig {
})
.or_else(|| {
// Maybe it's in generation_config.json
crate::file_json_field(&gencfg_path, "eos_token_id")
crate::file_json_field::<serde_json::Value>(&gencfg_path, "eos_token_id")
.inspect_err(
|err| tracing::warn!(%err, "Missing eos_token_id in generation_config.json"),
)
.ok()
.and_then(|v| {
if v.is_number() {
v.as_number()
.and_then(|n| n.as_u64())
.map(|n| vec![n as TokenIdType])
} else if v.is_array() {
let arr = v.as_array().unwrap();
Some(
arr.iter()
.filter_map(|inner_v| {
inner_v
.as_number()
.and_then(|n| n.as_u64())
.map(|n| n as TokenIdType)
})
.collect(),
)
} else {
None
}
})
})
.ok_or_else(|| {
anyhow::anyhow!(
......
......@@ -9,6 +9,24 @@ use either::Either;
use minijinja::{Environment, Value};
use tracing;
/// Remove known non-standard Jinja2 tags from chat templates
///
/// Some models use custom Jinja2 extensions that minijinja doesn't recognize. These tags
/// are typically metadata markers that don't affect the rendered output. For example:
/// - {% generation %} / {% endgeneration %}: Used by vLLM's AssistantTracker to mark
/// assistant-generated content. The tags themselves don't produce output.
///
/// By removing these tags before validation, we allow templates with backend-specific
/// extensions to work with minijinja while maintaining correct output semantics.
///
/// Note: This follows the same approach as Mistral.rs, which also strips these tags
/// for compatibility: https://github.com/EricLBuehler/mistral.rs/blob/2bcf0e9/mistralrs-core/src/pipeline/chat_template.rs#L318-L322
fn remove_known_non_jinja2_tags(template: &str) -> String {
template
.replace("{% generation %}", "")
.replace("{% endgeneration %}", "")
}
impl JinjaEnvironment {
fn env(self) -> Environment<'static> {
self.env
......@@ -64,8 +82,10 @@ impl HfTokenizerConfigJsonFormatter {
);
supports_add_generation_prompt = Some(true);
}
env.add_template_owned("default", x.to_string())?;
env.add_template_owned("tool_use", x.to_string())?;
// Remove known non-standard tags before validation (they don't affect output)
let template_cleaned = remove_known_non_jinja2_tags(x);
env.add_template_owned("default", template_cleaned.clone())?;
env.add_template_owned("tool_use", template_cleaned)?;
}
Either::Right(map) => {
for t in map {
......@@ -87,7 +107,9 @@ impl HfTokenizerConfigJsonFormatter {
} else {
supports_add_generation_prompt = Some(false);
}
env.add_template_owned(k.to_string(), v.to_string())?;
// Remove known non-standard tags before validation (they don't affect output)
let template_cleaned = remove_known_non_jinja2_tags(v);
env.add_template_owned(k.to_string(), template_cleaned)?;
}
}
if env.templates().count() == 0 {
......@@ -117,3 +139,30 @@ impl HfTokenizerConfigJsonFormatter {
// // fn apply_tool_template()
// }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remove_known_non_jinja2_tags() {
let template =
"USER: {{ message }} ASSISTANT: {% generation %}Reply here{% endgeneration %}";
let result = remove_known_non_jinja2_tags(template);
assert_eq!(result, "USER: {{ message }} ASSISTANT: Reply here");
}
#[test]
fn test_remove_known_non_jinja2_tags_preserves_standard_tags() {
let template = "{% for item in items %}{{ item }}{% endfor %}";
let result = remove_known_non_jinja2_tags(template);
assert_eq!(result, template);
}
#[test]
fn test_remove_known_non_jinja2_tags_multiple() {
let template = "Start {% generation %}Part 1{% endgeneration %} middle {% generation %}Part 2{% endgeneration %}";
let result = remove_known_non_jinja2_tags(template);
assert_eq!(result, "Start Part 1 middle Part 2");
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment