Unverified Commit a7ae61ed authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Add Configurable L0 and L1 Tokenizer Caching (#11688)

parent fda0cb2a
......@@ -38,6 +38,7 @@ futures-util = "0.3"
futures = "0.3"
pyo3 = { version = "0.25.1", features = ["extension-module"] }
dashmap = "6.1.0"
blake3 = "1.5"
http = "1.1.0"
tokio = { version = "1.42.0", features = ["full"] }
async-trait = "0.1"
......@@ -53,6 +54,7 @@ metrics-exporter-prometheus = "0.17.0"
uuid = { version = "1.10", features = ["v4", "serde"] }
ulid = "1.2.1"
parking_lot = "0.12.4"
rayon = "1.10"
thiserror = "2.0.12"
regex = "1.10"
url = "2.5.4"
......
This diff is collapsed.
......@@ -87,6 +87,11 @@ class RouterArgs:
model_path: Optional[str] = None
tokenizer_path: Optional[str] = None
chat_template: Optional[str] = None
# Tokenizer cache configuration
tokenizer_cache_enable_l0: bool = False
tokenizer_cache_l0_max_entries: int = 10000
tokenizer_cache_enable_l1: bool = False
tokenizer_cache_l1_max_memory: int = 50 * 1024 * 1024 # 50MB
reasoning_parser: Optional[str] = None
tool_call_parser: Optional[str] = None
# Backend selection
......@@ -467,6 +472,30 @@ class RouterArgs:
default=None,
help="Chat template path (optional)",
)
parser.add_argument(
f"--{prefix}tokenizer-cache-enable-l0",
action="store_true",
default=RouterArgs.tokenizer_cache_enable_l0,
help="Enable L0 (whole-string exact match) tokenizer cache (default: False)",
)
parser.add_argument(
f"--{prefix}tokenizer-cache-l0-max-entries",
type=int,
default=RouterArgs.tokenizer_cache_l0_max_entries,
help="Maximum number of entries in L0 tokenizer cache (default: 10000)",
)
parser.add_argument(
f"--{prefix}tokenizer-cache-enable-l1",
action="store_true",
default=RouterArgs.tokenizer_cache_enable_l1,
help="Enable L1 (prefix matching) tokenizer cache (default: False)",
)
parser.add_argument(
f"--{prefix}tokenizer-cache-l1-max-memory",
type=int,
default=RouterArgs.tokenizer_cache_l1_max_memory,
help="Maximum memory for L1 tokenizer cache in bytes (default: 50MB)",
)
parser.add_argument(
f"--{prefix}reasoning-parser",
type=str,
......
......@@ -81,6 +81,53 @@ pub struct RouterConfig {
pub reasoning_parser: Option<String>,
/// Parser for handling tool-call interactions
pub tool_call_parser: Option<String>,
/// Tokenizer cache configuration
#[serde(default)]
pub tokenizer_cache: TokenizerCacheConfig,
}
/// Tokenizer cache configuration
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct TokenizerCacheConfig {
/// Enable L0 cache (whole-string exact match)
#[serde(default = "default_enable_l0")]
pub enable_l0: bool,
/// Maximum number of entries in L0 cache
#[serde(default = "default_l0_max_entries")]
pub l0_max_entries: usize,
/// Enable L1 cache (prefix matching at fixed boundaries)
#[serde(default = "default_enable_l1")]
pub enable_l1: bool,
/// Maximum memory for L1 cache in bytes
#[serde(default = "default_l1_max_memory")]
pub l1_max_memory: usize,
}
fn default_enable_l0() -> bool {
false
}
fn default_l0_max_entries() -> usize {
10_000
}
fn default_enable_l1() -> bool {
false
}
fn default_l1_max_memory() -> usize {
50 * 1024 * 1024 // 50MB
}
impl Default for TokenizerCacheConfig {
fn default() -> Self {
Self {
enable_l0: default_enable_l0(),
l0_max_entries: default_l0_max_entries(),
enable_l1: default_enable_l1(),
l1_max_memory: default_l1_max_memory(),
}
}
}
fn default_history_backend() -> HistoryBackend {
......@@ -459,6 +506,7 @@ impl Default for RouterConfig {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
}
}
}
......@@ -1004,6 +1052,7 @@ mod tests {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
assert!(config.mode.is_pd_mode());
......@@ -1072,6 +1121,7 @@ mod tests {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
assert!(!config.mode.is_pd_mode());
......@@ -1136,6 +1186,7 @@ mod tests {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: TokenizerCacheConfig::default(),
};
assert!(config.has_service_discovery());
......
......@@ -42,6 +42,9 @@ impl ConfigValidator {
}
}
// Validate tokenizer cache configuration
Self::validate_tokenizer_cache(&config.tokenizer_cache)?;
Ok(())
}
......@@ -446,6 +449,29 @@ impl ConfigValidator {
Ok(())
}
/// Validate tokenizer cache configuration
fn validate_tokenizer_cache(cache: &TokenizerCacheConfig) -> ConfigResult<()> {
// Validate L0 max entries when L0 is enabled
if cache.enable_l0 && cache.l0_max_entries == 0 {
return Err(ConfigError::InvalidValue {
field: "tokenizer_cache.l0_max_entries".to_string(),
value: cache.l0_max_entries.to_string(),
reason: "Must be > 0 when L0 cache is enabled".to_string(),
});
}
// Validate L1 max memory when L1 is enabled
if cache.enable_l1 && cache.l1_max_memory == 0 {
return Err(ConfigError::InvalidValue {
field: "tokenizer_cache.l1_max_memory".to_string(),
value: cache.l1_max_memory.to_string(),
reason: "Must be > 0 when L1 cache is enabled".to_string(),
});
}
Ok(())
}
/// Validate compatibility between different configuration sections
fn validate_compatibility(config: &RouterConfig) -> ConfigResult<()> {
// IGW mode is independent - skip other compatibility checks when enabled
......
......@@ -198,6 +198,10 @@ struct Router {
model_path: Option<String>,
tokenizer_path: Option<String>,
chat_template: Option<String>,
tokenizer_cache_enable_l0: bool,
tokenizer_cache_l0_max_entries: usize,
tokenizer_cache_enable_l1: bool,
tokenizer_cache_l1_max_memory: usize,
reasoning_parser: Option<String>,
tool_call_parser: Option<String>,
backend: BackendType,
......@@ -350,6 +354,12 @@ impl Router {
oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
tokenizer_cache: config::TokenizerCacheConfig {
enable_l0: self.tokenizer_cache_enable_l0,
l0_max_entries: self.tokenizer_cache_l0_max_entries,
enable_l1: self.tokenizer_cache_enable_l1,
l1_max_memory: self.tokenizer_cache_l1_max_memory,
},
})
}
}
......@@ -415,6 +425,10 @@ impl Router {
model_path = None,
tokenizer_path = None,
chat_template = None,
tokenizer_cache_enable_l0 = false,
tokenizer_cache_l0_max_entries = 10000,
tokenizer_cache_enable_l1 = false,
tokenizer_cache_l1_max_memory = 52428800,
reasoning_parser = None,
tool_call_parser = None,
backend = BackendType::Sglang,
......@@ -480,6 +494,10 @@ impl Router {
model_path: Option<String>,
tokenizer_path: Option<String>,
chat_template: Option<String>,
tokenizer_cache_enable_l0: bool,
tokenizer_cache_l0_max_entries: usize,
tokenizer_cache_enable_l1: bool,
tokenizer_cache_l1_max_memory: usize,
reasoning_parser: Option<String>,
tool_call_parser: Option<String>,
backend: BackendType,
......@@ -559,6 +577,10 @@ impl Router {
model_path,
tokenizer_path,
chat_template,
tokenizer_cache_enable_l0,
tokenizer_cache_l0_max_entries,
tokenizer_cache_enable_l1,
tokenizer_cache_l1_max_memory,
reasoning_parser,
tool_call_parser,
backend,
......
......@@ -5,7 +5,7 @@ use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConfigError, ConfigResult, ConnectionMode, DiscoveryConfig,
HealthCheckConfig, HistoryBackend, MetricsConfig, OracleConfig, PolicyConfig, RetryConfig,
RouterConfig, RoutingMode,
RouterConfig, RoutingMode, TokenizerCacheConfig,
},
metrics::PrometheusConfig,
server::{self, ServerConfig},
......@@ -270,6 +270,18 @@ struct CliArgs {
#[arg(long)]
chat_template: Option<String>,
#[arg(long, default_value_t = false)]
tokenizer_cache_enable_l0: bool,
#[arg(long, default_value_t = 10000)]
tokenizer_cache_l0_max_entries: usize,
#[arg(long, default_value_t = false)]
tokenizer_cache_enable_l1: bool,
#[arg(long, default_value_t = 52428800)]
tokenizer_cache_l1_max_memory: usize,
#[arg(long, default_value = "memory", value_parser = ["memory", "none", "oracle"])]
history_backend: String,
......@@ -581,6 +593,12 @@ impl CliArgs {
oracle,
reasoning_parser: self.reasoning_parser.clone(),
tool_call_parser: self.tool_call_parser.clone(),
tokenizer_cache: TokenizerCacheConfig {
enable_l0: self.tokenizer_cache_enable_l0,
l0_max_entries: self.tokenizer_cache_l0_max_entries,
enable_l1: self.tokenizer_cache_enable_l1,
l1_max_memory: self.tokenizer_cache_l1_max_memory,
},
})
}
......
......@@ -26,6 +26,7 @@ use crate::{
generate::GenerateFinishReason,
},
tokenizer::{
cache::CachedTokenizer,
chat_template::{ChatTemplateContentFormat, ChatTemplateParams},
traits::Tokenizer,
HuggingFaceTokenizer,
......@@ -317,9 +318,24 @@ pub fn process_chat_messages(
tokenizer: &dyn Tokenizer,
) -> Result<ProcessedMessages, String> {
// Use the tokenizer's chat template - we require HuggingFace tokenizer for gRPC
let formatted_text = if let Some(hf_tokenizer) =
tokenizer.as_any().downcast_ref::<HuggingFaceTokenizer>()
{
// First try direct downcast, then try via CachedTokenizer wrapper
let hf_tokenizer = tokenizer
.as_any()
.downcast_ref::<HuggingFaceTokenizer>()
.or_else(|| {
// If direct downcast fails, try to get inner tokenizer from CachedTokenizer
tokenizer
.as_any()
.downcast_ref::<CachedTokenizer>()
.and_then(|cached| {
cached
.inner()
.as_any()
.downcast_ref::<HuggingFaceTokenizer>()
})
});
let formatted_text = if let Some(hf_tokenizer) = hf_tokenizer {
// Get content format and transform messages accordingly
let content_format = hf_tokenizer.chat_template_content_format();
let mut transformed_messages = process_content_format(&request.messages, content_format)?;
......
......@@ -48,7 +48,11 @@ use crate::{
reasoning_parser::ParserFactory as ReasoningParserFactory,
routers::{router_manager::RouterManager, RouterTrait},
service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
tokenizer::{
cache::{CacheConfig, CachedTokenizer},
factory as tokenizer_factory,
traits::Tokenizer,
},
tool_parser::ParserFactory as ToolParserFactory,
};
......@@ -864,7 +868,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
.to_string()
})?;
let tokenizer = Some(
let base_tokenizer =
tokenizer_factory::create_tokenizer_with_chat_template_blocking(
&tokenizer_path,
config.router_config.chat_template.as_deref(),
......@@ -876,8 +880,23 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
or a HuggingFace model ID. For directories, ensure they contain tokenizer files.",
tokenizer_path, e
)
})?,
);
})?;
// Conditionally wrap with caching layer if at least one cache is enabled
let tokenizer = if config.router_config.tokenizer_cache.enable_l0
|| config.router_config.tokenizer_cache.enable_l1
{
let cache_config = CacheConfig {
enable_l0: config.router_config.tokenizer_cache.enable_l0,
l0_max_entries: config.router_config.tokenizer_cache.l0_max_entries,
enable_l1: config.router_config.tokenizer_cache.enable_l1,
l1_max_memory: config.router_config.tokenizer_cache.l1_max_memory,
};
Some(Arc::new(CachedTokenizer::new(base_tokenizer, cache_config)) as Arc<dyn Tokenizer>)
} else {
// Use base tokenizer directly without caching
Some(base_tokenizer)
};
let reasoning_parser_factory = Some(ReasoningParserFactory::new());
let tool_parser_factory = Some(ToolParserFactory::new());
......
//! Tokenizer Fingerprinting for Cache Invalidation
//!
//! Creates a unique fingerprint of a tokenizer's configuration to detect
//! when the tokenizer has changed and the cache needs to be cleared.
use std::{
collections::hash_map::DefaultHasher,
hash::{Hash, Hasher},
};
use super::super::traits::Tokenizer;
/// A fingerprint of a tokenizer's configuration
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct TokenizerFingerprint {
/// Size of the vocabulary
pub vocab_size: usize,
/// Hash of a sample of vocabulary tokens (for speed)
pub vocab_hash: u64,
/// Hash of special tokens
pub special_tokens_hash: u64,
}
impl TokenizerFingerprint {
/// Create a fingerprint from a tokenizer
pub fn from_tokenizer(tokenizer: &dyn Tokenizer) -> Self {
let vocab_size = tokenizer.vocab_size();
let vocab_hash = Self::compute_vocab_hash(tokenizer);
let special_tokens_hash = Self::compute_special_tokens_hash(tokenizer);
Self {
vocab_size,
vocab_hash,
special_tokens_hash,
}
}
/// Compute a hash of the vocabulary by sampling tokens
fn compute_vocab_hash(tokenizer: &dyn Tokenizer) -> u64 {
let mut hasher = DefaultHasher::new();
let vocab_size = tokenizer.vocab_size();
// Sample up to 1000 tokens for speed
let sample_size = vocab_size.min(1000);
let step = if sample_size > 0 {
vocab_size / sample_size
} else {
1
};
for i in (0..vocab_size).step_by(step.max(1)) {
if let Some(token) = tokenizer.id_to_token(i as u32) {
token.hash(&mut hasher);
}
}
hasher.finish()
}
/// Compute a hash of special tokens
fn compute_special_tokens_hash(tokenizer: &dyn Tokenizer) -> u64 {
let mut hasher = DefaultHasher::new();
let special_tokens = tokenizer.get_special_tokens();
special_tokens.bos_token.hash(&mut hasher);
special_tokens.eos_token.hash(&mut hasher);
special_tokens.unk_token.hash(&mut hasher);
special_tokens.sep_token.hash(&mut hasher);
special_tokens.pad_token.hash(&mut hasher);
special_tokens.cls_token.hash(&mut hasher);
special_tokens.mask_token.hash(&mut hasher);
special_tokens.additional_special_tokens.hash(&mut hasher);
hasher.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::mock::MockTokenizer;
#[test]
fn test_fingerprint_equality() {
let tokenizer1 = MockTokenizer::new();
let tokenizer2 = MockTokenizer::new();
let fp1 = TokenizerFingerprint::from_tokenizer(&tokenizer1);
let fp2 = TokenizerFingerprint::from_tokenizer(&tokenizer2);
// Same tokenizer config should produce same fingerprint
assert_eq!(fp1, fp2);
}
#[test]
fn test_fingerprint_consistency() {
let tokenizer = MockTokenizer::new();
let fp1 = TokenizerFingerprint::from_tokenizer(&tokenizer);
let fp2 = TokenizerFingerprint::from_tokenizer(&tokenizer);
// Fingerprint should be consistent
assert_eq!(fp1, fp2);
assert_eq!(fp1.vocab_size, tokenizer.vocab_size());
}
}
//! L0 Cache: Whole-string exact match cache
//!
//! This is the simplest and most effective cache layer.
//! Key: input string → Value: full encoding result
//!
//! Expected hit rate: 60-90% for workloads with repeated system prompts
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc,
};
use dashmap::DashMap;
use super::super::traits::Encoding;
/// L0 cache implementation using DashMap for lock-free reads
pub struct L0Cache {
/// The cache map: input string → encoding
map: Arc<DashMap<String, Encoding>>,
/// Maximum number of entries before eviction
max_entries: usize,
/// Cache hit counter
hits: AtomicU64,
/// Cache miss counter
misses: AtomicU64,
}
impl L0Cache {
/// Create a new L0 cache with the specified capacity
pub fn new(max_entries: usize) -> Self {
Self {
map: Arc::new(DashMap::with_capacity(max_entries.min(1024))),
max_entries,
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
}
}
/// Get an encoding from the cache
pub fn get(&self, key: &str) -> Option<Encoding> {
match self.map.get(key) {
Some(entry) => {
self.hits.fetch_add(1, Ordering::Relaxed);
Some(entry.value().clone())
}
None => {
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
}
}
/// Insert an encoding into the cache
pub fn insert(&self, key: String, value: Encoding) {
// Simple eviction: if we're at capacity, remove a random entry
// DashMap doesn't support LRU directly, so we use a simple strategy
if self.map.len() >= self.max_entries {
// Get the key to remove in a separate scope to ensure iterator is dropped
let key_to_remove = { self.map.iter().next().map(|entry| entry.key().clone()) }; // Iterator fully dropped here, all locks released
// Now remove it
if let Some(k) = key_to_remove {
self.map.remove(&k);
}
}
self.map.insert(key, value);
}
/// Get the current number of entries in the cache
pub fn len(&self) -> usize {
self.map.len()
}
/// Check if the cache is empty
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
/// Get cache statistics
pub fn stats(&self) -> CacheStats {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total_requests = hits + misses;
CacheStats {
hits,
misses,
entries: self.len(),
hit_rate: if total_requests > 0 {
hits as f64 / total_requests as f64
} else {
0.0
},
}
}
/// Clear the cache
pub fn clear(&self) {
self.map.clear();
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
}
/// Estimate memory usage in bytes
pub fn memory_usage(&self) -> usize {
// Rough estimate:
// - Each entry: key (string) + value (encoding ~250 tokens * 4 bytes) + overhead
// - Average: ~2.2KB per entry
self.len() * 2200
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
pub hit_rate: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::traits::Encoding;
fn mock_encoding(tokens: Vec<u32>) -> Encoding {
Encoding::Sp(tokens)
}
#[test]
fn test_basic_get_set() {
let cache = L0Cache::new(10);
// Miss
assert!(cache.get("hello").is_none());
// Insert
cache.insert("hello".to_string(), mock_encoding(vec![1, 2, 3]));
// Hit
let result = cache.get("hello");
assert!(result.is_some());
assert_eq!(result.unwrap().token_ids(), &[1, 2, 3]);
}
#[test]
fn test_eviction() {
let cache = L0Cache::new(2);
cache.insert("a".to_string(), mock_encoding(vec![1]));
cache.insert("b".to_string(), mock_encoding(vec![2]));
// Should evict when adding third
cache.insert("c".to_string(), mock_encoding(vec![3]));
// Cache should have exactly 2 entries
assert_eq!(cache.len(), 2);
}
#[test]
fn test_stats() {
let cache = L0Cache::new(10);
cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
// 1 miss (initial get that returned None)
let _ = cache.get("missing");
// 1 hit
let _ = cache.get("test");
let stats = cache.stats();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
assert_eq!(stats.hit_rate, 0.5);
}
#[test]
fn test_clear() {
let cache = L0Cache::new(10);
cache.insert("test".to_string(), mock_encoding(vec![1, 2, 3]));
assert_eq!(cache.len(), 1);
cache.clear();
assert_eq!(cache.len(), 0);
assert!(cache.get("test").is_none());
}
#[test]
fn test_concurrent_access() {
use std::thread;
let cache = Arc::new(L0Cache::new(1000));
let mut handles = vec![];
// Spawn 10 threads
for i in 0..10 {
let cache_clone = cache.clone();
handles.push(thread::spawn(move || {
// Each thread inserts and reads
let key = format!("key_{}", i);
cache_clone.insert(key.clone(), mock_encoding(vec![i as u32]));
// Read it back
let result = cache_clone.get(&key);
assert!(result.is_some());
}));
}
for handle in handles {
handle.join().unwrap();
}
// Should have 10 entries
assert_eq!(cache.len(), 10);
}
}
//! L1 Cache: Special-token boundary prefix cache
//!
//! Caches tokenization results at ALL special token boundaries.
//! Special tokens (like `<|im_start|>`, `<|im_end|>`) are atomic in BPE tokenizers (special: true, normalized: false),
//! making them the ONLY safe split points that guarantee correctness.
//!
//! **Design**: Cache at every special token boundary (not at fixed granularity intervals)
//! - Simple: No granularity parameter, no search windows
//! - Efficient: Fewer cache entries (10 instead of 64 for typical 8KB prompt)
//! - Natural: Aligns with actual chat template structure
//!
//! Example:
//!
//! Template: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\n{query}<|im_end|>"
//!
//! Request 1: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>"
//! Request 2: "<|im_start|>system\nYou are helpful.<|im_end|><|im_start|>user\nHello!<|im_end|>"
//!
//! Cache points: After each "<|im_end|>" (atomic tokens, guaranteed safe)
//! Result: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix)
use std::{
mem::size_of,
sync::{
atomic::{AtomicU64, Ordering},
Arc,
},
};
use blake3;
use dashmap::DashMap;
use super::super::traits::TokenIdType;
/// Hash type for cache keys
type Blake3Hash = [u8; 32];
/// Number of shards for concurrent access
const NUM_SHARDS: usize = 16;
/// Find ALL special token boundaries in the text
///
/// **ONLY uses special tokens** - these are atomic (special: true, normalized: false) in BPE,
/// guaranteeing: tokenize(prefix) + tokenize(suffix) == tokenize(prefix + suffix)
///
/// No fallback to whitespace/punctuation - better to not cache than risk corruption.
///
/// Common special tokens:
/// - ChatML: `<|im_start|>`, `<|im_end|>`
/// - Llama 3: `<|begin_of_text|>`, `<|end_of_text|>`, `<|eot_id|>`
/// - GPT: `<|endoftext|>`
/// - Custom: `<|reserved_special_token_N|>`
///
/// Returns positions immediately after each special token (where prefixes can be cached).
fn find_special_token_boundaries(text: &str, special_tokens: &[&str]) -> Vec<usize> {
if special_tokens.is_empty() {
return Vec::new();
}
let mut boundaries = Vec::new();
// Find all special token end positions
for &token in special_tokens {
let mut start = 0;
while let Some(pos) = text[start..].find(token) {
let boundary = start + pos + token.len();
// Only cache boundaries that leave some suffix to tokenize
if boundary < text.len() {
boundaries.push(boundary);
}
start = boundary;
}
}
// Sort and deduplicate (in case multiple special tokens end at same position)
boundaries.sort_unstable();
boundaries.dedup();
boundaries
}
/// A cached prefix entry
#[derive(Debug, Clone)]
struct CachedPrefix {
/// The pre-computed token IDs for this prefix
tokens: Vec<TokenIdType>,
/// Last access timestamp (for LRU eviction)
last_accessed: Arc<AtomicU64>,
/// Size in bytes (for memory tracking during eviction)
size_bytes: usize,
}
/// L1 cache implementation with special-token-boundary prefix matching
pub struct L1Cache {
/// Sharded maps for concurrent access
/// Key: Blake3 hash of bytes[0..boundary]
/// Value: Cached token IDs for that prefix
shards: Vec<Arc<DashMap<Blake3Hash, CachedPrefix>>>,
/// Maximum memory in bytes
max_memory: usize,
/// Current memory usage estimate
current_memory: AtomicU64,
/// Cache hit counter
hits: AtomicU64,
/// Cache miss counter
misses: AtomicU64,
/// Monotonic counter for LRU timestamps
access_counter: AtomicU64,
}
impl L1Cache {
/// Create a new L1 cache with the specified memory limit
pub fn new(max_memory: usize) -> Self {
let shards = (0..NUM_SHARDS).map(|_| Arc::new(DashMap::new())).collect();
Self {
shards,
max_memory,
current_memory: AtomicU64::new(0),
hits: AtomicU64::new(0),
misses: AtomicU64::new(0),
access_counter: AtomicU64::new(0),
}
}
/// Try to find the longest prefix match at special token boundaries
/// Returns (cached_tokens, byte_offset) if found
///
/// Uses pre-computed tokens cached during insertion.
pub fn longest_prefix_match(
&self,
input: &str,
special_tokens: &[&str],
) -> Option<(Vec<TokenIdType>, usize)> {
let boundaries = find_special_token_boundaries(input, special_tokens);
if boundaries.is_empty() {
self.misses.fetch_add(1, Ordering::Relaxed);
return None;
}
// Search backwards from the longest boundary to find the best match
for &boundary_pos in boundaries.iter().rev() {
let prefix = &input[0..boundary_pos];
let prefix_bytes = prefix.as_bytes();
let hash = blake3::hash(prefix_bytes);
let hash_bytes: Blake3Hash = *hash.as_bytes();
let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
if let Some(entry) = self.shards[shard_idx].get(&hash_bytes) {
// Update last accessed timestamp for LRU
let timestamp = self.access_counter.fetch_add(1, Ordering::Relaxed);
entry.last_accessed.store(timestamp, Ordering::Relaxed);
self.hits.fetch_add(1, Ordering::Relaxed);
return Some((entry.tokens.clone(), boundary_pos));
}
}
self.misses.fetch_add(1, Ordering::Relaxed);
None
}
/// Insert prefix entries at ALL special token boundaries
///
/// Re-tokenizes each prefix to ensure correctness (BPE tokenization is not prefix-stable).
/// This is more expensive on cache misses but provides correct tokens for cache hits.
///
/// Optimized for workloads with high prefix reuse (e.g., chat templates with repeated system prompts).
pub fn insert_at_boundaries<E: super::super::traits::Encoder + ?Sized>(
&self,
input: &str,
tokenizer: &E,
special_tokens: &[&str],
) -> anyhow::Result<()> {
let boundaries = find_special_token_boundaries(input, special_tokens);
if boundaries.is_empty() {
return Ok(());
}
// Calculate how much memory we need and tokenize each prefix
let mut entries_to_insert = Vec::new();
for &boundary_pos in &boundaries {
// Extract prefix up to this special token boundary
let prefix = &input[0..boundary_pos];
let prefix_bytes = prefix.as_bytes();
let hash = blake3::hash(prefix_bytes);
let hash_bytes: Blake3Hash = *hash.as_bytes();
// Re-tokenize the prefix for guaranteed correctness
// This is the only way to know the exact token boundaries
let prefix_encoding = tokenizer.encode(prefix)?;
let prefix_tokens = prefix_encoding.token_ids().to_vec();
// Size = text bytes + token storage
let size_bytes = boundary_pos + prefix_tokens.len() * size_of::<TokenIdType>();
entries_to_insert.push((hash_bytes, prefix_tokens, size_bytes));
}
if entries_to_insert.is_empty() {
return Ok(());
}
let total_size_needed: usize = entries_to_insert.iter().map(|(_, _, size)| size).sum();
// Evict if necessary
let current = self.current_memory.load(Ordering::Relaxed) as usize;
if current + total_size_needed > self.max_memory {
self.evict_lru(total_size_needed);
}
// Insert all entries
for (hash_bytes, prefix_tokens, size_bytes) in entries_to_insert {
let shard_idx = hash_bytes[0] as usize % NUM_SHARDS;
let cached = CachedPrefix {
tokens: prefix_tokens,
last_accessed: Arc::new(AtomicU64::new(
self.access_counter.load(Ordering::Relaxed),
)),
size_bytes,
};
self.shards[shard_idx].insert(hash_bytes, cached);
self.current_memory
.fetch_add(size_bytes as u64, Ordering::Relaxed);
}
Ok(())
}
/// Evict least recently used entries using approximate LRU via random sampling
///
/// This uses an approximate LRU strategy that's much faster than true LRU:
/// - Samples K random entries from the cache (K=32)
/// - Evicts the oldest entry among the samples
/// - Repeats until enough space is freed
///
/// This provides O(samples) complexity instead of O(total_entries * log(total_entries)),
/// avoiding latency spikes when eviction is triggered on large caches.
///
/// The approximation is excellent in practice - sampling 32 entries from a large cache
/// gives high probability of finding very old entries.
fn evict_lru(&self, space_needed: usize) {
const SAMPLE_SIZE: usize = 32; // Number of entries to sample per eviction round
let mut freed = 0usize;
let mut iteration = 0usize;
// Keep evicting until we have enough space
while freed < space_needed {
// Collect samples from shards
let mut samples: Vec<(usize, Blake3Hash, u64, usize)> = Vec::with_capacity(SAMPLE_SIZE);
// Sample entries across different shards
for i in 0..SAMPLE_SIZE {
// Distribute samples across shards using iteration and index for variety
let shard_idx = (iteration * SAMPLE_SIZE + i) % NUM_SHARDS;
// Get first entry from that shard (DashMap iteration order is arbitrary)
if let Some(entry) = self.shards[shard_idx].iter().next() {
let hash = *entry.key();
let timestamp = entry.value().last_accessed.load(Ordering::Relaxed);
let size = entry.value().size_bytes;
samples.push((shard_idx, hash, timestamp, size));
}
}
if samples.is_empty() {
// Cache is empty, nothing to evict
break;
}
// Find the oldest entry among samples
if let Some((shard_idx, hash, _, _)) =
samples.iter().min_by_key(|(_, _, ts, _)| ts).copied()
{
// Remove it
if let Some((_, removed)) = self.shards[shard_idx].remove(&hash) {
freed += removed.size_bytes;
self.current_memory
.fetch_sub(removed.size_bytes as u64, Ordering::Relaxed);
}
}
iteration += 1;
}
}
/// Get the number of entries in the cache
pub fn len(&self) -> usize {
self.shards.iter().map(|s| s.len()).sum()
}
/// Check if the cache is empty
pub fn is_empty(&self) -> bool {
self.shards.iter().all(|s| s.is_empty())
}
/// Get cache statistics
pub fn stats(&self) -> L1CacheStats {
let hits = self.hits.load(Ordering::Relaxed);
let misses = self.misses.load(Ordering::Relaxed);
let total_requests = hits + misses;
L1CacheStats {
hits,
misses,
entries: self.len(),
memory_bytes: self.current_memory.load(Ordering::Relaxed) as usize,
hit_rate: if total_requests > 0 {
hits as f64 / total_requests as f64
} else {
0.0
},
}
}
/// Clear the cache
pub fn clear(&self) {
for shard in &self.shards {
shard.clear();
}
self.current_memory.store(0, Ordering::Relaxed);
self.hits.store(0, Ordering::Relaxed);
self.misses.store(0, Ordering::Relaxed);
}
}
#[derive(Debug, Clone)]
pub struct L1CacheStats {
pub hits: u64,
pub misses: u64,
pub entries: usize,
pub memory_bytes: usize,
pub hit_rate: f64,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::mock::MockTokenizer;
#[test]
fn test_basic_prefix_match() {
let cache = L1Cache::new(1024 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();
// Realistic ChatML template with special tokens
let input1 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there! How are you doing today?<|im_end|>";
// Insert at special token boundaries (re-tokenizes prefixes)
cache
.insert_at_boundaries(input1, &tokenizer, special_tokens)
.unwrap();
// Should have cached at special token boundaries
assert!(!cache.is_empty());
// Search with same prefix but different user query
let input2 = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nWhat is 2+2?<|im_end|>";
let result = cache.longest_prefix_match(input2, special_tokens);
// Should find a match at the special token boundary (after system message)
assert!(result.is_some());
let (tokens, offset) = result.unwrap();
assert!(offset > 0);
assert!(!tokens.is_empty());
}
#[test]
fn test_short_input_with_boundaries() {
let cache = L1Cache::new(1024 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();
// Short input with special tokens
let input = "<|im_start|>user\nHi<|im_end|>";
cache
.insert_at_boundaries(input, &tokenizer, special_tokens)
.unwrap();
// Should cache at <|im_start|> boundary (has suffix left)
assert!(!cache.is_empty());
// Should find a match
let result = cache.longest_prefix_match(input, special_tokens);
assert!(result.is_some());
}
#[test]
fn test_longest_match() {
let cache = L1Cache::new(1024 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();
// Create multi-turn conversation with multiple special token boundaries (~400 bytes)
let input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|><|im_start|>assistant\nI'm doing well, thank you! I'd be happy to explain tokenization. Tokenization is the process of breaking text into smaller units called tokens.<|im_end|>";
cache
.insert_at_boundaries(input, &tokenizer, special_tokens)
.unwrap();
// Should have multiple entries at special token boundaries
assert!(cache.len() >= 2); // At least 2 boundaries
// Search with partial conversation - should match at a special token boundary
let partial_input = "<|im_start|>system\nYou are a helpful AI assistant that provides detailed and accurate responses.<|im_end|><|im_start|>user\nHello there! How are you today? Can you help me understand how tokenization works in language models?<|im_end|>";
let result = cache.longest_prefix_match(partial_input, special_tokens);
// Should find a match at a special token boundary
assert!(result.is_some());
let (_, offset) = result.unwrap();
assert!(offset > 0);
assert!(offset <= partial_input.len());
}
#[test]
fn test_stats() {
let cache = L1Cache::new(1024 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();
// ChatML input with special tokens
let input = "<|im_start|>system\nYou are a helpful assistant that provides detailed answers.<|im_end|><|im_start|>user\nHello there! How are you today?<|im_end|>";
cache
.insert_at_boundaries(input, &tokenizer, special_tokens)
.unwrap();
// Try to find match
let _ = cache.longest_prefix_match(input, special_tokens);
let stats = cache.stats();
// Should have at least one hit (the longest special token boundary should match)
assert!(stats.hits >= 1);
assert_eq!(stats.hit_rate, 1.0);
}
#[test]
fn test_clear() {
let cache = L1Cache::new(1024 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>"];
let tokenizer = MockTokenizer::new();
// ChatML input with special tokens
let input = "<|im_start|>system\nYou are a helpful assistant that provides clear and detailed responses.<|im_end|><|im_start|>user\nHello there!<|im_end|>";
cache
.insert_at_boundaries(input, &tokenizer, special_tokens)
.unwrap();
assert!(!cache.is_empty());
cache.clear();
assert!(cache.is_empty());
let stats = cache.stats();
assert_eq!(stats.hits, 0);
assert_eq!(stats.misses, 0);
}
#[test]
fn test_lru_eviction() {
// Create a small cache (5KB) to trigger eviction
let cache = L1Cache::new(5 * 1024);
let special_tokens = &["<|im_start|>", "<|im_end|>", "<|eot_id|>"];
let tokenizer = MockTokenizer::new();
// Insert first conversation
let input1 = "<|im_start|>system\nYou are a helpful assistant specialized in mathematics.<|im_end|><|im_start|>user\nCan you explain calculus to me?<|im_end|><|im_start|>assistant\nCertainly! Calculus is a branch of mathematics that studies continuous change.<|im_end|><|eot_id|>";
cache
.insert_at_boundaries(input1, &tokenizer, special_tokens)
.unwrap();
// Access the first entry to update its timestamp
let result = cache.longest_prefix_match(input1, special_tokens);
assert!(result.is_some());
// Insert second conversation
let input2 = "<|im_start|>system\nYou are a helpful assistant specialized in physics.<|im_end|><|im_start|>user\nWhat is quantum mechanics?<|im_end|><|im_start|>assistant\nQuantum mechanics is the fundamental theory describing nature at atomic and subatomic scales.<|im_end|><|eot_id|>";
cache
.insert_at_boundaries(input2, &tokenizer, special_tokens)
.unwrap();
// Access the second entry to make it more recent
let result = cache.longest_prefix_match(input2, special_tokens);
assert!(result.is_some());
// Insert third conversation (should trigger eviction of oldest)
let input3 = "<|im_start|>system\nYou are a helpful assistant specialized in chemistry.<|im_end|><|im_start|>user\nExplain the periodic table to me please.<|im_end|><|im_start|>assistant\nThe periodic table is a tabular arrangement of chemical elements organized by atomic number and electron configuration.<|im_end|><|eot_id|>";
cache
.insert_at_boundaries(input3, &tokenizer, special_tokens)
.unwrap();
// Verify cache didn't exceed max memory
let stats = cache.stats();
assert!(stats.memory_bytes <= 5 * 1024);
// The most recently accessed entries should still be present
let result = cache.longest_prefix_match(input3, special_tokens);
assert!(result.is_some());
}
}
//! Tokenizer Caching Layer
//!
//! Provides a caching wrapper around any tokenizer implementation to speed up
//! repeated tokenization of the same strings (e.g., system prompts).
//!
//! # Architecture
//! - **L0 Cache**: Whole-string exact match (90% of wins)
//! - **L1 Cache**: Prefix matching at fixed boundaries (future work)
//!
//! # Usage
//! ```ignore
//! let tokenizer = Arc::new(HuggingFaceTokenizer::from_file("tokenizer.json")?);
//! let cached = Arc::new(CachedTokenizer::new(tokenizer, CacheConfig::default()));
//! let encoding = cached.encode("Hello world")?;
//! ```
mod fingerprint;
mod l0;
mod l1;
use std::sync::Arc;
use anyhow::Result;
pub use fingerprint::TokenizerFingerprint;
pub use l0::{CacheStats, L0Cache};
pub use l1::{L1Cache, L1CacheStats};
use rayon::prelude::*;
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, TokenIdType, Tokenizer};
/// Configuration for the tokenizer cache
#[derive(Debug, Clone)]
pub struct CacheConfig {
/// Enable L0 (whole-string) cache
pub enable_l0: bool,
/// Maximum number of entries in L0 cache
pub l0_max_entries: usize,
/// Enable L1 (prefix) cache
pub enable_l1: bool,
/// Maximum memory for L1 cache in bytes
pub l1_max_memory: usize,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
enable_l0: true,
l0_max_entries: 10_000, // ~22MB memory for typical prompts
enable_l1: false, // Opt-in for now
l1_max_memory: 50 * 1024 * 1024, // 50MB
}
}
}
/// A caching wrapper around any tokenizer
pub struct CachedTokenizer {
/// The underlying tokenizer
inner: Arc<dyn Tokenizer>,
/// L0 cache (whole-string exact match)
l0: Option<L0Cache>,
/// L1 cache (prefix matching at fixed boundaries)
l1: Option<L1Cache>,
/// Configuration
#[allow(dead_code)]
config: CacheConfig,
/// Fingerprint for cache invalidation
fingerprint: TokenizerFingerprint,
/// Cached special token strings (extracted once at construction)
special_token_strings: Vec<String>,
}
impl CachedTokenizer {
/// Create a new cached tokenizer
pub fn new(inner: Arc<dyn Tokenizer>, config: CacheConfig) -> Self {
let fingerprint = TokenizerFingerprint::from_tokenizer(inner.as_ref());
let l0 = if config.enable_l0 {
Some(L0Cache::new(config.l0_max_entries))
} else {
None
};
let l1 = if config.enable_l1 {
Some(L1Cache::new(config.l1_max_memory))
} else {
None
};
// Extract special tokens once at construction time
let special_token_strings = Self::extract_special_token_strings(&inner);
Self {
inner,
l0,
l1,
config,
fingerprint,
special_token_strings,
}
}
/// Extract all special token strings from the tokenizer (called once at construction)
fn extract_special_token_strings(tokenizer: &Arc<dyn Tokenizer>) -> Vec<String> {
let special_tokens = tokenizer.get_special_tokens();
let mut tokens = Vec::new();
if let Some(ref token) = special_tokens.bos_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.eos_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.unk_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.sep_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.pad_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.cls_token {
tokens.push(token.clone());
}
if let Some(ref token) = special_tokens.mask_token {
tokens.push(token.clone());
}
tokens.extend(special_tokens.additional_special_tokens.iter().cloned());
tokens
}
/// Get L0 cache statistics
pub fn cache_stats(&self) -> Option<CacheStats> {
self.l0.as_ref().map(|cache| cache.stats())
}
/// Get L1 cache statistics
pub fn l1_cache_stats(&self) -> Option<L1CacheStats> {
self.l1.as_ref().map(|cache| cache.stats())
}
/// Clear the cache
pub fn clear_cache(&self) {
if let Some(l0) = &self.l0 {
l0.clear();
}
if let Some(l1) = &self.l1 {
l1.clear();
}
}
/// Get the fingerprint of the underlying tokenizer
pub fn fingerprint(&self) -> &TokenizerFingerprint {
&self.fingerprint
}
/// Get a reference to the inner (wrapped) tokenizer
pub fn inner(&self) -> &Arc<dyn Tokenizer> {
&self.inner
}
}
impl Encoder for CachedTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
// Collect special tokens once if L1 is enabled (avoid redundant allocation)
let special_tokens: Option<Vec<&str>> = self.l1.as_ref().map(|_| {
self.special_token_strings
.iter()
.map(|s| s.as_str())
.collect()
});
// L0 cache lookup (exact match)
if let Some(l0) = &self.l0 {
if let Some(cached) = l0.get(input) {
return Ok(cached);
}
}
// L1 cache lookup (prefix match at special token boundaries)
if let Some(l1) = &self.l1 {
let tokens = special_tokens.as_ref().unwrap();
if let Some((prefix_tokens, prefix_len)) = l1.longest_prefix_match(input, tokens) {
// We have a prefix match - tokenize the suffix
let suffix = &input[prefix_len..];
if !suffix.is_empty() {
let suffix_encoding = self.inner.encode(suffix)?;
// Merge prefix tokens + suffix tokens
// Safe because we're splitting at special token boundaries
let mut merged_tokens = prefix_tokens;
merged_tokens.extend_from_slice(suffix_encoding.token_ids());
let merged_encoding = Encoding::Sp(merged_tokens);
// Cache the full result in L0
if let Some(l0) = &self.l0 {
l0.insert(input.to_string(), merged_encoding.clone());
}
return Ok(merged_encoding);
}
}
}
// Full tokenization (both L0 and L1 miss)
let encoding = self.inner.encode(input)?;
// Cache in L0
if let Some(l0) = &self.l0 {
l0.insert(input.to_string(), encoding.clone());
}
// Cache in L1 at special token boundaries
// Re-tokenizes prefixes for correctness (optimized for high prefix reuse)
if let Some(l1) = &self.l1 {
let tokens = special_tokens.as_ref().unwrap();
let _ = l1.insert_at_boundaries(input, self.inner.as_ref(), tokens);
// Ignore errors in cache insertion - cache is best-effort
}
Ok(encoding)
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
// Process each input in parallel, leveraging thread-safe caches
// This maintains the parallelism from the underlying HuggingFaceTokenizer
inputs.par_iter().map(|&input| self.encode(input)).collect()
}
}
impl Decoder for CachedTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> {
// Decoding is not cached (it's fast enough and rarely repeated)
self.inner.decode(token_ids, skip_special_tokens)
}
}
impl Tokenizer for CachedTokenizer {
fn vocab_size(&self) -> usize {
self.inner.vocab_size()
}
fn get_special_tokens(&self) -> &SpecialTokens {
self.inner.get_special_tokens()
}
fn token_to_id(&self, token: &str) -> Option<TokenIdType> {
self.inner.token_to_id(token)
}
fn id_to_token(&self, id: TokenIdType) -> Option<String> {
self.inner.id_to_token(id)
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::mock::MockTokenizer;
#[test]
fn test_cache_hit() {
let tokenizer = Arc::new(MockTokenizer::new());
let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
let input = "Hello world";
// First call - miss
let result1 = cached.encode(input).unwrap();
// Second call - hit
let result2 = cached.encode(input).unwrap();
// Results should be identical
assert_eq!(result1.token_ids(), result2.token_ids());
// Check cache stats
let stats = cached.cache_stats().unwrap();
assert_eq!(stats.hits, 1);
assert_eq!(stats.misses, 1);
}
#[test]
fn test_cache_disabled() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = CacheConfig {
enable_l0: false,
l0_max_entries: 0,
enable_l1: false,
l1_max_memory: 0,
};
let cached = CachedTokenizer::new(tokenizer, config);
let input = "Hello world";
// Both calls should work even without cache
let result1 = cached.encode(input).unwrap();
let result2 = cached.encode(input).unwrap();
assert_eq!(result1.token_ids(), result2.token_ids());
// No cache stats available
assert!(cached.cache_stats().is_none());
}
#[test]
fn test_encode_batch() {
let tokenizer = Arc::new(MockTokenizer::new());
let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
let inputs = vec!["Hello", "world", "Hello"]; // "Hello" repeated
let results = cached.encode_batch(&inputs).unwrap();
assert_eq!(results.len(), 3);
// With parallel execution, duplicate inputs may be processed simultaneously
// and both see cache misses. Verify results are correct instead.
assert_eq!(results[0].token_ids(), results[2].token_ids()); // Both "Hello" should match
// After batch processing, cache should be populated
// Subsequent calls should hit the cache
let _ = cached.encode("Hello").unwrap();
let stats = cached.cache_stats().unwrap();
// Should have at least 1 hit from the call above (cache was populated by batch)
assert!(
stats.hits >= 1,
"Expected at least 1 cache hit after batch processing"
);
}
#[test]
fn test_decoder_passthrough() {
let tokenizer = Arc::new(MockTokenizer::new());
let cached = CachedTokenizer::new(tokenizer, CacheConfig::default());
let tokens = vec![1, 2, 3];
let decoded = cached.decode(&tokens, false).unwrap();
// Should just pass through to inner tokenizer
assert!(!decoded.is_empty());
}
#[test]
fn test_tokenizer_trait_methods() {
let tokenizer = Arc::new(MockTokenizer::new());
let cached = CachedTokenizer::new(tokenizer.clone(), CacheConfig::default());
// Should pass through to inner tokenizer
assert_eq!(cached.vocab_size(), tokenizer.vocab_size());
assert!(cached.token_to_id("Hello").is_some());
assert!(cached.id_to_token(1).is_some());
}
}
......@@ -407,7 +407,7 @@ mod tests {
#[test]
fn test_mock_tokenizer_creation() {
let tokenizer = create_tokenizer_from_file("mock").unwrap();
assert_eq!(tokenizer.vocab_size(), 8); // Mock tokenizer has 8 tokens
assert_eq!(tokenizer.vocab_size(), 14); // Mock tokenizer has 14 tokens
}
#[test]
......
......@@ -44,8 +44,8 @@ impl HuggingFaceTokenizer {
// Extract special tokens
let special_tokens = Self::extract_special_tokens(&tokenizer);
// Build vocab mappings
let vocab = tokenizer.get_vocab(false);
// Build vocab mappings (include special tokens to get added_tokens like <|im_start|>)
let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
let reverse_vocab: HashMap<TokenIdType, String> = vocab
.iter()
.map(|(token, &id)| (id, token.clone()))
......@@ -80,7 +80,7 @@ impl HuggingFaceTokenizer {
/// Create from an existing HuggingFace tokenizer
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
let special_tokens = Self::extract_special_tokens(&tokenizer);
let vocab = tokenizer.get_vocab(false);
let vocab = tokenizer.get_vocab(true); // true = include special tokens and added_tokens
let reverse_vocab: HashMap<TokenIdType, String> = vocab
.iter()
.map(|(token, &id)| (id, token.clone()))
......@@ -98,8 +98,7 @@ impl HuggingFaceTokenizer {
/// Extract special tokens from the tokenizer
fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
// Try to get special tokens from the tokenizer
// This is a simplified version - actual implementation would need to handle various formats
// Get vocab with special tokens included (added_tokens like <|im_start|>)
let vocab = tokenizer.get_vocab(true);
let find_token = |patterns: &[&str]| -> Option<String> {
......@@ -111,6 +110,14 @@ impl HuggingFaceTokenizer {
None
};
// Extract additional special tokens using the tokenizers library API
let additional_special_tokens: Vec<String> = tokenizer
.get_added_tokens_decoder()
.iter()
.filter(|(_id, token)| token.special) // Only tokens marked as special: true
.map(|(_id, token)| token.content.clone())
.collect();
SpecialTokens {
bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
......@@ -119,7 +126,7 @@ impl HuggingFaceTokenizer {
pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
additional_special_tokens: vec![],
additional_special_tokens,
}
}
......
......@@ -34,6 +34,12 @@ impl MockTokenizer {
(".", 6),
("<eos>", 999),
("<bos>", 1000),
("<|im_start|>", 1001),
("<|im_end|>", 1002),
("<|eot_id|>", 1003),
("system", 7),
("user", 8),
("assistant", 9),
];
for (token, id) in tokens {
......@@ -62,7 +68,8 @@ impl MockTokenizer {
impl Encoder for MockTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
// Simple word-based tokenization for testing
// Simple word-based tokenization using the vocab
// Split by whitespace and look up each word (decoder adds spaces back)
let tokens: Vec<u32> = input
.split_whitespace()
.filter_map(|word| self.vocab.get(word).copied())
......
......@@ -2,6 +2,7 @@ use std::{ops::Deref, sync::Arc};
use anyhow::Result;
pub mod cache;
pub mod factory;
pub mod hub;
pub mod mock;
......@@ -22,6 +23,7 @@ pub mod tiktoken;
mod tests;
// Re-exports
pub use cache::{CacheConfig, CacheStats, CachedTokenizer, TokenizerFingerprint};
pub use factory::{
create_tokenizer, create_tokenizer_async, create_tokenizer_async_with_chat_template,
create_tokenizer_from_file, create_tokenizer_with_chat_template,
......
......@@ -43,7 +43,7 @@ fn test_tokenizer_wrapper() {
let text = tokenizer.decode(&[1, 2], false).unwrap();
assert_eq!(text, "Hello world");
assert_eq!(tokenizer.vocab_size(), 8);
assert_eq!(tokenizer.vocab_size(), 14);
assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
assert_eq!(tokenizer.token_to_id("unknown"), None);
......
......@@ -69,6 +69,7 @@ impl TestContext {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
Self::new_with_config(config, worker_configs).await
......@@ -1406,6 +1407,7 @@ mod error_tests {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = TestContext::new_with_config(
......@@ -1735,6 +1737,7 @@ mod pd_mode_tests {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
// Create app context
......@@ -1898,6 +1901,7 @@ mod request_id_tests {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = TestContext::new_with_config(
......
......@@ -84,6 +84,7 @@ async fn test_non_streaming_mcp_minimal_e2e_with_persistence() {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
// Create router and context
......@@ -284,6 +285,7 @@ async fn test_conversations_crud_basic() {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = common::create_test_context(router_cfg);
......@@ -619,6 +621,7 @@ async fn test_multi_turn_loop_with_mcp() {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = common::create_test_context(router_cfg);
......@@ -795,6 +798,7 @@ async fn test_max_tool_calls_limit() {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = common::create_test_context(router_cfg);
......@@ -937,6 +941,7 @@ async fn setup_streaming_mcp_test() -> (
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = common::create_test_context(router_cfg);
......@@ -1378,6 +1383,7 @@ async fn test_conversation_items_create_and_get() {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = common::create_test_context(router_cfg);
......@@ -1479,6 +1485,7 @@ async fn test_conversation_items_delete() {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = common::create_test_context(router_cfg);
......@@ -1586,6 +1593,7 @@ async fn test_conversation_items_max_limit() {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = common::create_test_context(router_cfg);
......@@ -1663,6 +1671,7 @@ async fn test_conversation_items_unsupported_type() {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = common::create_test_context(router_cfg);
......@@ -1739,6 +1748,7 @@ async fn test_conversation_items_multi_conversation_sharing() {
oracle: None,
reasoning_parser: None,
tool_call_parser: None,
tokenizer_cache: sglang_router_rs::config::TokenizerCacheConfig::default(),
};
let ctx = common::create_test_context(router_cfg);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment