Unverified Commit 64582caa authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Refactor chat template content format detection (#11288)

parent 2fcd56ea
......@@ -4,7 +4,8 @@
//! similar to HuggingFace transformers' apply_chat_template method.
use anyhow::{anyhow, Result};
use minijinja::{context, machinery, Environment, Value};
use minijinja::machinery::ast::{Expr, Stmt};
use minijinja::{context, Environment, Value};
use serde_json;
use std::collections::HashMap;
......@@ -50,243 +51,293 @@ pub fn detect_chat_template_content_format(template: &str) -> ChatTemplateConten
ChatTemplateContentFormat::String
}
/// AST-based detection using minijinja's unstable machinery
/// This implements the exact same logic as SGLang's _is_var_or_elems_access functions
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
use minijinja::machinery::{parse, WhitespaceConfig};
use minijinja::syntax::SyntaxConfig;
// Parse the template into AST
let ast = match parse(
template,
"template",
SyntaxConfig {},
WhitespaceConfig::default(),
) {
Ok(ast) => ast,
Err(_) => return Some(ChatTemplateContentFormat::String),
};
// Traverse AST looking for patterns that indicate OpenAI format
let has_iteration = find_content_iteration_in_ast(&ast);
let has_structure_checks = find_content_structure_checks_in_ast(&ast);
let has_assignment_patterns = find_variable_assignment_patterns_in_ast(&ast);
/// Flags tracking which OpenAI-style patterns we've seen
#[derive(Default, Debug, Clone, Copy)]
struct Flags {
saw_iteration: bool,
saw_structure: bool,
saw_assignment: bool,
saw_macro: bool,
}
if has_iteration || has_structure_checks || has_assignment_patterns {
Some(ChatTemplateContentFormat::OpenAI)
} else {
Some(ChatTemplateContentFormat::String)
impl Flags {
fn any(self) -> bool {
self.saw_iteration || self.saw_structure || self.saw_assignment || self.saw_macro
}
}
/// Find content iteration patterns in AST
/// Implements the same logic as SGLang's AST traversal
fn find_content_iteration_in_ast(ast: &machinery::ast::Stmt) -> bool {
use machinery::ast::Stmt;
/// Single-pass AST detector with scope tracking
struct Detector<'a> {
ast: &'a Stmt<'a>,
/// Message loop vars currently in scope (e.g., `message`, `m`, `msg`)
scope: std::collections::VecDeque<String>,
scope_set: std::collections::HashSet<String>,
flags: Flags,
}
match ast {
Stmt::Template(template) => {
// Recursively check all children
template
.children
.iter()
.any(|child| find_content_iteration_in_ast(child))
}
Stmt::ForLoop(for_loop) => {
// Check if this for-loop iterates over message content
is_var_or_elems_access(&for_loop.iter, "message", "content") ||
is_var_or_elems_access(&for_loop.iter, "msg", "content") ||
is_var_or_elems_access(&for_loop.iter, "m", "content") ||
// Also check the body for nested loops
for_loop.body.iter().any(|stmt| find_content_iteration_in_ast(stmt))
impl<'a> Detector<'a> {
fn new(ast: &'a Stmt<'a>) -> Self {
Self {
ast,
scope: std::collections::VecDeque::new(),
scope_set: std::collections::HashSet::new(),
flags: Flags::default(),
}
Stmt::IfCond(if_cond) => {
// Check true and false branches
if_cond
.true_body
.iter()
.any(|stmt| find_content_iteration_in_ast(stmt))
|| if_cond
.false_body
.iter()
.any(|stmt| find_content_iteration_in_ast(stmt))
}
_ => false, // Other statement types don't contain loops
}
}
/// Check if expression accesses varname['key'] or varname.key
/// Implements SGLang's _is_var_or_elems_access logic using actual AST nodes
fn is_var_or_elems_access(expr: &machinery::ast::Expr, varname: &str, key: &str) -> bool {
use machinery::ast::Expr;
match expr {
// Check for attribute access: varname.key
Expr::GetAttr(getattr) => is_var_access(&getattr.expr, varname) && getattr.name == key,
// Check for item access: varname['key'] or varname["key"]
Expr::GetItem(getitem) => {
is_var_access(&getitem.expr, varname) && is_const_string(&getitem.subscript_expr, key)
}
// Handle filters and tests that might wrap the access
Expr::Filter(filter) => {
if let Some(ref expr) = filter.expr {
is_var_or_elems_access(expr, varname, key)
} else {
false
}
fn run(mut self) -> Flags {
self.walk_stmt(self.ast);
self.flags
}
fn push_scope(&mut self, var: String) {
self.scope.push_back(var.clone());
self.scope_set.insert(var);
}
fn pop_scope(&mut self) {
if let Some(v) = self.scope.pop_back() {
self.scope_set.remove(&v);
}
Expr::Test(test) => is_var_or_elems_access(&test.expr, varname, key),
_ => false,
}
}
/// Check if expression is a variable access (like {{ varname }})
/// Implements SGLang's _is_var_access logic
fn is_var_access(expr: &machinery::ast::Expr, varname: &str) -> bool {
matches!(expr, machinery::ast::Expr::Var(var) if var.id == varname)
}
fn is_var_access(expr: &Expr, varname: &str) -> bool {
matches!(expr, Expr::Var(v) if v.id == varname)
}
/// Check if expression is a constant string with the given value
fn is_const_string(expr: &machinery::ast::Expr, value: &str) -> bool {
matches!(expr, machinery::ast::Expr::Const(const_expr)
if const_expr.value.as_str() == Some(value))
}
fn is_const_str(expr: &Expr, value: &str) -> bool {
matches!(expr, Expr::Const(c) if c.value.as_str() == Some(value))
}
/// Find content structure checks in AST (like content[0], content|length)
fn find_content_structure_checks_in_ast(ast: &machinery::ast::Stmt) -> bool {
use machinery::ast::Stmt;
match ast {
Stmt::Template(template) => template
.children
.iter()
.any(|child| find_content_structure_checks_in_ast(child)),
Stmt::ForLoop(for_loop) => for_loop
.body
.iter()
.any(|stmt| find_content_structure_checks_in_ast(stmt)),
Stmt::IfCond(if_cond) => {
// Check if condition has content structure checks
has_content_structure_check_expr(&if_cond.expr)
|| if_cond
.true_body
.iter()
.any(|stmt| find_content_structure_checks_in_ast(stmt))
|| if_cond
.false_body
.iter()
.any(|stmt| find_content_structure_checks_in_ast(stmt))
}
Stmt::EmitExpr(expr) => has_content_structure_check_expr(&expr.expr),
_ => false,
fn is_numeric_const(expr: &Expr) -> bool {
matches!(expr, Expr::Const(c) if c.value.is_number())
}
}
/// Find variable assignment patterns like set content = message['content']
fn find_variable_assignment_patterns_in_ast(ast: &machinery::ast::Stmt) -> bool {
use machinery::ast::Stmt;
match ast {
Stmt::Template(template) => template
.children
.iter()
.any(|child| find_variable_assignment_patterns_in_ast(child)),
Stmt::ForLoop(for_loop) => {
// Check if this for-loop body contains both assignment and iteration
let has_assignment = for_loop
.body
.iter()
.any(|stmt| is_content_assignment_stmt(stmt));
let has_iteration = for_loop.body.iter().any(|stmt| {
is_content_variable_iteration(stmt)
|| matches!(stmt, Stmt::IfCond(if_cond) if
if_cond.true_body.iter().any(|s| is_content_variable_iteration(s)) ||
if_cond.false_body.iter().any(|s| is_content_variable_iteration(s))
)
});
(has_assignment && has_iteration)
|| for_loop
.body
.iter()
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
/// Check if expr is varname.content or varname["content"]
fn is_var_dot_content(expr: &Expr, varname: &str) -> bool {
match expr {
Expr::GetAttr(g) => Self::is_var_access(&g.expr, varname) && g.name == "content",
Expr::GetItem(g) => {
Self::is_var_access(&g.expr, varname)
&& Self::is_const_str(&g.subscript_expr, "content")
}
// Unwrap filters/tests that just wrap the same expr
Expr::Filter(f) => f
.expr
.as_ref()
.is_some_and(|e| Self::is_var_dot_content(e, varname)),
Expr::Test(t) => Self::is_var_dot_content(&t.expr, varname),
_ => false,
}
Stmt::IfCond(if_cond) => {
if_cond
.true_body
}
/// Check if expr accesses .content on any variable in our scope, or any descendant of it.
fn is_any_scope_var_content(&self, expr: &Expr) -> bool {
let mut current_expr = expr;
loop {
// Check if current level matches <scopeVar>.content
if self
.scope_set
.iter()
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
|| if_cond
.false_body
.iter()
.any(|stmt| find_variable_assignment_patterns_in_ast(stmt))
.any(|v| Self::is_var_dot_content(current_expr, v))
{
return true;
}
// Walk up the expression tree
match current_expr {
Expr::GetAttr(g) => current_expr = &g.expr,
Expr::GetItem(g) => current_expr = &g.expr,
_ => return false,
}
}
_ => false,
}
}
/// Check if expression has content structure checks (index access, length, etc.)
fn has_content_structure_check_expr(expr: &machinery::ast::Expr) -> bool {
use machinery::ast::Expr;
match expr {
// Check for content[0] - index access
Expr::GetItem(getitem) => {
is_content_access(&getitem.expr) && is_numeric_constant(&getitem.subscript_expr)
fn walk_stmt(&mut self, stmt: &Stmt) {
// Early exit if we've already detected an OpenAI pattern
if self.flags.any() {
return;
}
// Check for content|length - filter with length
Expr::Filter(filter) => {
if let Some(ref filter_expr) = filter.expr {
is_content_access(filter_expr) && filter.name == "length"
} else {
false
match stmt {
Stmt::Template(t) => {
for ch in &t.children {
self.walk_stmt(ch);
}
}
// {% for message in messages %}
Stmt::ForLoop(fl) => {
// Detect "for X in messages" → push X into scope
if let Expr::Var(iter) = &fl.iter {
if iter.id == "messages" {
if let Expr::Var(target) = &fl.target {
self.push_scope(target.id.to_string());
}
}
}
// Also detect "for ... in message.content" or "for ... in content"
// - Iterating directly over <scopeVar>.content => OpenAI style
if self.is_any_scope_var_content(&fl.iter) {
self.flags.saw_iteration = true;
}
// - Iterating over a local var named "content"
if matches!(&fl.iter, Expr::Var(v) if v.id == "content") {
self.flags.saw_iteration = true;
}
for b in &fl.body {
self.walk_stmt(b);
}
// Pop scope if we pushed it
if let Expr::Var(iter) = &fl.iter {
if iter.id == "messages" && matches!(&fl.target, Expr::Var(_)) {
self.pop_scope();
}
}
}
Stmt::IfCond(ic) => {
self.inspect_expr_for_structure(&ic.expr);
for b in &ic.true_body {
self.walk_stmt(b);
}
for b in &ic.false_body {
self.walk_stmt(b);
}
}
Stmt::EmitExpr(e) => {
self.inspect_expr_for_structure(&e.expr);
}
// {% set content = message.content %}
Stmt::Set(s) => {
if Self::is_var_access(&s.target, "content")
&& self.is_any_scope_var_content(&s.expr)
{
self.flags.saw_assignment = true;
}
}
Stmt::Macro(m) => {
// Heuristic: macro that checks type (via `is` test) and also has any loop
let mut has_type_check = false;
let mut has_loop = false;
Self::scan_macro_body(&m.body, &mut has_type_check, &mut has_loop);
if has_type_check && has_loop {
self.flags.saw_macro = true;
}
}
_ => {}
}
// Check for content is sequence/iterable
Expr::Test(test) => {
is_content_access(&test.expr) && (test.name == "sequence" || test.name == "iterable")
}
_ => false,
}
}
/// Check if statement assigns message content to a variable
fn is_content_assignment_stmt(stmt: &machinery::ast::Stmt) -> bool {
use machinery::ast::Stmt;
fn inspect_expr_for_structure(&mut self, expr: &Expr) {
if self.flags.saw_structure {
return;
}
match stmt {
Stmt::Set(set_stmt) => {
// Check if this is setting content = message['content']
is_var_access(&set_stmt.target, "content")
&& is_var_or_elems_access(&set_stmt.expr, "message", "content")
match expr {
// content[0] or message.content[0]
Expr::GetItem(gi) => {
if (matches!(&gi.expr, Expr::Var(v) if v.id == "content")
|| self.is_any_scope_var_content(&gi.expr))
&& Self::is_numeric_const(&gi.subscript_expr)
{
self.flags.saw_structure = true;
}
}
// content|length or message.content|length
Expr::Filter(f) => {
if f.name == "length" {
if let Some(inner) = &f.expr {
// Box derefs automatically, so `&**inner` is `&Expr`
let inner_ref: &Expr = inner;
let is_content_var = matches!(inner_ref, Expr::Var(v) if v.id == "content");
if is_content_var || self.is_any_scope_var_content(inner_ref) {
self.flags.saw_structure = true;
}
}
} else if let Some(inner) = &f.expr {
let inner_ref: &Expr = inner;
self.inspect_expr_for_structure(inner_ref);
}
}
// content is sequence/iterable OR message.content is sequence/iterable
Expr::Test(t) => {
if t.name == "sequence" || t.name == "iterable" || t.name == "string" {
if matches!(&t.expr, Expr::Var(v) if v.id == "content")
|| self.is_any_scope_var_content(&t.expr)
{
self.flags.saw_structure = true;
}
} else {
self.inspect_expr_for_structure(&t.expr);
}
}
Expr::GetAttr(g) => {
// Keep walking; nested expressions can hide structure checks
self.inspect_expr_for_structure(&g.expr);
}
// Handle binary operations like: if (message.content is string) and other_cond
Expr::BinOp(op) => {
self.inspect_expr_for_structure(&op.left);
self.inspect_expr_for_structure(&op.right);
}
// Handle unary operations like: if not (message.content is string)
Expr::UnaryOp(op) => {
self.inspect_expr_for_structure(&op.expr);
}
_ => {}
}
_ => false,
}
}
/// Check if statement iterates over content variable
fn is_content_variable_iteration(stmt: &machinery::ast::Stmt) -> bool {
use machinery::ast::{Expr, Stmt};
fn scan_macro_body(body: &[Stmt], has_type_check: &mut bool, has_loop: &mut bool) {
for s in body {
if *has_type_check && *has_loop {
return;
}
match stmt {
Stmt::ForLoop(for_loop) => {
// Check if iterating over a variable named "content"
matches!(for_loop.iter, Expr::Var(ref var) if var.id == "content")
match s {
Stmt::IfCond(ic) => {
if matches!(&ic.expr, Expr::Test(_)) {
*has_type_check = true;
}
Self::scan_macro_body(&ic.true_body, has_type_check, has_loop);
Self::scan_macro_body(&ic.false_body, has_type_check, has_loop);
}
Stmt::ForLoop(fl) => {
*has_loop = true;
Self::scan_macro_body(&fl.body, has_type_check, has_loop);
}
Stmt::Template(t) => {
Self::scan_macro_body(&t.children, has_type_check, has_loop);
}
_ => {}
}
}
_ => false,
}
}
/// Check if expression accesses content (message.content, message['content'], etc.)
fn is_content_access(expr: &machinery::ast::Expr) -> bool {
is_var_or_elems_access(expr, "message", "content")
|| is_var_or_elems_access(expr, "msg", "content")
|| is_var_or_elems_access(expr, "m", "content")
}
/// AST-based detection using minijinja's unstable machinery
/// Single-pass detector with scope tracking
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
use minijinja::machinery::{parse, WhitespaceConfig};
use minijinja::syntax::SyntaxConfig;
/// Check if expression is a numeric constant (for index access)
fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool {
matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number())
let ast = match parse(
template,
"template",
SyntaxConfig {},
WhitespaceConfig::default(),
) {
Ok(ast) => ast,
Err(_) => return Some(ChatTemplateContentFormat::String),
};
let flags = Detector::new(&ast).run();
Some(if flags.any() {
ChatTemplateContentFormat::OpenAI
} else {
ChatTemplateContentFormat::String
})
}
/// Parameters for chat template application
......
......@@ -249,3 +249,65 @@ fn test_chat_template_with_tokens_unit_test() {
assert!(result.contains("<s>"));
assert!(result.contains("</s>"));
}
#[test]
fn test_detect_openai_format_qwen3vl_macro_style() {
// Qwen3-VL style template using macros to handle multimodal content
// This tests the macro-based detection pattern
let template = r#"{%- set image_count = namespace(value=0) %}
{%- set video_count = namespace(value=0) %}
{%- macro render_content(content, do_vision_count) %}
{%- if content is string %}
{{- content }}
{%- else %}
{%- for item in content %}
{%- if 'image' in item or 'image_url' in item or item.type == 'image' %}
{%- if do_vision_count %}
{%- set image_count.value = image_count.value + 1 %}
{%- endif %}
{%- if add_vision_id %}Picture {{ image_count.value }}: {% endif -%}
<|vision_start|><|image_pad|><|vision_end|>
{%- elif 'video' in item or item.type == 'video' %}
{%- if do_vision_count %}
{%- set video_count.value = video_count.value + 1 %}
{%- endif %}
{%- if add_vision_id %}Video {{ video_count.value }}: {% endif -%}
<|vision_start|><|video_pad|><|vision_end|>
{%- elif 'text' in item %}
{{- item.text }}
{%- endif %}
{%- endfor %}
{%- endif %}
{%- endmacro %}
{%- for message in messages %}
{%- set content = render_content(message.content, True) %}
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}"#;
assert_eq!(
detect_chat_template_content_format(template),
ChatTemplateContentFormat::OpenAI
);
}
#[test]
fn test_detect_openai_format_arbitrary_variable_names() {
// Test that detection works with any variable name, not just "message", "msg", "m"
// Uses "chat_msg" and "x" as loop variables
let template = r#"
{%- for chat_msg in messages %}
{%- for x in chat_msg.content %}
{%- if x.type == 'text' %}{{ x.text }}{%- endif %}
{%- if x.type == 'image' %}<image>{%- endif %}
{%- endfor %}
{%- endfor %}
"#;
assert_eq!(
detect_chat_template_content_format(template),
ChatTemplateContentFormat::OpenAI
);
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment