Unverified Commit e69094df authored by Chang Su's avatar Chang Su Committed by GitHub
Browse files

[router][grpc] Remove `continue_final_message` in `ChatTemplateParams` and add...

[router][grpc] Remove `continue_final_message` in `ChatTemplateParams` and add `minijinja-contrib` (#11882)
parent 43ad0590
...@@ -64,6 +64,7 @@ anyhow = "1.0" ...@@ -64,6 +64,7 @@ anyhow = "1.0"
tokenizers = { version = "0.22.0" } tokenizers = { version = "0.22.0" }
tiktoken-rs = { version = "0.7.0" } tiktoken-rs = { version = "0.7.0" }
minijinja = { version = "2.0", features = ["unstable_machinery", "json", "builtins"] } minijinja = { version = "2.0", features = ["unstable_machinery", "json", "builtins"] }
minijinja-contrib = { version = "2.0", features = ["pycompat"] }
rustls = { version = "0.23", default-features = false, features = ["ring", "std"] } rustls = { version = "0.23", default-features = false, features = ["ring", "std"] }
hf-hub = { version = "0.4.3", features = ["tokio"] } hf-hub = { version = "0.4.3", features = ["tokio"] }
rmcp = { version = "0.6.3", features = ["client", "server", rmcp = { version = "0.6.3", features = ["client", "server",
......
...@@ -382,7 +382,6 @@ pub fn process_chat_messages( ...@@ -382,7 +382,6 @@ pub fn process_chat_messages(
let params = ChatTemplateParams { let params = ChatTemplateParams {
add_generation_prompt: true, add_generation_prompt: true,
continue_final_message: request.continue_final_message,
tools: tools_json.as_deref(), tools: tools_json.as_deref(),
template_kwargs: final_template_kwargs, template_kwargs: final_template_kwargs,
..Default::default() ..Default::default()
......
...@@ -3,12 +3,16 @@ ...@@ -3,12 +3,16 @@
//! This module provides functionality to apply chat templates to messages, //! This module provides functionality to apply chat templates to messages,
//! similar to HuggingFace transformers' apply_chat_template method. //! similar to HuggingFace transformers' apply_chat_template method.
use std::collections::HashMap; use std::{collections::HashMap, fs};
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use minijinja::{ use minijinja::{
context, context,
machinery::ast::{Expr, Stmt}, machinery::{
ast::{Expr, Stmt},
parse, WhitespaceConfig,
},
syntax::SyntaxConfig,
Environment, Value, Environment, Value,
}; };
use serde_json; use serde_json;
...@@ -323,11 +327,6 @@ impl<'a> Detector<'a> { ...@@ -323,11 +327,6 @@ impl<'a> Detector<'a> {
/// AST-based detection using minijinja's unstable machinery /// AST-based detection using minijinja's unstable machinery
/// Single-pass detector with scope tracking /// Single-pass detector with scope tracking
fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> { fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
use minijinja::{
machinery::{parse, WhitespaceConfig},
syntax::SyntaxConfig,
};
let ast = match parse( let ast = match parse(
template, template,
"template", "template",
...@@ -350,7 +349,6 @@ fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> { ...@@ -350,7 +349,6 @@ fn detect_format_with_ast(template: &str) -> Option<ChatTemplateContentFormat> {
#[derive(Default)] #[derive(Default)]
pub struct ChatTemplateParams<'a> { pub struct ChatTemplateParams<'a> {
pub add_generation_prompt: bool, pub add_generation_prompt: bool,
pub continue_final_message: bool,
pub tools: Option<&'a [serde_json::Value]>, pub tools: Option<&'a [serde_json::Value]>,
pub documents: Option<&'a [serde_json::Value]>, pub documents: Option<&'a [serde_json::Value]>,
pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>, pub template_kwargs: Option<&'a HashMap<String, serde_json::Value>>,
...@@ -377,16 +375,15 @@ impl ChatTemplateProcessor { ...@@ -377,16 +375,15 @@ impl ChatTemplateProcessor {
messages: &[serde_json::Value], messages: &[serde_json::Value],
params: ChatTemplateParams, params: ChatTemplateParams,
) -> Result<String> { ) -> Result<String> {
// Validate incompatible options
if params.continue_final_message && params.add_generation_prompt {
return Err(anyhow!("continue_final_message and add_generation_prompt are not compatible. Use continue_final_message when you want the model to continue the final message, and add_generation_prompt when you want to add a header that will prompt it to start a new assistant message instead."));
}
let mut env = Environment::new(); let mut env = Environment::new();
// Register the template // Register the template
env.add_template("chat", &self.template) env.add_template("chat", &self.template)
.map_err(|e| anyhow!("Failed to add template: {}", e))?; .map_err(|e| anyhow!("Failed to add template: {}", e))?;
// Enable Python method compatibility (e.g., str.startswith, str.endswith)
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
// Get the template // Get the template
let tmpl = env let tmpl = env
.get_template("chat") .get_template("chat")
...@@ -423,8 +420,6 @@ impl ChatTemplateProcessor { ...@@ -423,8 +420,6 @@ impl ChatTemplateProcessor {
/// Load chat template from tokenizer config JSON /// Load chat template from tokenizer config JSON
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> { pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
use std::fs;
let content = fs::read_to_string(config_path)?; let content = fs::read_to_string(config_path)?;
let config: serde_json::Value = serde_json::from_str(&content)?; let config: serde_json::Value = serde_json::from_str(&content)?;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment