Unverified Commit 5fbad308 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add tokenizer chat template support (#9370)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 7638f5e4
...@@ -5,7 +5,7 @@ edition = "2021" ...@@ -5,7 +5,7 @@ edition = "2021"
[features] [features]
default = ["huggingface", "grpc-client"] default = ["huggingface", "grpc-client"]
huggingface = ["tokenizers"] huggingface = ["tokenizers", "minijinja"]
tiktoken = ["tiktoken-rs"] tiktoken = ["tiktoken-rs"]
grpc-client = [] grpc-client = []
grpc-server = [] grpc-server = []
...@@ -52,7 +52,8 @@ url = "2.5.4" ...@@ -52,7 +52,8 @@ url = "2.5.4"
tokio-stream = { version = "0.1", features = ["sync"] } tokio-stream = { version = "0.1", features = ["sync"] }
anyhow = "1.0" anyhow = "1.0"
tokenizers = { version = "0.21.4", optional = true } tokenizers = { version = "0.21.4", optional = true }
tiktoken-rs = { version = "0.5", optional = true } tiktoken-rs = { version = "0.7.0", optional = true }
minijinja = { version = "2.0", optional = true }
# gRPC and Protobuf dependencies # gRPC and Protobuf dependencies
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] } tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
...@@ -71,6 +72,7 @@ criterion = { version = "0.5", features = ["html_reports"] } ...@@ -71,6 +72,7 @@ criterion = { version = "0.5", features = ["html_reports"] }
tower = { version = "0.5", features = ["util"] } tower = { version = "0.5", features = ["util"] }
http-body-util = "0.1" http-body-util = "0.1"
portpicker = "0.1" portpicker = "0.1"
tempfile = "3.8"
[[bench]] [[bench]]
name = "request_processing" name = "request_processing"
......
//! Chat template support for tokenizers using Jinja2 templates
//!
//! This module provides functionality to apply chat templates to messages,
//! similar to HuggingFace transformers' apply_chat_template method.
use anyhow::{anyhow, Result};
#[cfg(feature = "huggingface")]
use minijinja::{context, Environment, Value};
use serde::{Deserialize, Serialize};
use serde_json;
/// Represents a chat message with role and content
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
ChatMessage {
role: role.into(),
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
}
/// Chat template processor using Jinja2
#[cfg(feature = "huggingface")]
pub struct ChatTemplateProcessor {
template: String,
bos_token: Option<String>,
eos_token: Option<String>,
}
#[cfg(feature = "huggingface")]
impl ChatTemplateProcessor {
/// Create a new chat template processor
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
ChatTemplateProcessor {
template,
bos_token,
eos_token,
}
}
/// Apply the chat template to a list of messages
///
/// This mimics the behavior of HuggingFace's apply_chat_template method
/// but returns the formatted string instead of token IDs.
pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
add_generation_prompt: bool,
) -> Result<String> {
let mut env = Environment::new();
// Register the template
env.add_template("chat", &self.template)
.map_err(|e| anyhow!("Failed to add template: {}", e))?;
// Get the template
let tmpl = env
.get_template("chat")
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
// Convert messages to a format Jinja can work with
let messages_value: Vec<Value> = messages
.iter()
.map(|msg| {
context! {
role => msg.role.clone(),
content => msg.content.clone()
}
})
.collect();
// Render the template
let rendered = tmpl
.render(context! {
messages => messages_value,
add_generation_prompt => add_generation_prompt,
bos_token => self.bos_token.clone().unwrap_or_default(),
eos_token => self.eos_token.clone().unwrap_or_default()
})
.map_err(|e| anyhow!("Failed to render template: {}", e))?;
Ok(rendered)
}
}
/// Load chat template from tokenizer config JSON
#[cfg(feature = "huggingface")]
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
use std::fs;
let content = fs::read_to_string(config_path)?;
let config: serde_json::Value = serde_json::from_str(&content)?;
// Look for chat_template in the config
if let Some(template) = config.get("chat_template") {
if let Some(template_str) = template.as_str() {
return Ok(Some(template_str.to_string()));
}
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_message_creation() {
let msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(msg.role, "system");
assert_eq!(msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
}
#[cfg(feature = "huggingface")]
#[test]
fn test_simple_chat_template() {
// Simple template that formats messages
let template = r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
assert!(result.contains("system: You are helpful"));
assert!(result.contains("user: Hello"));
assert!(result.contains("assistant:"));
}
#[cfg(feature = "huggingface")]
#[test]
fn test_chat_template_with_tokens() {
// Template that uses special tokens
let template = r#"
{{ bos_token }}
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}{{ eos_token }}
{% endfor -%}
"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = vec![ChatMessage::user("Test")];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert!(result.contains("<s>"));
assert!(result.contains("</s>"));
}
}
...@@ -26,6 +26,14 @@ pub enum TokenizerType { ...@@ -26,6 +26,14 @@ pub enum TokenizerType {
/// - json: HuggingFace tokenizer /// - json: HuggingFace tokenizer
/// - For testing: can return mock tokenizer /// - For testing: can return mock tokenizer
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> { pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
create_tokenizer_with_chat_template(file_path, None)
}
/// Create a tokenizer from a file path with an optional chat template
pub fn create_tokenizer_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Arc<dyn traits::Tokenizer>> {
let start_time = Instant::now(); let start_time = Instant::now();
// Special case for testing // Special case for testing
...@@ -51,7 +59,10 @@ pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tok ...@@ -51,7 +59,10 @@ pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tok
Some("json") => { Some("json") => {
#[cfg(feature = "huggingface")] #[cfg(feature = "huggingface")]
{ {
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?; let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
file_path,
chat_template_path,
)?;
TokenizerMetrics::record_factory_load("json"); TokenizerMetrics::record_factory_load("json");
TokenizerMetrics::set_vocab_size("huggingface", tokenizer.vocab_size()); TokenizerMetrics::set_vocab_size("huggingface", tokenizer.vocab_size());
......
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
use crate::metrics::TokenizerMetrics; use crate::metrics::TokenizerMetrics;
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use std::collections::HashMap; use std::collections::HashMap;
use std::time::Instant; use std::time::Instant;
use tokenizers::tokenizer::Tokenizer as HfTokenizer; use tokenizers::tokenizer::Tokenizer as HfTokenizer;
#[cfg(feature = "minijinja")]
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
/// HuggingFace tokenizer wrapper /// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer { pub struct HuggingFaceTokenizer {
tokenizer: HfTokenizer, tokenizer: HfTokenizer,
special_tokens: SpecialTokens, special_tokens: SpecialTokens,
vocab: HashMap<String, u32>, vocab: HashMap<String, TokenIdType>,
reverse_vocab: HashMap<u32, String>, reverse_vocab: HashMap<TokenIdType, String>,
#[cfg(feature = "minijinja")]
chat_template: Option<String>,
} }
impl HuggingFaceTokenizer { impl HuggingFaceTokenizer {
/// Create a tokenizer from a HuggingFace tokenizer JSON file /// Create a tokenizer from a HuggingFace tokenizer JSON file
pub fn from_file(file_path: &str) -> Result<Self> { pub fn from_file(file_path: &str) -> Result<Self> {
Self::from_file_with_chat_template(file_path, None)
}
/// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
pub fn from_file_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Self> {
let tokenizer = HfTokenizer::from_file(file_path) let tokenizer = HfTokenizer::from_file(file_path)
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?; .map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
...@@ -24,16 +39,28 @@ impl HuggingFaceTokenizer { ...@@ -24,16 +39,28 @@ impl HuggingFaceTokenizer {
// Build vocab mappings // Build vocab mappings
let vocab = tokenizer.get_vocab(false); let vocab = tokenizer.get_vocab(false);
let reverse_vocab: HashMap<u32, String> = vocab let reverse_vocab: HashMap<TokenIdType, String> = vocab
.iter() .iter()
.map(|(token, &id)| (id, token.clone())) .map(|(token, &id)| (id, token.clone()))
.collect(); .collect();
// Load chat template
#[cfg(feature = "minijinja")]
let chat_template = if let Some(template_path) = chat_template_path {
// Load from specified .jinja file
Self::load_chat_template_from_file(template_path)?
} else {
// Try to load from tokenizer_config.json
Self::load_chat_template(file_path)
};
Ok(HuggingFaceTokenizer { Ok(HuggingFaceTokenizer {
tokenizer, tokenizer,
special_tokens, special_tokens,
vocab, vocab,
reverse_vocab, reverse_vocab,
#[cfg(feature = "minijinja")]
chat_template,
}) })
} }
...@@ -41,7 +68,7 @@ impl HuggingFaceTokenizer { ...@@ -41,7 +68,7 @@ impl HuggingFaceTokenizer {
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self { pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
let special_tokens = Self::extract_special_tokens(&tokenizer); let special_tokens = Self::extract_special_tokens(&tokenizer);
let vocab = tokenizer.get_vocab(false); let vocab = tokenizer.get_vocab(false);
let reverse_vocab: HashMap<u32, String> = vocab let reverse_vocab: HashMap<TokenIdType, String> = vocab
.iter() .iter()
.map(|(token, &id)| (id, token.clone())) .map(|(token, &id)| (id, token.clone()))
.collect(); .collect();
...@@ -51,6 +78,8 @@ impl HuggingFaceTokenizer { ...@@ -51,6 +78,8 @@ impl HuggingFaceTokenizer {
special_tokens, special_tokens,
vocab, vocab,
reverse_vocab, reverse_vocab,
#[cfg(feature = "minijinja")]
chat_template: None,
} }
} }
...@@ -81,13 +110,86 @@ impl HuggingFaceTokenizer { ...@@ -81,13 +110,86 @@ impl HuggingFaceTokenizer {
} }
} }
/// Try to load chat template from tokenizer_config.json
#[cfg(feature = "minijinja")]
fn load_chat_template(tokenizer_path: &str) -> Option<String> {
// Try to find tokenizer_config.json in the same directory
let path = std::path::Path::new(tokenizer_path);
let dir = path.parent()?;
let config_path = dir.join("tokenizer_config.json");
if config_path.exists() {
if let Ok(template) =
super::chat_template::load_chat_template_from_config(config_path.to_str()?)
{
return template;
}
}
None
}
/// Load chat template from a .jinja file
#[cfg(feature = "minijinja")]
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
use std::fs;
let content = fs::read_to_string(template_path)
.map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
// Clean up the template (similar to Python implementation)
let template = content.trim().replace("\\n", "\n");
Ok(Some(template))
}
/// Set or override the chat template
#[cfg(feature = "minijinja")]
pub fn set_chat_template(&mut self, template: String) {
self.chat_template = Some(template);
}
/// Apply chat template if available /// Apply chat template if available
pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> { #[cfg(feature = "minijinja")]
// This is a placeholder - actual implementation would handle templates pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
add_generation_prompt: bool,
) -> Result<String> {
if let Some(ref template) = self.chat_template {
let processor = ChatTemplateProcessor::new(
template.clone(),
self.special_tokens.bos_token.clone(),
self.special_tokens.eos_token.clone(),
);
processor.apply_chat_template(messages, add_generation_prompt)
} else {
// Fallback to simple formatting if no template is available
let mut result = String::new(); let mut result = String::new();
for msg in messages { for msg in messages {
result.push_str(&format!("{}: {}\n", msg.role, msg.content)); result.push_str(&format!("{}: {}\n", msg.role, msg.content));
} }
if add_generation_prompt {
result.push_str("assistant: ");
}
Ok(result)
}
}
/// Apply chat template if available (without minijinja feature)
#[cfg(not(feature = "minijinja"))]
pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
add_generation_prompt: bool,
) -> Result<String> {
// Fallback to simple formatting
let mut result = String::new();
for msg in messages {
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
if add_generation_prompt {
result.push_str("assistant: ");
}
Ok(result) Ok(result)
} }
} }
...@@ -133,7 +235,7 @@ impl Encoder for HuggingFaceTokenizer { ...@@ -133,7 +235,7 @@ impl Encoder for HuggingFaceTokenizer {
} }
impl Decoder for HuggingFaceTokenizer { impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> { fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
let start = Instant::now(); let start = Instant::now();
TokenizerMetrics::record_decode_request("huggingface"); TokenizerMetrics::record_decode_request("huggingface");
...@@ -160,47 +262,21 @@ impl TokenizerTrait for HuggingFaceTokenizer { ...@@ -160,47 +262,21 @@ impl TokenizerTrait for HuggingFaceTokenizer {
&self.special_tokens &self.special_tokens
} }
fn token_to_id(&self, token: &str) -> Option<u32> { fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
self.vocab.get(token).copied() self.vocab.get(token).copied()
} }
fn id_to_token(&self, id: u32) -> Option<String> { fn id_to_token(&self, id: TokenIdType) -> Option<String> {
self.reverse_vocab.get(&id).cloned() self.reverse_vocab.get(&id).cloned()
} }
} }
/// Represents a chat message for template application
#[derive(Debug, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
ChatMessage {
role: role.into(),
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; #[cfg(feature = "minijinja")]
use super::ChatMessage;
#[cfg(feature = "minijinja")]
#[test] #[test]
fn test_chat_message_creation() { fn test_chat_message_creation() {
let msg = ChatMessage::system("You are a helpful assistant"); let msg = ChatMessage::system("You are a helpful assistant");
......
...@@ -10,6 +10,9 @@ pub mod stream; ...@@ -10,6 +10,9 @@ pub mod stream;
pub mod traits; pub mod traits;
// Feature-gated modules // Feature-gated modules
#[cfg(feature = "huggingface")]
pub mod chat_template;
#[cfg(feature = "huggingface")] #[cfg(feature = "huggingface")]
pub mod huggingface; pub mod huggingface;
...@@ -20,14 +23,20 @@ pub mod tiktoken; ...@@ -20,14 +23,20 @@ pub mod tiktoken;
mod tests; mod tests;
// Re-exports // Re-exports
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType}; pub use factory::{
create_tokenizer, create_tokenizer_from_file, create_tokenizer_with_chat_template,
TokenizerType,
};
pub use sequence::Sequence; pub use sequence::Sequence;
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder}; pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
pub use stream::DecodeStream; pub use stream::DecodeStream;
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
#[cfg(feature = "huggingface")] #[cfg(feature = "huggingface")]
pub use huggingface::{ChatMessage, HuggingFaceTokenizer}; pub use huggingface::HuggingFaceTokenizer;
#[cfg(feature = "huggingface")]
pub use chat_template::ChatMessage;
#[cfg(feature = "tiktoken")] #[cfg(feature = "tiktoken")]
pub use tiktoken::{TiktokenModel, TiktokenTokenizer}; pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
...@@ -42,6 +51,17 @@ impl Tokenizer { ...@@ -42,6 +51,17 @@ impl Tokenizer {
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?)) Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
} }
/// Create a tokenizer from a file path with an optional chat template
pub fn from_file_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Tokenizer> {
Ok(Tokenizer(factory::create_tokenizer_with_chat_template(
file_path,
chat_template_path,
)?))
}
/// Create a tokenizer from an Arc<dyn Tokenizer> /// Create a tokenizer from an Arc<dyn Tokenizer>
pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self { pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
Tokenizer(tokenizer) Tokenizer(tokenizer)
......
use super::traits::Tokenizer as TokenizerTrait; use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
use anyhow::Result; use anyhow::Result;
use std::sync::Arc; use std::sync::Arc;
...@@ -9,7 +9,7 @@ pub struct Sequence { ...@@ -9,7 +9,7 @@ pub struct Sequence {
tokenizer: Arc<dyn TokenizerTrait>, tokenizer: Arc<dyn TokenizerTrait>,
/// The current sequence of token ids /// The current sequence of token ids
token_ids: Vec<u32>, token_ids: Vec<TokenIdType>,
/// The position in the current sequence the last decoded token completed /// The position in the current sequence the last decoded token completed
prefix_offset: usize, prefix_offset: usize,
...@@ -54,7 +54,7 @@ impl Sequence { ...@@ -54,7 +54,7 @@ impl Sequence {
} }
/// Create a sequence with initial tokens /// Create a sequence with initial tokens
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<u32>) -> Self { pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
let len = token_ids.len(); let len = token_ids.len();
Self { Self {
tokenizer, tokenizer,
...@@ -90,7 +90,7 @@ impl Sequence { ...@@ -90,7 +90,7 @@ impl Sequence {
/// Append a single token to the sequence and return newly decoded text /// Append a single token to the sequence and return newly decoded text
/// Based on HuggingFace TGI incremental decoding /// Based on HuggingFace TGI incremental decoding
pub fn append_token(&mut self, token_id: u32) -> Result<String> { pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
// Store the old read offset before adding the new token // Store the old read offset before adding the new token
let old_read_offset = self.read_offset; let old_read_offset = self.read_offset;
...@@ -145,7 +145,7 @@ impl Sequence { ...@@ -145,7 +145,7 @@ impl Sequence {
} }
/// Get the current token ids /// Get the current token ids
pub fn token_ids(&self) -> &[u32] { pub fn token_ids(&self) -> &[TokenIdType] {
&self.token_ids &self.token_ids
} }
......
use super::traits; use super::traits::{self, TokenIdType};
use crate::metrics::TokenizerMetrics; use crate::metrics::TokenizerMetrics;
use anyhow::Result; use anyhow::Result;
use std::collections::HashSet; use std::collections::HashSet;
...@@ -22,18 +22,18 @@ pub enum SequenceDecoderOutput { ...@@ -22,18 +22,18 @@ pub enum SequenceDecoderOutput {
#[derive(Debug, Clone, Default)] #[derive(Debug, Clone, Default)]
pub struct StopSequenceConfig { pub struct StopSequenceConfig {
/// Token IDs that trigger a stop /// Token IDs that trigger a stop
pub stop_tokens: HashSet<u32>, pub stop_tokens: HashSet<TokenIdType>,
/// String sequences that trigger a stop /// String sequences that trigger a stop
pub stop_sequences: Vec<String>, pub stop_sequences: Vec<String>,
/// Token IDs for visible stops (included in output) /// Token IDs for visible stops (included in output)
pub visible_stop_tokens: HashSet<u32>, pub visible_stop_tokens: HashSet<TokenIdType>,
/// String sequences for visible stops (included in output) /// String sequences for visible stops (included in output)
pub visible_stop_sequences: Vec<String>, pub visible_stop_sequences: Vec<String>,
} }
impl StopSequenceConfig { impl StopSequenceConfig {
/// Builder pattern - add a stop token /// Builder pattern - add a stop token
pub fn with_stop_token(mut self, token_id: u32) -> Self { pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
self.stop_tokens.insert(token_id); self.stop_tokens.insert(token_id);
self self
} }
...@@ -45,7 +45,7 @@ impl StopSequenceConfig { ...@@ -45,7 +45,7 @@ impl StopSequenceConfig {
} }
/// Builder pattern - add a visible stop token /// Builder pattern - add a visible stop token
pub fn with_visible_stop_token(mut self, token_id: u32) -> Self { pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
self.visible_stop_tokens.insert(token_id); self.visible_stop_tokens.insert(token_id);
self self
} }
...@@ -64,7 +64,7 @@ pub struct StopSequenceDecoder { ...@@ -64,7 +64,7 @@ pub struct StopSequenceDecoder {
/// Buffer for partial matches (the "jail") /// Buffer for partial matches (the "jail")
jail_buffer: String, jail_buffer: String,
/// Accumulated tokens /// Accumulated tokens
token_buffer: Vec<u32>, token_buffer: Vec<TokenIdType>,
/// Offset where the prefix text starts (for context) /// Offset where the prefix text starts (for context)
prefix_offset: usize, prefix_offset: usize,
/// Offset marking the end of previously decoded text /// Offset marking the end of previously decoded text
...@@ -94,7 +94,7 @@ impl StopSequenceDecoder { ...@@ -94,7 +94,7 @@ impl StopSequenceDecoder {
} }
/// Process a single token /// Process a single token
pub fn process_token(&mut self, token_id: u32) -> Result<SequenceDecoderOutput> { pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
let start = Instant::now(); let start = Instant::now();
if self.stopped { if self.stopped {
...@@ -252,7 +252,10 @@ impl StopSequenceDecoder { ...@@ -252,7 +252,10 @@ impl StopSequenceDecoder {
} }
/// Process multiple tokens /// Process multiple tokens
pub fn process_tokens(&mut self, token_ids: &[u32]) -> Result<Vec<SequenceDecoderOutput>> { pub fn process_tokens(
&mut self,
token_ids: &[TokenIdType],
) -> Result<Vec<SequenceDecoderOutput>> {
let mut outputs = Vec::new(); let mut outputs = Vec::new();
for &token_id in token_ids { for &token_id in token_ids {
outputs.push(self.process_token(token_id)?); outputs.push(self.process_token(token_id)?);
...@@ -302,7 +305,7 @@ impl StopSequenceDecoderBuilder { ...@@ -302,7 +305,7 @@ impl StopSequenceDecoderBuilder {
} }
} }
pub fn stop_token(mut self, token_id: u32) -> Self { pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
self.config.stop_tokens.insert(token_id); self.config.stop_tokens.insert(token_id);
self self
} }
...@@ -312,7 +315,7 @@ impl StopSequenceDecoderBuilder { ...@@ -312,7 +315,7 @@ impl StopSequenceDecoderBuilder {
self self
} }
pub fn visible_stop_token(mut self, token_id: u32) -> Self { pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
self.config.visible_stop_tokens.insert(token_id); self.config.visible_stop_tokens.insert(token_id);
self self
} }
......
// src/tokenizer/stream.rs // src/tokenizer/stream.rs
use super::traits; use super::traits::{self, TokenIdType};
use crate::metrics::TokenizerMetrics; use crate::metrics::TokenizerMetrics;
use anyhow::Result; use anyhow::Result;
use std::sync::Arc; use std::sync::Arc;
...@@ -18,7 +18,7 @@ pub struct DecodeStream { ...@@ -18,7 +18,7 @@ pub struct DecodeStream {
/// A temporary buffer of the necessary token_ids needed /// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks /// to produce valid string chunks
all_token_ids: Vec<u32>, all_token_ids: Vec<TokenIdType>,
prefix_offset: usize, prefix_offset: usize,
read_offset: usize, read_offset: usize,
...@@ -27,7 +27,7 @@ pub struct DecodeStream { ...@@ -27,7 +27,7 @@ pub struct DecodeStream {
impl DecodeStream { impl DecodeStream {
pub fn new( pub fn new(
tokenizer: Arc<dyn traits::Tokenizer>, tokenizer: Arc<dyn traits::Tokenizer>,
prompt_token_ids: &[u32], prompt_token_ids: &[TokenIdType],
skip_special_tokens: bool, skip_special_tokens: bool,
) -> Self { ) -> Self {
let num_input_tokens = prompt_token_ids.len(); let num_input_tokens = prompt_token_ids.len();
...@@ -44,7 +44,7 @@ impl DecodeStream { ...@@ -44,7 +44,7 @@ impl DecodeStream {
/// Step appends a token_id to the internal state and tries to produce a text chunk. /// Step appends a token_id to the internal state and tries to produce a text chunk.
/// Returning `None` means the given id is not enough to produce a chunk. /// Returning `None` means the given id is not enough to produce a chunk.
pub fn step(&mut self, id: u32) -> Result<Option<String>> { pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
let start = Instant::now(); let start = Instant::now();
self.all_token_ids.push(id); self.all_token_ids.push(id);
......
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
use anyhow::{Error, Result}; use anyhow::{Error, Result};
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE}; use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
...@@ -140,12 +142,10 @@ impl Encoder for TiktokenTokenizer { ...@@ -140,12 +142,10 @@ impl Encoder for TiktokenTokenizer {
} }
impl Decoder for TiktokenTokenizer { impl Decoder for TiktokenTokenizer {
fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result<String> { fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
// Convert u32 to usize for tiktoken-rs // tiktoken-rs 0.7.0 now uses u32 (Rank type)
let tokens: Vec<usize> = token_ids.iter().map(|&id| id as usize).collect();
self.tokenizer self.tokenizer
.decode(tokens) .decode(token_ids.to_vec())
.map_err(|e| Error::msg(format!("Decoding failed: {}", e))) .map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
} }
} }
...@@ -159,13 +159,13 @@ impl TokenizerTrait for TiktokenTokenizer { ...@@ -159,13 +159,13 @@ impl TokenizerTrait for TiktokenTokenizer {
&self.special_tokens &self.special_tokens
} }
fn token_to_id(&self, _token: &str) -> Option<u32> { fn token_to_id(&self, _token: &str) -> Option<TokenIdType> {
// Tiktoken doesn't provide direct token-to-id mapping // Tiktoken doesn't provide direct token-to-id mapping
// We'd need to encode the token and check if it produces a single ID // We'd need to encode the token and check if it produces a single ID
None None
} }
fn id_to_token(&self, _id: u32) -> Option<String> { fn id_to_token(&self, _id: TokenIdType) -> Option<String> {
// Tiktoken doesn't provide direct id-to-token mapping // Tiktoken doesn't provide direct id-to-token mapping
// We can only decode IDs to text // We can only decode IDs to text
None None
......
use anyhow::Result; use anyhow::Result;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
/// Type alias for token IDs
pub type TokenIdType = u32;
/// Core encoding trait - separate from decoding for modularity /// Core encoding trait - separate from decoding for modularity
pub trait Encoder: Send + Sync { pub trait Encoder: Send + Sync {
...@@ -8,15 +13,15 @@ pub trait Encoder: Send + Sync { ...@@ -8,15 +13,15 @@ pub trait Encoder: Send + Sync {
/// Core decoding trait - can be implemented independently /// Core decoding trait - can be implemented independently
pub trait Decoder: Send + Sync { pub trait Decoder: Send + Sync {
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String>; fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
} }
/// Combined tokenizer trait /// Combined tokenizer trait
pub trait Tokenizer: Encoder + Decoder { pub trait Tokenizer: Encoder + Decoder {
fn vocab_size(&self) -> usize; fn vocab_size(&self) -> usize;
fn get_special_tokens(&self) -> &SpecialTokens; fn get_special_tokens(&self) -> &SpecialTokens;
fn token_to_id(&self, token: &str) -> Option<u32>; fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
fn id_to_token(&self, id: u32) -> Option<String>; fn id_to_token(&self, id: TokenIdType) -> Option<String>;
} }
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans /// Contains the results of tokenizing text: token IDs, string tokens, and their spans
...@@ -25,29 +30,45 @@ pub enum Encoding { ...@@ -25,29 +30,45 @@ pub enum Encoding {
/// Hugging Face /// Hugging Face
Hf(Box<tokenizers::tokenizer::Encoding>), Hf(Box<tokenizers::tokenizer::Encoding>),
/// Sentence Piece /// Sentence Piece
Sp(Vec<u32>), Sp(Vec<TokenIdType>),
/// Tiktoken (for GPT models) /// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
Tiktoken(Vec<usize>), Tiktoken(Vec<TokenIdType>),
} }
impl Encoding { impl Encoding {
pub fn token_ids(&self) -> Vec<u32> { /// Returns a reference to token IDs when possible, owned Vec for compatibility
pub fn token_ids(&self) -> Vec<TokenIdType> {
match self { match self {
Encoding::Hf(inner) => inner.get_ids().to_vec(), Encoding::Hf(inner) => inner.get_ids().to_vec(),
Encoding::Sp(inner) => inner.clone(), Encoding::Sp(inner) => inner.clone(),
Encoding::Tiktoken(inner) => inner.iter().map(|&id| id as u32).collect(), Encoding::Tiktoken(inner) => inner.clone(),
} }
} }
pub fn token_ids_ref(&self) -> &[u32] { /// Returns a reference to token IDs where possible
pub fn token_ids_ref(&self) -> &[TokenIdType] {
match self { match self {
Encoding::Hf(inner) => inner.get_ids(), Encoding::Hf(inner) => inner.get_ids(),
Encoding::Sp(inner) => inner, Encoding::Sp(inner) => inner,
Encoding::Tiktoken(_) => { Encoding::Tiktoken(inner) => inner, // Now works with tiktoken-rs 0.7.0!
// Tiktoken uses usize, we can't return a reference to u32 }
// This is a limitation - callers should use token_ids() for Tiktoken }
&[]
/// Get a hash of the token IDs for caching purposes
pub fn get_hash(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
} }
}
/// Hash implementation for Encoding
impl Hash for Encoding {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Encoding::Hf(inner) => inner.get_ids().hash(state),
Encoding::Sp(inner) => inner.hash(state),
Encoding::Tiktoken(inner) => inner.hash(state),
} }
} }
} }
......
#[cfg(test)]
mod tests {
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
#[test]
#[cfg(feature = "huggingface")]
fn test_chat_message_helpers() {
let system_msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(system_msg.role, "system");
assert_eq!(system_msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
assert_eq!(user_msg.content, "Hello!");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
assert_eq!(assistant_msg.content, "Hi there!");
}
#[test]
#[cfg(feature = "huggingface")]
fn test_llama_style_template() {
// Test a Llama-style chat template
let template = r#"
{%- if messages[0]['role'] == 'system' -%}
{%- set system_message = messages[0]['content'] -%}
{%- set messages = messages[1:] -%}
{%- else -%}
{%- set system_message = '' -%}
{%- endif -%}
{{- bos_token }}
{%- if system_message %}
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
{%- endif %}
{%- for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<|begin_of_text|>".to_string()),
Some("<|end_of_text|>".to_string()),
);
let messages = vec![
ChatMessage::system("You are a helpful assistant"),
ChatMessage::user("What is 2+2?"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
// Check that the result contains expected markers
assert!(result.contains("<|begin_of_text|>"));
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
assert!(result.contains("You are a helpful assistant"));
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
assert!(result.contains("What is 2+2?"));
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_chatml_template() {
// Test a ChatML-style template
let template = r#"
{%- for message in messages %}
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
// Check ChatML format
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
assert!(result.ends_with("<|im_start|>assistant\n"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_template_without_generation_prompt() {
let template = r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![ChatMessage::user("Test")];
// Test without generation prompt
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result.trim(), "user: Test");
// Test with generation prompt
let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap();
assert!(result_with_prompt.contains("assistant:"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_template_with_special_tokens() {
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = vec![ChatMessage::user("Hello")];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result, "<s>Hello</s>");
}
#[test]
#[cfg(feature = "huggingface")]
fn test_empty_messages() {
let template =
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result, "");
}
// Integration test with actual tokenizer file loading would go here
// but requires a real tokenizer_config.json file
}
#[cfg(test)]
mod tests {
use std::fs;
use tempfile::TempDir;
#[test]
#[cfg(feature = "huggingface")]
fn test_load_chat_template_from_file() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory
let temp_dir = TempDir::new().unwrap();
let template_path = temp_dir.path().join("template.jinja");
// Write a test template
let template_content = r#"
{%- for message in messages %}
{{- '<|' + message['role'] + '|>' + message['content'] }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|assistant|>' }}
{%- endif %}
"#;
fs::write(&template_path, template_content).unwrap();
// Create a mock tokenizer config
let tokenizer_config = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"hello": 0,
"world": 1,
"<s>": 2,
"</s>": 3
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_config).unwrap();
// Load tokenizer with custom chat template
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
tokenizer_path.to_str().unwrap(),
Some(template_path.to_str().unwrap()),
)
.unwrap();
// Test that the custom template is used
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there"),
];
let result = tokenizer.apply_chat_template(&messages, true).unwrap();
// Verify the custom template format
assert!(result.contains("<|user|>Hello"));
assert!(result.contains("<|assistant|>Hi there"));
assert!(result.ends_with("<|assistant|>"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_override_existing_template() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory
let temp_dir = TempDir::new().unwrap();
// Create tokenizer config with a built-in template
let tokenizer_config_path = temp_dir.path().join("tokenizer_config.json");
let config_with_template = r#"{
"chat_template": "built-in: {% for msg in messages %}{{ msg.content }}{% endfor %}"
}"#;
fs::write(&tokenizer_config_path, config_with_template).unwrap();
// Create the actual tokenizer file
let tokenizer_json = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"test": 0,
"<s>": 1,
"</s>": 2
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_json).unwrap();
// Create custom template that should override
let custom_template_path = temp_dir.path().join("custom.jinja");
let custom_template =
r#"CUSTOM: {% for msg in messages %}[{{ msg.role }}]: {{ msg.content }}{% endfor %}"#;
fs::write(&custom_template_path, custom_template).unwrap();
// Load with custom template - should override the built-in one
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
tokenizer_path.to_str().unwrap(),
Some(custom_template_path.to_str().unwrap()),
)
.unwrap();
let messages = vec![ChatMessage::user("Test")];
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
// Should use CUSTOM template, not built-in
assert!(result.starts_with("CUSTOM:"));
assert!(result.contains("[user]: Test"));
assert!(!result.contains("built-in:"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_set_chat_template_after_creation() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory and tokenizer file
let temp_dir = TempDir::new().unwrap();
let tokenizer_json = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"test": 0,
"<s>": 1,
"</s>": 2
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_json).unwrap();
// Load tokenizer without custom template
let mut tokenizer =
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()).unwrap();
// Set a template after creation (mimics Python's behavior)
let new_template =
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
tokenizer.set_chat_template(new_template.to_string());
let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")];
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
assert!(result.starts_with("NEW:"));
assert!(result.contains("user: Hello;"));
assert!(result.contains("assistant: World;"));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment