use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait}; use anyhow::{Error, Result}; use tiktoken_rs::{cl100k_base, p50k_base, p50k_edit, r50k_base, CoreBPE}; /// Tiktoken tokenizer wrapper for OpenAI GPT models pub struct TiktokenTokenizer { tokenizer: CoreBPE, #[allow(dead_code)] model: TiktokenModel, special_tokens: SpecialTokens, vocab_size: usize, } /// Supported Tiktoken models #[derive(Debug, Clone, Copy)] pub enum TiktokenModel { /// GPT-4, GPT-3.5-turbo, text-embedding-ada-002 Cl100kBase, /// Codex models, text-davinci-002, text-davinci-003 P50kBase, /// Use for edit models like text-davinci-edit-001, code-davinci-edit-001 P50kEdit, /// GPT-3 models like davinci R50kBase, } impl TiktokenTokenizer { /// Create a new Tiktoken tokenizer for the specified model pub fn new(model: TiktokenModel) -> Result { let tokenizer = match model { TiktokenModel::Cl100kBase => cl100k_base() .map_err(|e| Error::msg(format!("Failed to load cl100k_base: {}", e)))?, TiktokenModel::P50kBase => p50k_base() .map_err(|e| Error::msg(format!("Failed to load p50k_base: {}", e)))?, TiktokenModel::P50kEdit => p50k_edit() .map_err(|e| Error::msg(format!("Failed to load p50k_edit: {}", e)))?, TiktokenModel::R50kBase => r50k_base() .map_err(|e| Error::msg(format!("Failed to load r50k_base: {}", e)))?, }; // Extract special tokens (tiktoken-rs doesn't expose them directly) // We'll use common ones for GPT models let special_tokens = Self::get_special_tokens_for_model(model); // Get vocabulary size (this is an approximation) let vocab_size = match model { TiktokenModel::Cl100kBase => 100256, // cl100k has ~100k tokens TiktokenModel::P50kBase | TiktokenModel::P50kEdit => 50281, // p50k has ~50k tokens TiktokenModel::R50kBase => 50257, // r50k has ~50k tokens }; Ok(TiktokenTokenizer { tokenizer, model, special_tokens, vocab_size, }) } /// Create a tokenizer from a model string (e.g., "gpt-4", "gpt-3.5-turbo") pub fn from_model_name(model_name: &str) -> Result { let model = Self::model_from_name(model_name)?; Self::new(model) } /// Determine the appropriate model from a model name fn model_from_name(model_name: &str) -> Result { // Based on OpenAI's model-to-encoding mapping if model_name.contains("gpt-4") || model_name.contains("gpt-3.5") || model_name.contains("turbo") { Ok(TiktokenModel::Cl100kBase) } else if model_name.contains("davinci-002") || model_name.contains("davinci-003") || model_name.contains("codex") { Ok(TiktokenModel::P50kBase) } else if model_name.contains("edit") { Ok(TiktokenModel::P50kEdit) } else if model_name.contains("davinci") || model_name.contains("curie") || model_name.contains("babbage") || model_name.contains("ada") { Ok(TiktokenModel::R50kBase) } else { // Return an error for unrecognized model names to prevent silent failures Err(anyhow::anyhow!( "Unrecognized OpenAI model name: '{}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names", model_name )) } } /// Get special tokens for a specific model fn get_special_tokens_for_model(model: TiktokenModel) -> SpecialTokens { // These are common special tokens for GPT models // The actual token IDs might vary by model match model { TiktokenModel::Cl100kBase => SpecialTokens { bos_token: Some("<|endoftext|>".to_string()), eos_token: Some("<|endoftext|>".to_string()), unk_token: None, sep_token: None, pad_token: Some("<|endoftext|>".to_string()), cls_token: None, mask_token: None, additional_special_tokens: vec![ "<|fim_prefix|>".to_string(), "<|fim_middle|>".to_string(), "<|fim_suffix|>".to_string(), "<|endofprompt|>".to_string(), ], }, _ => SpecialTokens { bos_token: Some("<|endoftext|>".to_string()), eos_token: Some("<|endoftext|>".to_string()), unk_token: None, sep_token: None, pad_token: Some("<|endoftext|>".to_string()), cls_token: None, mask_token: None, additional_special_tokens: vec![], }, } } } impl Encoder for TiktokenTokenizer { fn encode(&self, input: &str) -> Result { let tokens = self.tokenizer.encode_ordinary(input); Ok(Encoding::Tiktoken(tokens)) } fn encode_batch(&self, inputs: &[&str]) -> Result> { inputs.iter().map(|input| self.encode(input)).collect() } } impl Decoder for TiktokenTokenizer { fn decode(&self, token_ids: &[u32], _skip_special_tokens: bool) -> Result { // Convert u32 to usize for tiktoken-rs let tokens: Vec = token_ids.iter().map(|&id| id as usize).collect(); self.tokenizer .decode(tokens) .map_err(|e| Error::msg(format!("Decoding failed: {}", e))) } } impl TokenizerTrait for TiktokenTokenizer { fn vocab_size(&self) -> usize { self.vocab_size } fn get_special_tokens(&self) -> &SpecialTokens { &self.special_tokens } fn token_to_id(&self, _token: &str) -> Option { // Tiktoken doesn't provide direct token-to-id mapping // We'd need to encode the token and check if it produces a single ID None } fn id_to_token(&self, _id: u32) -> Option { // Tiktoken doesn't provide direct id-to-token mapping // We can only decode IDs to text None } } #[cfg(test)] mod tests { use super::*; #[test] fn test_tiktoken_creation() { let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap(); assert_eq!(tokenizer.vocab_size(), 100256); } #[test] fn test_model_from_name() { assert!(matches!( TiktokenTokenizer::model_from_name("gpt-4").unwrap(), TiktokenModel::Cl100kBase )); assert!(matches!( TiktokenTokenizer::model_from_name("gpt-3.5-turbo").unwrap(), TiktokenModel::Cl100kBase )); assert!(matches!( TiktokenTokenizer::model_from_name("text-davinci-003").unwrap(), TiktokenModel::P50kBase )); assert!(matches!( TiktokenTokenizer::model_from_name("text-davinci-edit-001").unwrap(), TiktokenModel::P50kEdit )); assert!(matches!( TiktokenTokenizer::model_from_name("davinci").unwrap(), TiktokenModel::R50kBase )); } #[test] fn test_encode_decode() { let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap(); let text = "Hello, world!"; let encoding = tokenizer.encode(text).unwrap(); let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap(); assert_eq!(decoded, text); } #[test] fn test_batch_encode() { let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap(); let texts = vec!["Hello", "World", "Test"]; let encodings = tokenizer.encode_batch(&texts).unwrap(); assert_eq!(encodings.len(), 3); for (i, encoding) in encodings.iter().enumerate() { let decoded = tokenizer.decode(&encoding.token_ids(), false).unwrap(); assert_eq!(decoded, texts[i]); } } #[test] fn test_special_tokens() { let tokenizer = TiktokenTokenizer::new(TiktokenModel::Cl100kBase).unwrap(); let special_tokens = tokenizer.get_special_tokens(); assert!(special_tokens.eos_token.is_some()); assert_eq!(special_tokens.eos_token.as_ref().unwrap(), "<|endoftext|>"); } #[test] fn test_unrecognized_model_name_returns_error() { // Test that unrecognized model names return an error let result = TiktokenTokenizer::from_model_name("distilgpt-2"); assert!(result.is_err()); if let Err(e) = result { assert!(e.to_string().contains("Unrecognized OpenAI model name")); } let result = TiktokenTokenizer::from_model_name("bert-base-uncased"); assert!(result.is_err()); if let Err(e) = result { assert!(e.to_string().contains("Unrecognized OpenAI model name")); } let result = TiktokenTokenizer::from_model_name("llama-7b"); assert!(result.is_err()); if let Err(e) = result { assert!(e.to_string().contains("Unrecognized OpenAI model name")); } } #[test] fn test_recognized_model_names() { // Test that recognized model names work correctly assert!(TiktokenTokenizer::from_model_name("gpt-4").is_ok()); assert!(TiktokenTokenizer::from_model_name("gpt-3.5-turbo").is_ok()); assert!(TiktokenTokenizer::from_model_name("text-davinci-003").is_ok()); assert!(TiktokenTokenizer::from_model_name("code-davinci-002").is_ok()); assert!(TiktokenTokenizer::from_model_name("text-curie-001").is_ok()); assert!(TiktokenTokenizer::from_model_name("text-babbage-001").is_ok()); assert!(TiktokenTokenizer::from_model_name("text-ada-001").is_ok()); } }