Unverified Commit 4a87ba21 authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

router-grpc: Add tools processing and other paramters for apply_chat_template (#10877)

parent d7b20dd6
...@@ -27,7 +27,7 @@ use crate::tokenizer::traits::Tokenizer; ...@@ -27,7 +27,7 @@ use crate::tokenizer::traits::Tokenizer;
use crate::tool_parser::ParserRegistry; use crate::tool_parser::ParserRegistry;
use uuid::Uuid; use uuid::Uuid;
use crate::tokenizer::chat_template::ChatTemplateContentFormat; use crate::tokenizer::chat_template::{ChatTemplateContentFormat, ChatTemplateParams};
use serde_json::Value; use serde_json::Value;
// Data structures for processing // Data structures for processing
...@@ -300,12 +300,87 @@ impl GrpcRouter { ...@@ -300,12 +300,87 @@ impl GrpcRouter {
{ {
// Get content format and transform messages accordingly // Get content format and transform messages accordingly
let content_format = hf_tokenizer.chat_template_content_format(); let content_format = hf_tokenizer.chat_template_content_format();
let transformed_messages = let mut transformed_messages =
Self::transform_messages_for_content_format(&request.messages, content_format)?; Self::process_content_format(&request.messages, content_format)?;
hf_tokenizer // Process tool call arguments in assistant messages
.apply_chat_template(&transformed_messages, true) Self::process_tool_call_arguments(&mut transformed_messages)?;
.map_err(|e| format!("Failed to apply chat template: {}", e))?
// Convert tools to JSON values for template processing
let tools_json: Option<Vec<serde_json::Value>> = request
.tools
.as_ref()
.map(|tools| {
tools
.iter()
.map(serde_json::to_value)
.collect::<Result<Vec<_>, _>>()
})
.transpose()
.map_err(|e| format!("Failed to serialize tools: {}", e))?;
// Build template kwargs, merging reasoning_effort if present
let mut combined_template_kwargs = std::collections::HashMap::new();
// Add reasoning_effort if present (like Python does)
if let Some(reasoning_effort) = &request.reasoning_effort {
combined_template_kwargs.insert(
"reasoning_effort".to_string(),
serde_json::Value::String(reasoning_effort.clone()),
);
}
// Add any additional template kwargs from request
if let Some(template_kwargs) = &request.chat_template_kwargs {
for (key, value) in template_kwargs {
combined_template_kwargs.insert(key.clone(), value.clone());
}
}
let final_template_kwargs = if combined_template_kwargs.is_empty() {
None
} else {
Some(&combined_template_kwargs)
};
let params = ChatTemplateParams {
add_generation_prompt: true,
continue_final_message: request.continue_final_message,
tools: tools_json.as_deref(),
template_kwargs: final_template_kwargs,
..Default::default()
};
// Handle assistant prefix for continue_final_message
let assistant_prefix = if request.continue_final_message
&& !transformed_messages.is_empty()
&& transformed_messages
.last()
.and_then(|msg| msg.get("role"))
.and_then(|v| v.as_str())
== Some("assistant")
{
// Pop the last message to handle it separately
let last_msg = transformed_messages.pop().unwrap();
last_msg
.get("content")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
} else {
None
};
// Apply chat template with the (now possibly shorter) list of messages
let rendered = hf_tokenizer
.apply_chat_template(&transformed_messages, params)
.map_err(|e| format!("Failed to apply chat template: {}", e))?;
// Append assistant prefix if we have one
if let Some(prefix) = assistant_prefix {
format!("{}{}", rendered, prefix)
} else {
rendered
}
} else { } else {
return Err( return Err(
"gRPC router requires HuggingFace tokenizer with chat template support".to_string(), "gRPC router requires HuggingFace tokenizer with chat template support".to_string(),
...@@ -322,8 +397,8 @@ impl GrpcRouter { ...@@ -322,8 +397,8 @@ impl GrpcRouter {
}) })
} }
/// Transform messages based on content format for ANY message type /// Process messages based on content format for ANY message type
fn transform_messages_for_content_format( fn process_content_format(
messages: &[crate::protocols::spec::ChatMessage], messages: &[crate::protocols::spec::ChatMessage],
content_format: crate::tokenizer::chat_template::ChatTemplateContentFormat, content_format: crate::tokenizer::chat_template::ChatTemplateContentFormat,
) -> Result<Vec<serde_json::Value>, String> { ) -> Result<Vec<serde_json::Value>, String> {
...@@ -394,6 +469,49 @@ impl GrpcRouter { ...@@ -394,6 +469,49 @@ impl GrpcRouter {
} }
} }
/// Process tool call arguments in messages
/// Per Transformers docs, tool call arguments in assistant messages should be dicts
fn process_tool_call_arguments(messages: &mut [serde_json::Value]) -> Result<(), String> {
for msg in messages {
// Early return if not assistant message
let role = msg.get("role").and_then(|v| v.as_str());
if role != Some("assistant") {
continue;
}
// Early return if no tool_calls
let Some(tool_calls) = msg.get_mut("tool_calls").and_then(|tc| tc.as_array_mut())
else {
continue;
};
// Process each tool call's arguments
for call in tool_calls {
let Some(function) = call.get_mut("function") else {
continue;
};
let Some(args) = function.get_mut("arguments") else {
continue;
};
let Some(args_str) = args.as_str() else {
continue;
};
// Parse JSON string to object (like Python json.loads)
match serde_json::from_str::<serde_json::Value>(args_str) {
Ok(parsed) => *args = parsed,
Err(e) => {
return Err(format!(
"Failed to parse tool call arguments as JSON: '{}'. Error: {}",
args_str, e
))
}
}
}
}
Ok(())
}
/// Build gRPC SamplingParams from OpenAI request /// Build gRPC SamplingParams from OpenAI request
fn build_grpc_sampling_params( fn build_grpc_sampling_params(
&self, &self,
...@@ -410,6 +528,19 @@ impl GrpcRouter { ...@@ -410,6 +528,19 @@ impl GrpcRouter {
.or(request.max_tokens) .or(request.max_tokens)
.map(|v| v as i32); .map(|v| v as i32);
// Handle skip_special_tokens: set to false if tools are present and tool_choice is not "none"
let skip_special_tokens = if request.tools.is_some() {
match &request.tool_choice {
Some(crate::protocols::spec::ToolChoice::Value(
crate::protocols::spec::ToolChoiceValue::None,
)) => request.skip_special_tokens,
Some(_) => false, // tool_choice is not "none"
None => false, // TODO: this assumes tool_choice defaults to "auto" when tools present
}
} else {
request.skip_special_tokens
};
#[allow(deprecated)] #[allow(deprecated)]
Ok(proto::SamplingParams { Ok(proto::SamplingParams {
temperature: request.temperature.unwrap_or(1.0), temperature: request.temperature.unwrap_or(1.0),
...@@ -422,7 +553,7 @@ impl GrpcRouter { ...@@ -422,7 +553,7 @@ impl GrpcRouter {
max_new_tokens, max_new_tokens,
stop: stop_sequences, stop: stop_sequences,
stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(), stop_token_ids: request.stop_token_ids.clone().unwrap_or_default(),
skip_special_tokens: request.skip_special_tokens, skip_special_tokens,
n: request.n.unwrap_or(1) as i32, n: request.n.unwrap_or(1) as i32,
structural_tag: structural_tag.unwrap_or_default(), structural_tag: structural_tag.unwrap_or_default(),
constraint: self.build_constraint(request)?, constraint: self.build_constraint(request)?,
...@@ -700,10 +831,8 @@ mod tests { ...@@ -700,10 +831,8 @@ mod tests {
name: None, name: None,
}]; }];
let result = GrpcRouter::transform_messages_for_content_format( let result =
&messages, GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
ChatTemplateContentFormat::String,
)
.unwrap(); .unwrap();
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
...@@ -735,10 +864,8 @@ mod tests { ...@@ -735,10 +864,8 @@ mod tests {
name: None, name: None,
}]; }];
let result = GrpcRouter::transform_messages_for_content_format( let result =
&messages, GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
ChatTemplateContentFormat::OpenAI,
)
.unwrap(); .unwrap();
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
...@@ -764,10 +891,8 @@ mod tests { ...@@ -764,10 +891,8 @@ mod tests {
name: None, name: None,
}]; }];
let result = GrpcRouter::transform_messages_for_content_format( let result =
&messages, GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
ChatTemplateContentFormat::String,
)
.unwrap(); .unwrap();
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
...@@ -791,10 +916,8 @@ mod tests { ...@@ -791,10 +916,8 @@ mod tests {
reasoning_content: None, reasoning_content: None,
}]; }];
let result = GrpcRouter::transform_messages_for_content_format( let result =
&messages, GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
ChatTemplateContentFormat::String,
)
.unwrap(); .unwrap();
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
...@@ -832,10 +955,8 @@ mod tests { ...@@ -832,10 +955,8 @@ mod tests {
}, },
]; ];
let result = GrpcRouter::transform_messages_for_content_format( let result =
&messages, GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
ChatTemplateContentFormat::String,
)
.unwrap(); .unwrap();
assert_eq!(result.len(), 2); assert_eq!(result.len(), 2);
...@@ -862,10 +983,8 @@ mod tests { ...@@ -862,10 +983,8 @@ mod tests {
name: None, name: None,
}]; }];
let result = GrpcRouter::transform_messages_for_content_format( let result =
&messages, GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
ChatTemplateContentFormat::String,
)
.unwrap(); .unwrap();
assert_eq!(result.len(), 1); assert_eq!(result.len(), 1);
...@@ -902,10 +1021,8 @@ mod tests { ...@@ -902,10 +1021,8 @@ mod tests {
]; ];
// Test String format // Test String format
let result_string = GrpcRouter::transform_messages_for_content_format( let result_string =
&messages, GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::String)
ChatTemplateContentFormat::String,
)
.unwrap(); .unwrap();
assert_eq!(result_string.len(), 2); assert_eq!(result_string.len(), 2);
...@@ -913,10 +1030,8 @@ mod tests { ...@@ -913,10 +1030,8 @@ mod tests {
assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image"); assert_eq!(result_string[1]["content"].as_str().unwrap(), "With image");
// Test OpenAI format // Test OpenAI format
let result_openai = GrpcRouter::transform_messages_for_content_format( let result_openai =
&messages, GrpcRouter::process_content_format(&messages, ChatTemplateContentFormat::OpenAI)
ChatTemplateContentFormat::OpenAI,
)
.unwrap(); .unwrap();
assert_eq!(result_openai.len(), 2); assert_eq!(result_openai.len(), 2);
......
...@@ -6,6 +6,7 @@ ...@@ -6,6 +6,7 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use minijinja::{context, machinery, Environment, Value}; use minijinja::{context, machinery, Environment, Value};
use serde_json; use serde_json;
use std::collections::HashMap;
/// Chat template content format /// Chat template content format
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
...@@ -288,21 +289,25 @@ fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool { ...@@ -288,21 +289,25 @@ fn is_numeric_constant(expr: &machinery::ast::Expr) -> bool {
matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number()) matches!(expr, machinery::ast::Expr::Const(const_expr) if const_expr.value.is_number())
} }
/// Parameters for chat template application
#[derive(Default)]
pub struct ChatTemplateParams<'a> {
pub add_generation_prompt: bool,
pub continue_final_message: bool,
pub tools: Option<&'a [serde_json::Value]>,
pub documents: Option<&'a [serde_json::Value]>,
pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
}
/// Chat template processor using Jinja2 - simple wrapper like HuggingFace /// Chat template processor using Jinja2 - simple wrapper like HuggingFace
pub struct ChatTemplateProcessor { pub struct ChatTemplateProcessor {
template: String, template: String,
bos_token: Option<String>,
eos_token: Option<String>,
} }
impl ChatTemplateProcessor { impl ChatTemplateProcessor {
/// Create a new chat template processor /// Create a new chat template processor
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self { pub fn new(template: String) -> Self {
ChatTemplateProcessor { ChatTemplateProcessor { template }
template,
bos_token,
eos_token,
}
} }
/// Apply the chat template to a list of messages /// Apply the chat template to a list of messages
...@@ -313,8 +318,12 @@ impl ChatTemplateProcessor { ...@@ -313,8 +318,12 @@ impl ChatTemplateProcessor {
pub fn apply_chat_template( pub fn apply_chat_template(
&self, &self,
messages: &[serde_json::Value], messages: &[serde_json::Value],
add_generation_prompt: bool, params: ChatTemplateParams,
) -> Result<String> { ) -> Result<String> {
// Validate incompatible options
if params.continue_final_message && params.add_generation_prompt {
return Err(anyhow!("continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."));
}
let mut env = Environment::new(); let mut env = Environment::new();
// Register the template // Register the template
...@@ -326,17 +335,29 @@ impl ChatTemplateProcessor { ...@@ -326,17 +335,29 @@ impl ChatTemplateProcessor {
.get_template("chat") .get_template("chat")
.map_err(|e| anyhow!("Failed to get template: {}", e))?; .map_err(|e| anyhow!("Failed to get template: {}", e))?;
// Convert ChatMessage to minijinja::Value for rendering using serde like pydantic.model_dump() // Convert messages to minijinja::Value (messages already processed by router)
let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect(); let minijinja_messages: Vec<Value> = messages.iter().map(Value::from_serialize).collect();
// Render the template directly with the provided values let base_context = context! {
messages => &minijinja_messages,
add_generation_prompt => params.add_generation_prompt,
tools => params.tools,
documents => params.documents,
};
// Merge with template_kwargs if provided
let ctx = if let Some(kwargs) = params.template_kwargs {
context! {
..base_context,
..Value::from_serialize(kwargs)
}
} else {
base_context
};
// Render the template
let rendered = tmpl let rendered = tmpl
.render(context! { .render(&ctx)
messages => minijinja_messages,
add_generation_prompt => add_generation_prompt,
bos_token => self.bos_token.clone().unwrap_or_default(),
eos_token => self.eos_token.clone().unwrap_or_default()
})
.map_err(|e| anyhow!("Failed to render template: {}", e))?; .map_err(|e| anyhow!("Failed to render template: {}", e))?;
Ok(rendered) Ok(rendered)
......
...@@ -4,7 +4,8 @@ use anyhow::{Error, Result}; ...@@ -4,7 +4,8 @@ use anyhow::{Error, Result};
use tokenizers::tokenizer::Tokenizer as HfTokenizer; use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::chat_template::{ use super::chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor, detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor,
}; };
use super::traits::{ use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait, Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
...@@ -165,16 +166,11 @@ impl HuggingFaceTokenizer { ...@@ -165,16 +166,11 @@ impl HuggingFaceTokenizer {
pub fn apply_chat_template( pub fn apply_chat_template(
&self, &self,
messages: &[serde_json::Value], messages: &[serde_json::Value],
add_generation_prompt: bool, params: ChatTemplateParams,
) -> Result<String> { ) -> Result<String> {
if let Some(ref template) = self.chat_template { if let Some(ref template) = self.chat_template {
let processor = ChatTemplateProcessor::new( let processor = ChatTemplateProcessor::new(template.clone());
template.clone(), processor.apply_chat_template(messages, params)
self.special_tokens.bos_token.clone(),
self.special_tokens.eos_token.clone(),
);
processor.apply_chat_template(messages, add_generation_prompt)
} else { } else {
Err(Error::msg( Err(Error::msg(
"Cannot use chat template functions because tokenizer.chat_template is not set and no template \ "Cannot use chat template functions because tokenizer.chat_template is not set and no template \
......
use sglang_router_rs::protocols::spec; use sglang_router_rs::protocols::spec;
use sglang_router_rs::tokenizer::chat_template::{ use sglang_router_rs::tokenizer::chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor, detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor,
}; };
#[test] #[test]
...@@ -169,11 +170,7 @@ assistant: ...@@ -169,11 +170,7 @@ assistant:
{%- endif %} {%- endif %}
"#; "#;
let processor = ChatTemplateProcessor::new( let processor = ChatTemplateProcessor::new(template.to_string());
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = vec![ let messages = vec![
spec::ChatMessage::System { spec::ChatMessage::System {
...@@ -194,8 +191,12 @@ assistant: ...@@ -194,8 +191,12 @@ assistant:
.map(|msg| serde_json::to_value(msg).unwrap()) .map(|msg| serde_json::to_value(msg).unwrap())
.collect(); .collect();
let params = ChatTemplateParams {
add_generation_prompt: true,
..Default::default()
};
let result = processor let result = processor
.apply_chat_template(&message_values, true) .apply_chat_template(&message_values, params)
.unwrap(); .unwrap();
assert!(result.contains("system: You are helpful")); assert!(result.contains("system: You are helpful"));
assert!(result.contains("user: Hello")); assert!(result.contains("user: Hello"));
...@@ -204,19 +205,15 @@ assistant: ...@@ -204,19 +205,15 @@ assistant:
#[test] #[test]
fn test_chat_template_with_tokens_unit_test() { fn test_chat_template_with_tokens_unit_test() {
// Template that uses special tokens // Template that uses template kwargs for tokens (more realistic)
let template = r#" let template = r#"
{{ bos_token }} {%- if start_token -%}{{ start_token }}{%- endif -%}
{%- for message in messages -%} {%- for message in messages -%}
{{ message.role }}: {{ message.content }}{{ eos_token }} {{ message.role }}: {{ message.content }}{%- if end_token -%}{{ end_token }}{%- endif -%}
{% endfor -%} {% endfor -%}
"#; "#;
let processor = ChatTemplateProcessor::new( let processor = ChatTemplateProcessor::new(template.to_string());
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = [spec::ChatMessage::User { let messages = [spec::ChatMessage::User {
role: "user".to_string(), role: "user".to_string(),
...@@ -230,8 +227,24 @@ fn test_chat_template_with_tokens_unit_test() { ...@@ -230,8 +227,24 @@ fn test_chat_template_with_tokens_unit_test() {
.map(|msg| serde_json::to_value(msg).unwrap()) .map(|msg| serde_json::to_value(msg).unwrap())
.collect(); .collect();
// Use template_kwargs to pass tokens
let mut template_kwargs = std::collections::HashMap::new();
template_kwargs.insert(
"start_token".to_string(),
serde_json::Value::String("<s>".to_string()),
);
template_kwargs.insert(
"end_token".to_string(),
serde_json::Value::String("</s>".to_string()),
);
let params = ChatTemplateParams {
template_kwargs: Some(&template_kwargs),
..Default::default()
};
let result = processor let result = processor
.apply_chat_template(&message_values, false) .apply_chat_template(&message_values, params)
.unwrap(); .unwrap();
assert!(result.contains("<s>")); assert!(result.contains("<s>"));
assert!(result.contains("</s>")); assert!(result.contains("</s>"));
......
use sglang_router_rs::protocols::spec; use sglang_router_rs::protocols::spec;
use sglang_router_rs::tokenizer::chat_template::{ use sglang_router_rs::tokenizer::chat_template::{
detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateProcessor, detect_chat_template_content_format, ChatTemplateContentFormat, ChatTemplateParams,
ChatTemplateProcessor,
}; };
#[test] #[test]
...@@ -14,11 +15,7 @@ fn test_simple_chat_template() { ...@@ -14,11 +15,7 @@ fn test_simple_chat_template() {
{%- endif %} {%- endif %}
"#; "#;
let processor = ChatTemplateProcessor::new( let processor = ChatTemplateProcessor::new(template.to_string());
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = [spec::ChatMessage::User { let messages = [spec::ChatMessage::User {
role: "user".to_string(), role: "user".to_string(),
...@@ -32,8 +29,12 @@ fn test_simple_chat_template() { ...@@ -32,8 +29,12 @@ fn test_simple_chat_template() {
.map(|msg| serde_json::to_value(msg).unwrap()) .map(|msg| serde_json::to_value(msg).unwrap())
.collect(); .collect();
let params = ChatTemplateParams {
add_generation_prompt: true,
..Default::default()
};
let result = processor let result = processor
.apply_chat_template(&message_values, true) .apply_chat_template(&message_values, params)
.unwrap(); .unwrap();
assert!(result.contains("<|user|>Test<|end|>")); assert!(result.contains("<|user|>Test<|end|>"));
assert!(result.contains("<|assistant|>")); assert!(result.contains("<|assistant|>"));
...@@ -41,19 +42,15 @@ fn test_simple_chat_template() { ...@@ -41,19 +42,15 @@ fn test_simple_chat_template() {
#[test] #[test]
fn test_chat_template_with_tokens() { fn test_chat_template_with_tokens() {
// Template that uses special tokens // Template that uses template kwargs for tokens
let template = r#" let template = r#"
{{ bos_token }} {%- if bos_token -%}{{ bos_token }}{%- endif -%}
{%- for message in messages -%} {%- for message in messages -%}
{{ message.role }}: {{ message.content }}{{ eos_token }} {{ message.role }}: {{ message.content }}{%- if eos_token -%}{{ eos_token }}{%- endif -%}
{% endfor -%} {% endfor -%}
"#; "#;
let processor = ChatTemplateProcessor::new( let processor = ChatTemplateProcessor::new(template.to_string());
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = [spec::ChatMessage::User { let messages = [spec::ChatMessage::User {
role: "user".to_string(), role: "user".to_string(),
...@@ -67,8 +64,24 @@ fn test_chat_template_with_tokens() { ...@@ -67,8 +64,24 @@ fn test_chat_template_with_tokens() {
.map(|msg| serde_json::to_value(msg).unwrap()) .map(|msg| serde_json::to_value(msg).unwrap())
.collect(); .collect();
// Use template_kwargs to pass tokens
let mut template_kwargs = std::collections::HashMap::new();
template_kwargs.insert(
"bos_token".to_string(),
serde_json::Value::String("<s>".to_string()),
);
template_kwargs.insert(
"eos_token".to_string(),
serde_json::Value::String("</s>".to_string()),
);
let params = ChatTemplateParams {
template_kwargs: Some(&template_kwargs),
..Default::default()
};
let result = processor let result = processor
.apply_chat_template(&message_values, false) .apply_chat_template(&message_values, params)
.unwrap(); .unwrap();
assert!(result.contains("<s>")); assert!(result.contains("<s>"));
assert!(result.contains("</s>")); assert!(result.contains("</s>"));
...@@ -85,7 +98,7 @@ fn test_llama_style_template() { ...@@ -85,7 +98,7 @@ fn test_llama_style_template() {
{%- set system_message = '' -%} {%- set system_message = '' -%}
{%- endif -%} {%- endif -%}
{{- bos_token }} {{- bos_token if bos_token else '<|begin_of_text|>' }}
{%- if system_message %} {%- if system_message %}
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }} {{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
{%- endif %} {%- endif %}
...@@ -99,11 +112,7 @@ fn test_llama_style_template() { ...@@ -99,11 +112,7 @@ fn test_llama_style_template() {
{%- endif %} {%- endif %}
"#; "#;
let processor = ChatTemplateProcessor::new( let processor = ChatTemplateProcessor::new(template.to_string());
template.to_string(),
Some("<|begin_of_text|>".to_string()),
Some("<|end_of_text|>".to_string()),
);
let messages = vec![ let messages = vec![
spec::ChatMessage::System { spec::ChatMessage::System {
...@@ -124,7 +133,21 @@ fn test_llama_style_template() { ...@@ -124,7 +133,21 @@ fn test_llama_style_template() {
.map(|msg| serde_json::to_value(msg).unwrap()) .map(|msg| serde_json::to_value(msg).unwrap())
.collect(); .collect();
let result = processor.apply_chat_template(&json_messages, true).unwrap(); // Use template_kwargs to pass the token
let mut template_kwargs = std::collections::HashMap::new();
template_kwargs.insert(
"bos_token".to_string(),
serde_json::Value::String("<|begin_of_text|>".to_string()),
);
let params = ChatTemplateParams {
add_generation_prompt: true,
template_kwargs: Some(&template_kwargs),
..Default::default()
};
let result = processor
.apply_chat_template(&json_messages, params)
.unwrap();
// Check that the result contains expected markers // Check that the result contains expected markers
assert!(result.contains("<|begin_of_text|>")); assert!(result.contains("<|begin_of_text|>"));
...@@ -147,7 +170,7 @@ fn test_chatml_template() { ...@@ -147,7 +170,7 @@ fn test_chatml_template() {
{%- endif %} {%- endif %}
"#; "#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = vec![ let messages = vec![
spec::ChatMessage::User { spec::ChatMessage::User {
...@@ -176,7 +199,15 @@ fn test_chatml_template() { ...@@ -176,7 +199,15 @@ fn test_chatml_template() {
.map(|msg| serde_json::to_value(msg).unwrap()) .map(|msg| serde_json::to_value(msg).unwrap())
.collect(); .collect();
let result = processor.apply_chat_template(&json_messages, true).unwrap(); let result = processor
.apply_chat_template(
&json_messages,
ChatTemplateParams {
add_generation_prompt: true,
..Default::default()
},
)
.unwrap();
// Check ChatML format // Check ChatML format
assert!(result.contains("<|im_start|>user\nHello<|im_end|>")); assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
...@@ -196,7 +227,7 @@ assistant: ...@@ -196,7 +227,7 @@ assistant:
{%- endif -%} {%- endif -%}
"#; "#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User { let messages = [spec::ChatMessage::User {
role: "user".to_string(), role: "user".to_string(),
...@@ -212,12 +243,20 @@ assistant: ...@@ -212,12 +243,20 @@ assistant:
// Test without generation prompt // Test without generation prompt
let result = processor let result = processor
.apply_chat_template(&json_messages, false) .apply_chat_template(&json_messages, ChatTemplateParams::default())
.unwrap(); .unwrap();
assert_eq!(result.trim(), "user: Test"); assert_eq!(result.trim(), "user: Test");
// Test with generation prompt // Test with generation prompt
let result_with_prompt = processor.apply_chat_template(&json_messages, true).unwrap(); let result_with_prompt = processor
.apply_chat_template(
&json_messages,
ChatTemplateParams {
add_generation_prompt: true,
..Default::default()
},
)
.unwrap();
assert!(result_with_prompt.contains("assistant:")); assert!(result_with_prompt.contains("assistant:"));
} }
...@@ -225,10 +264,12 @@ assistant: ...@@ -225,10 +264,12 @@ assistant:
fn test_empty_messages_template() { fn test_empty_messages_template() {
let template = r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#; let template = r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None); let processor = ChatTemplateProcessor::new(template.to_string());
let messages: Vec<serde_json::Value> = vec![]; let messages: Vec<serde_json::Value> = vec![];
let result = processor.apply_chat_template(&messages, false).unwrap(); let result = processor
.apply_chat_template(&messages, ChatTemplateParams::default())
.unwrap();
assert_eq!(result, ""); assert_eq!(result, "");
} }
...@@ -279,7 +320,7 @@ fn test_template_with_multimodal_content() { ...@@ -279,7 +320,7 @@ fn test_template_with_multimodal_content() {
{% endfor %} {% endfor %}
"#; "#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None); let processor = ChatTemplateProcessor::new(template.to_string());
let messages = [spec::ChatMessage::User { let messages = [spec::ChatMessage::User {
role: "user".to_string(), role: "user".to_string(),
...@@ -304,7 +345,7 @@ fn test_template_with_multimodal_content() { ...@@ -304,7 +345,7 @@ fn test_template_with_multimodal_content() {
.collect(); .collect();
let result = processor let result = processor
.apply_chat_template(&json_messages, false) .apply_chat_template(&json_messages, ChatTemplateParams::default())
.unwrap(); .unwrap();
// Should contain both text and image parts // Should contain both text and image parts
......
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use sglang_router_rs::protocols::spec; use sglang_router_rs::protocols::spec;
use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer; use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
use std::fs; use std::fs;
use tempfile::TempDir; use tempfile::TempDir;
...@@ -79,7 +80,14 @@ mod tests { ...@@ -79,7 +80,14 @@ mod tests {
.map(|msg| serde_json::to_value(msg).unwrap()) .map(|msg| serde_json::to_value(msg).unwrap())
.collect(); .collect();
let result = tokenizer.apply_chat_template(&json_messages, true).unwrap(); use sglang_router_rs::tokenizer::chat_template::ChatTemplateParams;
let params = ChatTemplateParams {
add_generation_prompt: true,
..Default::default()
};
let result = tokenizer
.apply_chat_template(&json_messages, params)
.unwrap();
// Verify the custom template format // Verify the custom template format
assert!(result.contains("<|user|>Hello")); assert!(result.contains("<|user|>Hello"));
...@@ -150,7 +158,7 @@ mod tests { ...@@ -150,7 +158,7 @@ mod tests {
.collect(); .collect();
let result = tokenizer let result = tokenizer
.apply_chat_template(&json_messages, false) .apply_chat_template(&json_messages, ChatTemplateParams::default())
.unwrap(); .unwrap();
// Should use CUSTOM template, not built-in // Should use CUSTOM template, not built-in
...@@ -219,7 +227,7 @@ mod tests { ...@@ -219,7 +227,7 @@ mod tests {
.collect(); .collect();
let result = tokenizer let result = tokenizer
.apply_chat_template(&json_messages, false) .apply_chat_template(&json_messages, ChatTemplateParams::default())
.unwrap(); .unwrap();
assert!(result.starts_with("NEW:")); assert!(result.starts_with("NEW:"));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment