"src/vscode:/vscode.git/clone" did not exist on "f73d0b6becf8f6b985c237bdfc94e31f2d96e956"
Unverified Commit ff0cf51c authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] introducing tokenizer trait (#9287)

parent a1c7f742
......@@ -43,6 +43,8 @@ uuid = { version = "1.10", features = ["v4", "serde"] }
thiserror = "2.0.12"
url = "2.5.4"
tokio-stream = { version = "0.1", features = ["sync"] }
anyhow = "1.0"
tokenizers = "0.21.4"
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
......
......@@ -10,6 +10,7 @@ pub mod policies;
pub mod routers;
pub mod server;
pub mod service_discovery;
pub mod tokenizer;
pub mod tree;
use crate::metrics::PrometheusConfig;
......
//! Mock tokenizer implementation for testing
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
use anyhow::Result;
use std::collections::HashMap;
/// Mock tokenizer for testing purposes
pub struct MockTokenizer {
vocab: HashMap<String, u32>,
reverse_vocab: HashMap<u32, String>,
special_tokens: SpecialTokens,
}
impl Default for MockTokenizer {
fn default() -> Self {
Self::new()
}
}
impl MockTokenizer {
pub fn new() -> Self {
let mut vocab = HashMap::new();
let mut reverse_vocab = HashMap::new();
// Add some basic tokens
let tokens = vec![
("Hello", 1),
("world", 2),
("test", 3),
("token", 4),
(" ", 5),
(".", 6),
("<eos>", 999),
("<bos>", 1000),
];
for (token, id) in tokens {
vocab.insert(token.to_string(), id);
reverse_vocab.insert(id, token.to_string());
}
let special_tokens = SpecialTokens {
bos_token: Some("<bos>".to_string()),
eos_token: Some("<eos>".to_string()),
unk_token: Some("<unk>".to_string()),
sep_token: None,
pad_token: None,
cls_token: None,
mask_token: None,
additional_special_tokens: vec![],
};
Self {
vocab,
reverse_vocab,
special_tokens,
}
}
}
impl Encoder for MockTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
// Simple word-based tokenization for testing
let tokens: Vec<u32> = input
.split_whitespace()
.filter_map(|word| self.vocab.get(word).copied())
.collect();
Ok(Encoding::Sp(tokens))
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
inputs.iter().map(|input| self.encode(input)).collect()
}
}
impl Decoder for MockTokenizer {
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
let tokens: Vec<String> = token_ids
.iter()
.filter_map(|id| {
self.reverse_vocab.get(id).and_then(|token| {
if skip_special_tokens && (token == "<eos>" || token == "<bos>") {
None
} else {
Some(token.clone())
}
})
})
.collect();
Ok(tokens.join(" "))
}
}
impl TokenizerTrait for MockTokenizer {
fn vocab_size(&self) -> usize {
self.vocab.len()
}
fn get_special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.reverse_vocab.get(&id).cloned()
}
}
use anyhow::Result;
use std::ops::Deref;
use std::sync::Arc;
pub mod mock;
pub mod stream;
pub mod traits;
#[cfg(test)]
mod tests;
pub use stream::DecodeStream;
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)]
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
impl Tokenizer {
/// Create a tokenizer from a file path
/// Will be implemented in Phase 3 with factory pattern
pub fn from_file(_file_path: &str) -> Result<Tokenizer> {
// TODO: Implement factory pattern in Phase 3
unimplemented!("Factory pattern will be implemented in Phase 3")
}
/// Create a tokenizer from an Arc<dyn Tokenizer>
pub fn from_arc(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
Tokenizer(tokenizer)
}
/// Create a stateful sequence object for decoding token_ids into text
pub fn decode_stream(
&self,
prompt_token_ids: &[u32],
skip_special_tokens: bool,
) -> DecodeStream {
DecodeStream::new(self.0.clone(), prompt_token_ids, skip_special_tokens)
}
/// Direct encode method
pub fn encode(&self, input: &str) -> Result<Encoding> {
self.0.encode(input)
}
/// Direct batch encode method
pub fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
self.0.encode_batch(inputs)
}
/// Direct decode method
pub fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
self.0.decode(token_ids, skip_special_tokens)
}
/// Get vocabulary size
pub fn vocab_size(&self) -> usize {
self.0.vocab_size()
}
/// Get special tokens
pub fn get_special_tokens(&self) -> &SpecialTokens {
self.0.get_special_tokens()
}
/// Convert token string to ID
pub fn token_to_id(&self, token: &str) -> Option<u32> {
self.0.token_to_id(token)
}
/// Convert ID to token string
pub fn id_to_token(&self, id: u32) -> Option<String> {
self.0.id_to_token(id)
}
}
impl Deref for Tokenizer {
type Target = Arc<dyn traits::Tokenizer>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl From<Arc<dyn traits::Tokenizer>> for Tokenizer {
fn from(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
Tokenizer(tokenizer)
}
}
// src/tokenizer/stream.rs
use super::traits;
use anyhow::Result;
use std::sync::Arc;
const INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET: usize = 5;
/// DecodeStream will keep the state necessary to produce individual chunks of
/// strings given an input stream of token_ids
pub struct DecodeStream {
/// The tokenizer used to decode token_ids
tokenizer: Arc<dyn traits::Tokenizer>,
skip_special_tokens: bool,
/// A temporary buffer of the necessary token_ids needed
/// to produce valid string chunks
all_token_ids: Vec<u32>,
prefix_offset: usize,
read_offset: usize,
}
impl DecodeStream {
pub fn new(
tokenizer: Arc<dyn traits::Tokenizer>,
prompt_token_ids: &[u32],
skip_special_tokens: bool,
) -> Self {
let num_input_tokens = prompt_token_ids.len();
let prompt_token_ids = prompt_token_ids.to_vec();
Self {
tokenizer,
skip_special_tokens,
all_token_ids: prompt_token_ids,
prefix_offset: num_input_tokens
.saturating_sub(INITIAL_INCREMENTAL_DETOKENIZATION_OFFSET),
read_offset: num_input_tokens,
}
}
/// Step appends a token_id to the internal state and tries to produce a text chunk.
/// Returning `None` means the given id is not enough to produce a chunk.
pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.all_token_ids.push(id);
let prefix_text = self.tokenizer.decode(
&self.all_token_ids[self.prefix_offset..self.read_offset],
self.skip_special_tokens,
)?;
let new_text = self.tokenizer.decode(
&self.all_token_ids[self.prefix_offset..],
self.skip_special_tokens,
)?;
if new_text.len() > prefix_text.len() && !new_text.ends_with("�") {
let new_text = new_text[prefix_text.len()..].to_string();
self.prefix_offset = self.read_offset;
self.read_offset = self.all_token_ids.len();
Ok(Some(new_text))
} else {
Ok(None)
}
}
/// Process multiple tokens at once
pub fn step_batch(&mut self, token_ids: &[u32]) -> Result<Vec<String>> {
let mut chunks = Vec::new();
for &token_id in token_ids {
if let Some(text) = self.step(token_id)? {
chunks.push(text);
}
}
Ok(chunks)
}
/// Force flush any remaining text
pub fn flush(&mut self) -> Result<Option<String>> {
if self.read_offset < self.all_token_ids.len() {
let remaining = self.tokenizer.decode(
&self.all_token_ids[self.read_offset..],
self.skip_special_tokens,
)?;
self.read_offset = self.all_token_ids.len();
if !remaining.is_empty() {
return Ok(Some(remaining));
}
}
Ok(None)
}
/// Get all tokens processed so far
pub fn tokens(&self) -> &[u32] {
&self.all_token_ids
}
}
#[cfg(test)]
use super::*;
#[cfg(test)]
use std::sync::Arc;
#[test]
fn test_mock_tokenizer_encode() {
let tokenizer = mock::MockTokenizer::new();
let encoding = tokenizer.encode("Hello world").unwrap();
let token_ids = encoding.token_ids();
assert_eq!(token_ids, &[1, 2]); // "Hello" -> 1, "world" -> 2
}
#[test]
fn test_mock_tokenizer_decode() {
let tokenizer = mock::MockTokenizer::new();
let text = tokenizer.decode(&[1, 2], false).unwrap();
assert_eq!(text, "Hello world");
}
#[test]
fn test_mock_tokenizer_decode_skip_special() {
let tokenizer = mock::MockTokenizer::new();
// With special tokens
let text = tokenizer.decode(&[1000, 1, 2, 999], false).unwrap();
assert_eq!(text, "<bos> Hello world <eos>");
// Without special tokens
let text = tokenizer.decode(&[1000, 1, 2, 999], true).unwrap();
assert_eq!(text, "Hello world");
}
#[test]
fn test_tokenizer_wrapper() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
// Test encoding
let encoding = tokenizer.encode("Hello world").unwrap();
assert_eq!(encoding.token_ids(), &[1, 2]);
// Test decoding
let text = tokenizer.decode(&[1, 2], false).unwrap();
assert_eq!(text, "Hello world");
// Test vocab size
assert_eq!(tokenizer.vocab_size(), 8);
// Test token to ID
assert_eq!(tokenizer.token_to_id("Hello"), Some(1));
assert_eq!(tokenizer.token_to_id("unknown"), None);
// Test ID to token
assert_eq!(tokenizer.id_to_token(1), Some("Hello".to_string()));
assert_eq!(tokenizer.id_to_token(9999), None);
}
#[test]
fn test_decode_stream_basic() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
// Create a decode stream with initial tokens
let initial_tokens = vec![1, 2]; // "Hello world"
let mut stream = tokenizer.decode_stream(&initial_tokens, false);
// Add a new token
let result = stream.step(3).unwrap(); // "test"
// Since we're using a mock, the actual incremental behavior depends on implementation
// For now, we just verify it doesn't crash
assert!(result.is_some() || result.is_none());
}
#[test]
fn test_decode_stream_flush() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
let initial_tokens = vec![1];
let mut stream = tokenizer.decode_stream(&initial_tokens, false);
// Add tokens
stream.step(2).unwrap();
stream.step(3).unwrap();
// Flush remaining
let flushed = stream.flush().unwrap();
// The flush behavior depends on the implementation
assert!(flushed.is_some() || flushed.is_none());
}
#[test]
fn test_special_tokens() {
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
let special_tokens = tokenizer.get_special_tokens();
assert_eq!(special_tokens.bos_token, Some("<bos>".to_string()));
assert_eq!(special_tokens.eos_token, Some("<eos>".to_string()));
assert_eq!(special_tokens.unk_token, Some("<unk>".to_string()));
assert!(special_tokens.sep_token.is_none());
assert!(special_tokens.pad_token.is_none());
}
#[test]
fn test_batch_encode() {
let tokenizer = mock::MockTokenizer::new();
let inputs = vec!["Hello", "world", "test"];
let encodings = tokenizer.encode_batch(&inputs).unwrap();
assert_eq!(encodings.len(), 3);
assert_eq!(encodings[0].token_ids(), &[1]); // "Hello" -> 1
assert_eq!(encodings[1].token_ids(), &[2]); // "world" -> 2
assert_eq!(encodings[2].token_ids(), &[3]); // "test" -> 3
}
#[test]
fn test_thread_safety() {
use std::thread;
let mock_tokenizer = Arc::new(mock::MockTokenizer::new());
let tokenizer = Tokenizer::from_arc(mock_tokenizer);
// Spawn multiple threads that use the same tokenizer
let handles: Vec<_> = (0..10)
.map(|i| {
let tokenizer_clone = tokenizer.clone();
thread::spawn(move || {
let text = "Hello test".to_string();
let encoding = tokenizer_clone.encode(&text).unwrap();
let decoded = tokenizer_clone.decode(encoding.token_ids(), false).unwrap();
assert!(decoded.contains("Hello") || decoded.contains("test"));
i
})
})
.collect();
// Wait for all threads to complete
for handle in handles {
handle.join().unwrap();
}
}
use anyhow::Result;
/// Core encoding trait - separate from decoding for modularity
pub trait Encoder: Send + Sync {
fn encode(&self, input: &str) -> Result<Encoding>;
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
}
/// Core decoding trait - can be implemented independently
pub trait Decoder: Send + Sync {
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String>;
}
/// Combined tokenizer trait
pub trait Tokenizer: Encoder + Decoder {
fn vocab_size(&self) -> usize;
fn get_special_tokens(&self) -> &SpecialTokens;
fn token_to_id(&self, token: &str) -> Option<u32>;
fn id_to_token(&self, id: u32) -> Option<String>;
}
/// Contains the results of tokenizing text: token IDs, string tokens, and their spans
#[derive(Debug, Clone)]
pub enum Encoding {
/// Hugging Face
Hf(Box<tokenizers::tokenizer::Encoding>),
/// Sentence Piece
Sp(Vec<u32>),
}
impl Encoding {
pub fn token_ids(&self) -> &[u32] {
match self {
Encoding::Hf(inner) => inner.get_ids(),
Encoding::Sp(inner) => inner,
}
}
}
#[derive(Debug, Clone)]
pub struct SpecialTokens {
pub bos_token: Option<String>,
pub eos_token: Option<String>,
pub unk_token: Option<String>,
pub sep_token: Option<String>,
pub pad_token: Option<String>,
pub cls_token: Option<String>,
pub mask_token: Option<String>,
pub additional_special_tokens: Vec<String>,
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment