"profiler/profile_conv_fwd.cpp" did not exist on "e823d518cb46ad61ddb3c70eac8529e0a58af1f8"
Unverified Commit 5fbad308 authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] add tokenizer chat template support (#9370)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 7638f5e4
......@@ -5,7 +5,7 @@ edition = "2021"
[features]
default = ["huggingface", "grpc-client"]
huggingface = ["tokenizers"]
huggingface = ["tokenizers", "minijinja"]
tiktoken = ["tiktoken-rs"]
grpc-client = []
grpc-server = []
......@@ -52,7 +52,8 @@ url = "2.5.4"
tokio-stream = { version = "0.1", features = ["sync"] }
anyhow = "1.0"
tokenizers = { version = "0.21.4", optional = true }
tiktoken-rs = { version = "0.5", optional = true }
tiktoken-rs = { version = "0.7.0", optional = true }
minijinja = { version = "2.0", optional = true }
# gRPC and Protobuf dependencies
tonic = { version = "0.12", features = ["tls", "gzip", "transport"] }
......@@ -71,6 +72,7 @@ criterion = { version = "0.5", features = ["html_reports"] }
tower = { version = "0.5", features = ["util"] }
http-body-util = "0.1"
portpicker = "0.1"
tempfile = "3.8"
[[bench]]
name = "request_processing"
......
//! Chat template support for tokenizers using Jinja2 templates
//!
//! This module provides functionality to apply chat templates to messages,
//! similar to HuggingFace transformers' apply_chat_template method.
use anyhow::{anyhow, Result};
#[cfg(feature = "huggingface")]
use minijinja::{context, Environment, Value};
use serde::{Deserialize, Serialize};
use serde_json;
/// Represents a chat message with role and content
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
ChatMessage {
role: role.into(),
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
}
/// Chat template processor using Jinja2
#[cfg(feature = "huggingface")]
pub struct ChatTemplateProcessor {
template: String,
bos_token: Option<String>,
eos_token: Option<String>,
}
#[cfg(feature = "huggingface")]
impl ChatTemplateProcessor {
/// Create a new chat template processor
pub fn new(template: String, bos_token: Option<String>, eos_token: Option<String>) -> Self {
ChatTemplateProcessor {
template,
bos_token,
eos_token,
}
}
/// Apply the chat template to a list of messages
///
/// This mimics the behavior of HuggingFace's apply_chat_template method
/// but returns the formatted string instead of token IDs.
pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
add_generation_prompt: bool,
) -> Result<String> {
let mut env = Environment::new();
// Register the template
env.add_template("chat", &self.template)
.map_err(|e| anyhow!("Failed to add template: {}", e))?;
// Get the template
let tmpl = env
.get_template("chat")
.map_err(|e| anyhow!("Failed to get template: {}", e))?;
// Convert messages to a format Jinja can work with
let messages_value: Vec<Value> = messages
.iter()
.map(|msg| {
context! {
role => msg.role.clone(),
content => msg.content.clone()
}
})
.collect();
// Render the template
let rendered = tmpl
.render(context! {
messages => messages_value,
add_generation_prompt => add_generation_prompt,
bos_token => self.bos_token.clone().unwrap_or_default(),
eos_token => self.eos_token.clone().unwrap_or_default()
})
.map_err(|e| anyhow!("Failed to render template: {}", e))?;
Ok(rendered)
}
}
/// Load chat template from tokenizer config JSON
#[cfg(feature = "huggingface")]
pub fn load_chat_template_from_config(config_path: &str) -> Result<Option<String>> {
use std::fs;
let content = fs::read_to_string(config_path)?;
let config: serde_json::Value = serde_json::from_str(&content)?;
// Look for chat_template in the config
if let Some(template) = config.get("chat_template") {
if let Some(template_str) = template.as_str() {
return Ok(Some(template_str.to_string()));
}
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_message_creation() {
let msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(msg.role, "system");
assert_eq!(msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
}
#[cfg(feature = "huggingface")]
#[test]
fn test_simple_chat_template() {
// Simple template that formats messages
let template = r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![
ChatMessage::system("You are helpful"),
ChatMessage::user("Hello"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
assert!(result.contains("system: You are helpful"));
assert!(result.contains("user: Hello"));
assert!(result.contains("assistant:"));
}
#[cfg(feature = "huggingface")]
#[test]
fn test_chat_template_with_tokens() {
// Template that uses special tokens
let template = r#"
{{ bos_token }}
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}{{ eos_token }}
{% endfor -%}
"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = vec![ChatMessage::user("Test")];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert!(result.contains("<s>"));
assert!(result.contains("</s>"));
}
}
......@@ -26,6 +26,14 @@ pub enum TokenizerType {
/// - json: HuggingFace tokenizer
/// - For testing: can return mock tokenizer
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
create_tokenizer_with_chat_template(file_path, None)
}
/// Create a tokenizer from a file path with an optional chat template
pub fn create_tokenizer_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Arc<dyn traits::Tokenizer>> {
let start_time = Instant::now();
// Special case for testing
......@@ -51,7 +59,10 @@ pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tok
Some("json") => {
#[cfg(feature = "huggingface")]
{
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
file_path,
chat_template_path,
)?;
TokenizerMetrics::record_factory_load("json");
TokenizerMetrics::set_vocab_size("huggingface", tokenizer.vocab_size());
......
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
use crate::metrics::TokenizerMetrics;
use anyhow::{Error, Result};
use std::collections::HashMap;
use std::time::Instant;
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
#[cfg(feature = "minijinja")]
use super::chat_template::{ChatMessage, ChatTemplateProcessor};
/// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer {
tokenizer: HfTokenizer,
special_tokens: SpecialTokens,
vocab: HashMap<String, u32>,
reverse_vocab: HashMap<u32, String>,
vocab: HashMap<String, TokenIdType>,
reverse_vocab: HashMap<TokenIdType, String>,
#[cfg(feature = "minijinja")]
chat_template: Option<String>,
}
impl HuggingFaceTokenizer {
/// Create a tokenizer from a HuggingFace tokenizer JSON file
pub fn from_file(file_path: &str) -> Result<Self> {
Self::from_file_with_chat_template(file_path, None)
}
/// Create a tokenizer from a HuggingFace tokenizer JSON file with an optional chat template
pub fn from_file_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Self> {
let tokenizer = HfTokenizer::from_file(file_path)
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
......@@ -24,16 +39,28 @@ impl HuggingFaceTokenizer {
// Build vocab mappings
let vocab = tokenizer.get_vocab(false);
let reverse_vocab: HashMap<u32, String> = vocab
let reverse_vocab: HashMap<TokenIdType, String> = vocab
.iter()
.map(|(token, &id)| (id, token.clone()))
.collect();
// Load chat template
#[cfg(feature = "minijinja")]
let chat_template = if let Some(template_path) = chat_template_path {
// Load from specified .jinja file
Self::load_chat_template_from_file(template_path)?
} else {
// Try to load from tokenizer_config.json
Self::load_chat_template(file_path)
};
Ok(HuggingFaceTokenizer {
tokenizer,
special_tokens,
vocab,
reverse_vocab,
#[cfg(feature = "minijinja")]
chat_template,
})
}
......@@ -41,7 +68,7 @@ impl HuggingFaceTokenizer {
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
let special_tokens = Self::extract_special_tokens(&tokenizer);
let vocab = tokenizer.get_vocab(false);
let reverse_vocab: HashMap<u32, String> = vocab
let reverse_vocab: HashMap<TokenIdType, String> = vocab
.iter()
.map(|(token, &id)| (id, token.clone()))
.collect();
......@@ -51,6 +78,8 @@ impl HuggingFaceTokenizer {
special_tokens,
vocab,
reverse_vocab,
#[cfg(feature = "minijinja")]
chat_template: None,
}
}
......@@ -81,13 +110,86 @@ impl HuggingFaceTokenizer {
}
}
/// Try to load chat template from tokenizer_config.json
#[cfg(feature = "minijinja")]
fn load_chat_template(tokenizer_path: &str) -> Option<String> {
// Try to find tokenizer_config.json in the same directory
let path = std::path::Path::new(tokenizer_path);
let dir = path.parent()?;
let config_path = dir.join("tokenizer_config.json");
if config_path.exists() {
if let Ok(template) =
super::chat_template::load_chat_template_from_config(config_path.to_str()?)
{
return template;
}
}
None
}
/// Load chat template from a .jinja file
#[cfg(feature = "minijinja")]
fn load_chat_template_from_file(template_path: &str) -> Result<Option<String>> {
use std::fs;
let content = fs::read_to_string(template_path)
.map_err(|e| Error::msg(format!("Failed to read chat template file: {}", e)))?;
// Clean up the template (similar to Python implementation)
let template = content.trim().replace("\\n", "\n");
Ok(Some(template))
}
/// Set or override the chat template
#[cfg(feature = "minijinja")]
pub fn set_chat_template(&mut self, template: String) {
self.chat_template = Some(template);
}
/// Apply chat template if available
pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
// This is a placeholder - actual implementation would handle templates
#[cfg(feature = "minijinja")]
pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
add_generation_prompt: bool,
) -> Result<String> {
if let Some(ref template) = self.chat_template {
let processor = ChatTemplateProcessor::new(
template.clone(),
self.special_tokens.bos_token.clone(),
self.special_tokens.eos_token.clone(),
);
processor.apply_chat_template(messages, add_generation_prompt)
} else {
// Fallback to simple formatting if no template is available
let mut result = String::new();
for msg in messages {
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
if add_generation_prompt {
result.push_str("assistant: ");
}
Ok(result)
}
}
/// Apply chat template if available (without minijinja feature)
#[cfg(not(feature = "minijinja"))]
pub fn apply_chat_template(
&self,
messages: &[ChatMessage],
add_generation_prompt: bool,
) -> Result<String> {
// Fallback to simple formatting
let mut result = String::new();
for msg in messages {
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
if add_generation_prompt {
result.push_str("assistant: ");
}
Ok(result)
}
}
......@@ -133,7 +235,7 @@ impl Encoder for HuggingFaceTokenizer {
}
impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
let start = Instant::now();
TokenizerMetrics::record_decode_request("huggingface");
......@@ -160,47 +262,21 @@ impl TokenizerTrait for HuggingFaceTokenizer {
&self.special_tokens
}
fn token_to_id(&self, token: &str) -> Option<u32> {
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
self.reverse_vocab.get(&id).cloned()
}
}
/// Represents a chat message for template application
#[derive(Debug, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
ChatMessage {
role: role.into(),
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "minijinja")]
use super::ChatMessage;
#[cfg(feature = "minijinja")]
#[test]
fn test_chat_message_creation() {
let msg = ChatMessage::system("You are a helpful assistant");
......
......@@ -10,6 +10,9 @@ pub mod stream;
pub mod traits;
// Feature-gated modules
#[cfg(feature = "huggingface")]
pub mod chat_template;
#[cfg(feature = "huggingface")]
pub mod huggingface;
......@@ -20,14 +23,20 @@ pub mod tiktoken;
mod tests;
// Re-exports
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType};
pub use factory::{
create_tokenizer, create_tokenizer_from_file, create_tokenizer_with_chat_template,
TokenizerType,
};
pub use sequence::Sequence;
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
pub use stream::DecodeStream;
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
#[cfg(feature = "huggingface")]
pub use huggingface::{ChatMessage, HuggingFaceTokenizer};
pub use huggingface::HuggingFaceTokenizer;
#[cfg(feature = "huggingface")]
pub use chat_template::ChatMessage;
#[cfg(feature = "tiktoken")]
pub use tiktoken::{TiktokenModel, TiktokenTokenizer};
......@@ -42,6 +51,17 @@ impl Tokenizer {
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
}
/// Create a tokenizer from a file path with an optional chat template
pub fn from_file_with_chat_template(
file_path: &str,
chat_template_path: Option<&str>,
) -> Result<Tokenizer> {
Ok(Tokenizer(factory::create_tokenizer_with_chat_template(
file_path,
chat_template_path,
)?))
}
/// Create a tokenizer from an Arc<dyn Tokenizer>
pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
Tokenizer(tokenizer)
......
use super::traits::Tokenizer as TokenizerTrait;
use super::traits::{TokenIdType, Tokenizer as TokenizerTrait};
use anyhow::Result;
use std::sync::Arc;
......@@ -9,7 +9,7 @@ pub struct Sequence {
tokenizer: Arc<dyn TokenizerTrait>,
/// The current sequence of token ids
token_ids: Vec<u32>,
token_ids: Vec<TokenIdType>,
/// The position in the current sequence the last decoded token completed
prefix_offset: usize,
......@@ -54,7 +54,7 @@ impl Sequence {
}
/// Create a sequence with initial tokens
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<u32>) -> Self {
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
let len = token_ids.len();
Self {
tokenizer,
......@@ -90,7 +90,7 @@ impl Sequence {
/// Append a single token to the sequence and return newly decoded text
/// Based on HuggingFace TGI incremental decoding
pub fn append_token(&mut self, token_id: u32) -> Result<String> {
pub fn append_token(&mut self, token_id: TokenIdType) -> Result<String> {
// Store the old read offset before adding the new token
let old_read_offset = self.read_offset;
......@@ -145,7 +145,7 @@ impl Sequence {
}
/// Get the current token ids
pub fn token_ids(&self) -> &[u32] {
pub fn token_ids(&self) -> &[TokenIdType] {
&self.token_ids
}
......
use super::traits;
use super::traits::{self, TokenIdType};
use crate::metrics::TokenizerMetrics;
use anyhow::Result;
use std::collections::HashSet;
......@@ -22,18 +22,18 @@ pub enum SequenceDecoderOutput {
#[derive(Debug, Clone, Default)]
pub struct StopSequenceConfig {
/// Token IDs that trigger a stop
pub stop_tokens: HashSet<u32>,
pub stop_tokens: HashSet<TokenIdType>,
/// String sequences that trigger a stop
pub stop_sequences: Vec<String>,
/// Token IDs for visible stops (included in output)
pub visible_stop_tokens: HashSet<u32>,
pub visible_stop_tokens: HashSet<TokenIdType>,
/// String sequences for visible stops (included in output)
pub visible_stop_sequences: Vec<String>,
}
impl StopSequenceConfig {
/// Builder pattern - add a stop token
pub fn with_stop_token(mut self, token_id: u32) -> Self {
pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self {
self.stop_tokens.insert(token_id);
self
}
......@@ -45,7 +45,7 @@ impl StopSequenceConfig {
}
/// Builder pattern - add a visible stop token
pub fn with_visible_stop_token(mut self, token_id: u32) -> Self {
pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self {
self.visible_stop_tokens.insert(token_id);
self
}
......@@ -64,7 +64,7 @@ pub struct StopSequenceDecoder {
/// Buffer for partial matches (the "jail")
jail_buffer: String,
/// Accumulated tokens
token_buffer: Vec<u32>,
token_buffer: Vec<TokenIdType>,
/// Offset where the prefix text starts (for context)
prefix_offset: usize,
/// Offset marking the end of previously decoded text
......@@ -94,7 +94,7 @@ impl StopSequenceDecoder {
}
/// Process a single token
pub fn process_token(&mut self, token_id: u32) -> Result<SequenceDecoderOutput> {
pub fn process_token(&mut self, token_id: TokenIdType) -> Result<SequenceDecoderOutput> {
let start = Instant::now();
if self.stopped {
......@@ -252,7 +252,10 @@ impl StopSequenceDecoder {
}
/// Process multiple tokens
pub fn process_tokens(&mut self, token_ids: &[u32]) -> Result<Vec<SequenceDecoderOutput>> {
pub fn process_tokens(
&mut self,
token_ids: &[TokenIdType],
) -> Result<Vec<SequenceDecoderOutput>> {
let mut outputs = Vec::new();
for &token_id in token_ids {
outputs.push(self.process_token(token_id)?);
......@@ -302,7 +305,7 @@ impl StopSequenceDecoderBuilder {
}
}
pub fn stop_token(mut self, token_id: u32) -> Self {
pub fn stop_token(mut self, token_id: TokenIdType) -> Self {
self.config.stop_tokens.insert(token_id);
self
}
......@@ -312,7 +315,7 @@ impl StopSequenceDecoderBuilder {
self
}
pub fn visible_stop_token(mut self, token_id: u32) -> Self {
pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self {
self.config.visible_stop_tokens.insert(token_id);
self
}
......
// src/tokenizer/stream.rs
use super::traits;
use super::traits::{self, TokenIdType};
use crate::metrics::TokenizerMetrics;
use anyhow::Result;
use std::sync::Arc;
......@@ -18,7 +18,7 @@ pub struct DecodeStream {
/// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks
all_token_ids: Vec<u32>,
all_token_ids: Vec<TokenIdType>,
prefix_offset: usize,
read_offset: usize,
......@@ -27,7 +27,7 @@ pub struct DecodeStream {
impl DecodeStream {
pub fn new(
tokenizer: Arc<dyn traits::Tokenizer>,
prompt_token_ids: &[u32],
prompt_token_ids: &[TokenIdType],
skip_special_tokens: bool,
) -> Self {
let num_input_tokens = prompt_token_ids.len();
......@@ -44,7 +44,7 @@ impl DecodeStream {
/// Step appends a token_id to the internal state and tries to produce a text chunk.
/// Returning `None` means the given id is not enough to produce a chunk.
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
pub fn step(&mut self, id: TokenIdType) -> Result<Option<String>> {
let start = Instant::now();
self.all_token_ids.push(id);
......
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
use super::traits::{
Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer as TokenizerTrait,
};
use anyhow::{Error, Result};
use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE};
......@@ -140,12 +142,10 @@ impl Encoder for TiktokenTokenizer {
}
impl Decoder for TiktokenTokenizer {
fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result<String> {
// Convert u32 to usize for tiktoken-rs
let tokens: Vec<usize> = token_ids.iter().map(|&id| id as usize).collect();
fn decode(&self, token_ids: &[TokenIdType], _skip_special_tokens: bool) -> Result<String> {
// tiktoken-rs 0.7.0 now uses u32 (Rank type)
self.tokenizer
.decode(tokens)
.decode(token_ids.to_vec())
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
}
}
......@@ -159,13 +159,13 @@ impl TokenizerTrait for TiktokenTokenizer {
&self.special_tokens
}
fn token_to_id(&self, _token: &str) -> Option<u32> {
fn token_to_id(&self, _token: &str) -> Option<TokenIdType> {
// Tiktoken doesn't provide direct token-to-id mapping
// We'd need to encode the token and check if it produces a single ID
None
}
fn id_to_token(&self, _id: u32) -> Option<String> {
fn id_to_token(&self, _id: TokenIdType) -> Option<String> {
// Tiktoken doesn't provide direct id-to-token mapping
// We can only decode IDs to text
None
......
use anyhow::Result;
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
/// Type alias for token IDs
pub type TokenIdType = u32;
/// Core encoding trait - separate from decoding for modularity
pub trait Encoder: Send + Sync {
......@@ -8,15 +13,15 @@ pub trait Encoder: Send + Sync {
/// Core decoding trait - can be implemented independently
pub trait Decoder: Send + Sync {
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String>;
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
}
/// Combined tokenizer trait
pub trait Tokenizer: Encoder + Decoder {
fn vocab_size(&self) -> usize;
fn get_special_tokens(&self) -> &SpecialTokens;
fn token_to_id(&self, token: &str) -> Option<u32>;
fn id_to_token(&self, id: u32) -> Option<String>;
fn token_to_id(&self, token: &str) -> Option<TokenIdType>;
fn id_to_token(&self, id: TokenIdType) -> Option<String>;
}
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
......@@ -25,29 +30,45 @@ pub enum Encoding {
/// Hugging Face
Hf(Box<tokenizers::tokenizer::Encoding>),
/// Sentence Piece
Sp(Vec<u32>),
/// Tiktoken (for GPT models)
Tiktoken(Vec<usize>),
Sp(Vec<TokenIdType>),
/// Tiktoken (for GPT models) - now uses u32 in tiktoken-rs 0.7.0
Tiktoken(Vec<TokenIdType>),
}
impl Encoding {
pub fn token_ids(&self) -> Vec<u32> {
/// Returns a reference to token IDs when possible, owned Vec for compatibility
pub fn token_ids(&self) -> Vec<TokenIdType> {
match self {
Encoding::Hf(inner) => inner.get_ids().to_vec(),
Encoding::Sp(inner) => inner.clone(),
Encoding::Tiktoken(inner) => inner.iter().map(|&id| id as u32).collect(),
Encoding::Tiktoken(inner) => inner.clone(),
}
}
pub fn token_ids_ref(&self) -> &[u32] {
/// Returns a reference to token IDs where possible
pub fn token_ids_ref(&self) -> &[TokenIdType] {
match self {
Encoding::Hf(inner) => inner.get_ids(),
Encoding::Sp(inner) => inner,
Encoding::Tiktoken(_) => {
// Tiktoken uses usize, we can't return a reference to u32
// This is a limitation - callers should use token_ids() for Tiktoken
&[]
}
Encoding::Tiktoken(inner) => inner, // Now works with tiktoken-rs 0.7.0!
}
}
/// Get a hash of the token IDs for caching purposes
pub fn get_hash(&self) -> u64 {
let mut hasher = DefaultHasher::new();
self.hash(&mut hasher);
hasher.finish()
}
}
/// Hash implementation for Encoding
impl Hash for Encoding {
fn hash<H: Hasher>(&self, state: &mut H) {
match self {
Encoding::Hf(inner) => inner.get_ids().hash(state),
Encoding::Sp(inner) => inner.hash(state),
Encoding::Tiktoken(inner) => inner.hash(state),
}
}
}
......
#[cfg(test)]
mod tests {
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
#[test]
#[cfg(feature = "huggingface")]
fn test_chat_message_helpers() {
let system_msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(system_msg.role, "system");
assert_eq!(system_msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
assert_eq!(user_msg.content, "Hello!");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
assert_eq!(assistant_msg.content, "Hi there!");
}
#[test]
#[cfg(feature = "huggingface")]
fn test_llama_style_template() {
// Test a Llama-style chat template
let template = r#"
{%- if messages[0]['role'] == 'system' -%}
{%- set system_message = messages[0]['content'] -%}
{%- set messages = messages[1:] -%}
{%- else -%}
{%- set system_message = '' -%}
{%- endif -%}
{{- bos_token }}
{%- if system_message %}
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
{%- endif %}
{%- for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<|begin_of_text|>".to_string()),
Some("<|end_of_text|>".to_string()),
);
let messages = vec![
ChatMessage::system("You are a helpful assistant"),
ChatMessage::user("What is 2+2?"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
// Check that the result contains expected markers
assert!(result.contains("<|begin_of_text|>"));
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
assert!(result.contains("You are a helpful assistant"));
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
assert!(result.contains("What is 2+2?"));
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_chatml_template() {
// Test a ChatML-style template
let template = r#"
{%- for message in messages %}
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
// Check ChatML format
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
assert!(result.ends_with("<|im_start|>assistant\n"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_template_without_generation_prompt() {
let template = r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![ChatMessage::user("Test")];
// Test without generation prompt
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result.trim(), "user: Test");
// Test with generation prompt
let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap();
assert!(result_with_prompt.contains("assistant:"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_template_with_special_tokens() {
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = vec![ChatMessage::user("Hello")];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result, "<s>Hello</s>");
}
#[test]
#[cfg(feature = "huggingface")]
fn test_empty_messages() {
let template =
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result, "");
}
// Integration test with actual tokenizer file loading would go here
// but requires a real tokenizer_config.json file
}
#[cfg(test)]
mod tests {
use std::fs;
use tempfile::TempDir;
#[test]
#[cfg(feature = "huggingface")]
fn test_load_chat_template_from_file() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory
let temp_dir = TempDir::new().unwrap();
let template_path = temp_dir.path().join("template.jinja");
// Write a test template
let template_content = r#"
{%- for message in messages %}
{{- '<|' + message['role'] + '|>' + message['content'] }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|assistant|>' }}
{%- endif %}
"#;
fs::write(&template_path, template_content).unwrap();
// Create a mock tokenizer config
let tokenizer_config = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"hello": 0,
"world": 1,
"<s>": 2,
"</s>": 3
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_config).unwrap();
// Load tokenizer with custom chat template
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
tokenizer_path.to_str().unwrap(),
Some(template_path.to_str().unwrap()),
)
.unwrap();
// Test that the custom template is used
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there"),
];
let result = tokenizer.apply_chat_template(&messages, true).unwrap();
// Verify the custom template format
assert!(result.contains("<|user|>Hello"));
assert!(result.contains("<|assistant|>Hi there"));
assert!(result.ends_with("<|assistant|>"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_override_existing_template() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory
let temp_dir = TempDir::new().unwrap();
// Create tokenizer config with a built-in template
let tokenizer_config_path = temp_dir.path().join("tokenizer_config.json");
let config_with_template = r#"{
"chat_template": "built-in: {% for msg in messages %}{{ msg.content }}{% endfor %}"
}"#;
fs::write(&tokenizer_config_path, config_with_template).unwrap();
// Create the actual tokenizer file
let tokenizer_json = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"test": 0,
"<s>": 1,
"</s>": 2
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_json).unwrap();
// Create custom template that should override
let custom_template_path = temp_dir.path().join("custom.jinja");
let custom_template =
r#"CUSTOM: {% for msg in messages %}[{{ msg.role }}]: {{ msg.content }}{% endfor %}"#;
fs::write(&custom_template_path, custom_template).unwrap();
// Load with custom template - should override the built-in one
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
tokenizer_path.to_str().unwrap(),
Some(custom_template_path.to_str().unwrap()),
)
.unwrap();
let messages = vec![ChatMessage::user("Test")];
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
// Should use CUSTOM template, not built-in
assert!(result.starts_with("CUSTOM:"));
assert!(result.contains("[user]: Test"));
assert!(!result.contains("built-in:"));
}
#[test]
#[cfg(feature = "huggingface")]
fn test_set_chat_template_after_creation() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory and tokenizer file
let temp_dir = TempDir::new().unwrap();
let tokenizer_json = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"test": 0,
"<s>": 1,
"</s>": 2
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_json).unwrap();
// Load tokenizer without custom template
let mut tokenizer =
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()).unwrap();
// Set a template after creation (mimics Python's behavior)
let new_template =
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
tokenizer.set_chat_template(new_template.to_string());
let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")];
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
assert!(result.starts_with("NEW:"));
assert!(result.contains("user: Hello;"));
assert!(result.contains("assistant: World;"));
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment