Unverified Commit d08663ee authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] tokenizer factory, hf tokenizer, and stop sequence detector (#9293)


Co-authored-by: default avatarChang Su <chang.s.su@oracle.com>
parent 716e6827
......@@ -3,6 +3,10 @@ name = "sglang_router_rs"
version = "0.0.0"
edition = "2021"
[features]
default = ["huggingface"]
huggingface = ["tokenizers"]
[lib]
name = "sglang_router_rs"
# Pure Rust library: Just omit crate-type (defaults to rlib)
......@@ -44,7 +48,7 @@ thiserror = "2.0.12"
url = "2.5.4"
tokio-stream = { version = "0.1", features = ["sync"] }
anyhow = "1.0"
tokenizers = "0.21.4"
tokenizers = { version = "0.21.4", optional = true }
[dev-dependencies]
criterion = { version = "0.5", features = ["html_reports"] }
......
use super::traits;
use anyhow::{Error, Result};
use std::fs::File;
use std::io::Read;
use std::path::Path;
use std::sync::Arc;
#[cfg(feature = "huggingface")]
use super::huggingface::HuggingFaceTokenizer;
/// Represents the type of tokenizer being used
#[derive(Debug, Clone)]
pub enum TokenizerType {
HuggingFace(String),
Mock,
// Future: SentencePiece, GGUF, Tiktoken
}
/// Create a tokenizer from a file path to a tokenizer file.
/// The file extension is used to determine the tokenizer type.
/// Supported file types are:
/// - json: HuggingFace tokenizer
/// - For testing: can return mock tokenizer
pub fn create_tokenizer_from_file(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
// Special case for testing
if file_path == "mock" || file_path == "test" {
return Ok(Arc::new(super::mock::MockTokenizer::new()));
}
let path = Path::new(file_path);
// Check if file exists
if !path.exists() {
return Err(Error::msg(format!("File not found: {}", file_path)));
}
// Try to determine tokenizer type from extension
let extension = path
.extension()
.and_then(std::ffi::OsStr::to_str)
.map(|s| s.to_lowercase());
match extension.as_deref() {
Some("json") => {
#[cfg(feature = "huggingface")]
{
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
Ok(Arc::new(tokenizer))
}
#[cfg(not(feature = "huggingface"))]
{
Err(Error::msg(
"HuggingFace support not enabled. Enable the 'huggingface' feature.",
))
}
}
Some("model") => {
// SentencePiece model file
Err(Error::msg("SentencePiece models not yet supported"))
}
Some("gguf") => {
// GGUF format
Err(Error::msg("GGUF format not yet supported"))
}
_ => {
// Try to auto-detect by reading file content
auto_detect_tokenizer(file_path)
}
}
}
/// Auto-detect tokenizer type by examining file content
fn auto_detect_tokenizer(file_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
let mut file = File::open(file_path)?;
let mut buffer = vec![0u8; 512]; // Read first 512 bytes for detection
let bytes_read = file.read(&mut buffer)?;
buffer.truncate(bytes_read);
// Check for JSON (HuggingFace format)
if is_likely_json(&buffer) {
#[cfg(feature = "huggingface")]
{
let tokenizer = HuggingFaceTokenizer::from_file(file_path)?;
return Ok(Arc::new(tokenizer));
}
#[cfg(not(feature = "huggingface"))]
{
return Err(Error::msg(
"File appears to be JSON (HuggingFace) format, but HuggingFace support is not enabled",
));
}
}
// Check for GGUF magic number
if buffer.len() >= 4 && &buffer[0..4] == b"GGUF" {
return Err(Error::msg("GGUF format detected but not yet supported"));
}
// Check for SentencePiece model
if is_likely_sentencepiece(&buffer) {
return Err(Error::msg(
"SentencePiece model detected but not yet supported",
));
}
Err(Error::msg(format!(
"Unable to determine tokenizer type for file: {}",
file_path
)))
}
/// Check if the buffer likely contains JSON data
fn is_likely_json(buffer: &[u8]) -> bool {
// Skip UTF-8 BOM if present
let content = if buffer.len() >= 3 && buffer[0..3] == [0xEF, 0xBB, 0xBF] {
&buffer[3..]
} else {
buffer
};
// Find first non-whitespace character without allocation
if let Some(first_byte) = content.iter().find(|&&b| !b.is_ascii_whitespace()) {
*first_byte == b'{' || *first_byte == b'['
} else {
false
}
}
/// Check if the buffer likely contains a SentencePiece model
fn is_likely_sentencepiece(buffer: &[u8]) -> bool {
// SentencePiece models often start with specific patterns
// This is a simplified check
buffer.len() >= 12
&& (buffer.starts_with(b"\x0a\x09")
|| buffer.starts_with(b"\x08\x00")
|| buffer.windows(4).any(|w| w == b"<unk")
|| buffer.windows(4).any(|w| w == b"<s>")
|| buffer.windows(4).any(|w| w == b"</s>"))
}
/// Factory function to create tokenizer from a model name or path
pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Tokenizer>> {
// Check if it's a file path
let path = Path::new(model_name_or_path);
if path.exists() {
return create_tokenizer_from_file(model_name_or_path);
}
// Otherwise, try to load from HuggingFace Hub
#[cfg(feature = "huggingface")]
{
// This would download from HF Hub - not implemented yet
Err(Error::msg(
"Loading from HuggingFace Hub not yet implemented",
))
}
#[cfg(not(feature = "huggingface"))]
{
Err(Error::msg(format!(
"Model '{}' not found locally and HuggingFace support is not enabled",
model_name_or_path
)))
}
}
/// Get information about a tokenizer file
pub fn get_tokenizer_info(file_path: &str) -> Result<TokenizerType> {
let path = Path::new(file_path);
if !path.exists() {
return Err(Error::msg(format!("File not found: {}", file_path)));
}
let extension = path
.extension()
.and_then(std::ffi::OsStr::to_str)
.map(|s| s.to_lowercase());
match extension.as_deref() {
Some("json") => Ok(TokenizerType::HuggingFace(file_path.to_string())),
_ => {
// Try auto-detection
use std::fs::File;
use std::io::Read;
let mut file = File::open(file_path)?;
let mut buffer = vec![0u8; 512];
let bytes_read = file.read(&mut buffer)?;
buffer.truncate(bytes_read);
if is_likely_json(&buffer) {
Ok(TokenizerType::HuggingFace(file_path.to_string()))
} else {
Err(Error::msg("Unknown tokenizer type"))
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_detection() {
assert!(is_likely_json(b"{\"test\": \"value\"}"));
assert!(is_likely_json(b" \n\t{\"test\": \"value\"}"));
assert!(is_likely_json(b"[1, 2, 3]"));
assert!(!is_likely_json(b"not json"));
assert!(!is_likely_json(b""));
}
#[test]
fn test_mock_tokenizer_creation() {
let tokenizer = create_tokenizer_from_file("mock").unwrap();
assert_eq!(tokenizer.vocab_size(), 8); // Mock tokenizer has 8 tokens
}
#[test]
fn test_file_not_found() {
let result = create_tokenizer_from_file("/nonexistent/file.json");
assert!(result.is_err());
if let Err(e) = result {
assert!(e.to_string().contains("File not found"));
}
}
}
use super::traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
use anyhow::{Error, Result};
use std::collections::HashMap;
use tokenizers::tokenizer::Tokenizer as HfTokenizer;
/// HuggingFace tokenizer wrapper
pub struct HuggingFaceTokenizer {
tokenizer: HfTokenizer,
special_tokens: SpecialTokens,
vocab: HashMap<String, u32>,
reverse_vocab: HashMap<u32, String>,
}
impl HuggingFaceTokenizer {
/// Create a tokenizer from a HuggingFace tokenizer JSON file
pub fn from_file(file_path: &str) -> Result<Self> {
let tokenizer = HfTokenizer::from_file(file_path)
.map_err(|e| Error::msg(format!("Failed to load tokenizer: {}", e)))?;
// Extract special tokens
let special_tokens = Self::extract_special_tokens(&tokenizer);
// Build vocab mappings
let vocab = tokenizer.get_vocab(false);
let reverse_vocab: HashMap<u32, String> = vocab
.iter()
.map(|(token, &id)| (id, token.clone()))
.collect();
Ok(HuggingFaceTokenizer {
tokenizer,
special_tokens,
vocab,
reverse_vocab,
})
}
/// Create from an existing HuggingFace tokenizer
pub fn from_tokenizer(tokenizer: HfTokenizer) -> Self {
let special_tokens = Self::extract_special_tokens(&tokenizer);
let vocab = tokenizer.get_vocab(false);
let reverse_vocab: HashMap<u32, String> = vocab
.iter()
.map(|(token, &id)| (id, token.clone()))
.collect();
HuggingFaceTokenizer {
tokenizer,
special_tokens,
vocab,
reverse_vocab,
}
}
/// Extract special tokens from the tokenizer
fn extract_special_tokens(tokenizer: &HfTokenizer) -> SpecialTokens {
// Try to get special tokens from the tokenizer
// This is a simplified version - actual implementation would need to handle various formats
let vocab = tokenizer.get_vocab(true);
let find_token = |patterns: &[&str]| -> Option<String> {
for pattern in patterns {
if vocab.contains_key(*pattern) {
return Some(pattern.to_string());
}
}
None
};
SpecialTokens {
bos_token: find_token(&["<s>", "<|startoftext|>", "<BOS>", "[CLS]"]),
eos_token: find_token(&["</s>", "<|endoftext|>", "<EOS>", "[SEP]"]),
unk_token: find_token(&["<unk>", "<UNK>", "[UNK]"]),
sep_token: find_token(&["[SEP]", "<sep>", "<SEP>"]),
pad_token: find_token(&["<pad>", "<PAD>", "[PAD]"]),
cls_token: find_token(&["[CLS]", "<cls>", "<CLS>"]),
mask_token: find_token(&["[MASK]", "<mask>", "<MASK>"]),
additional_special_tokens: vec![],
}
}
/// Apply chat template if available
pub fn apply_chat_template(&self, messages: &[ChatMessage]) -> Result<String> {
// This is a placeholder - actual implementation would handle templates
let mut result = String::new();
for msg in messages {
result.push_str(&format!("{}: {}\n", msg.role, msg.content));
}
Ok(result)
}
}
impl Encoder for HuggingFaceTokenizer {
fn encode(&self, input: &str) -> Result<Encoding> {
let encoding = self
.tokenizer
.encode(input, false)
.map_err(|e| Error::msg(format!("Encoding failed: {}", e)))?;
Ok(Encoding::Hf(Box::new(encoding)))
}
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>> {
let encodings = self
.tokenizer
.encode_batch(inputs.to_vec(), false)
.map_err(|e| Error::msg(format!("Batch encoding failed: {}", e)))?;
Ok(encodings
.into_iter()
.map(|e| Encoding::Hf(Box::new(e)))
.collect())
}
}
impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[u32], skip_special_tokens: bool) -> Result<String> {
self.tokenizer
.decode(token_ids, skip_special_tokens)
.map_err(|e| Error::msg(format!("Decoding failed: {}", e)))
}
}
impl TokenizerTrait for HuggingFaceTokenizer {
fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(false)
}
fn get_special_tokens(&self) -> &SpecialTokens {
&self.special_tokens
}
fn token_to_id(&self, token: &str) -> Option<u32> {
self.vocab.get(token).copied()
}
fn id_to_token(&self, id: u32) -> Option<String> {
self.reverse_vocab.get(&id).cloned()
}
}
/// Represents a chat message for template application
#[derive(Debug, Clone)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn new(role: impl Into<String>, content: impl Into<String>) -> Self {
ChatMessage {
role: role.into(),
content: content.into(),
}
}
pub fn system(content: impl Into<String>) -> Self {
Self::new("system", content)
}
pub fn user(content: impl Into<String>) -> Self {
Self::new("user", content)
}
pub fn assistant(content: impl Into<String>) -> Self {
Self::new("assistant", content)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chat_message_creation() {
let msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(msg.role, "system");
assert_eq!(msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
}
// Note: Actual tokenizer tests would require a real tokenizer file
// These would be integration tests rather than unit tests
}
......@@ -2,26 +2,36 @@ use anyhow::Result;
use std::ops::Deref;
use std::sync::Arc;
pub mod factory;
pub mod mock;
pub mod stop;
pub mod stream;
pub mod traits;
// Feature-gated modules
#[cfg(feature = "huggingface")]
pub mod huggingface;
#[cfg(test)]
mod tests;
// Re-exports
pub use factory::{create_tokenizer, create_tokenizer_from_file, TokenizerType};
pub use stop::{SequenceDecoderOutput, StopSequenceConfig, StopSequenceDecoder};
pub use stream::DecodeStream;
pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as TokenizerTrait};
#[cfg(feature = "huggingface")]
pub use huggingface::{ChatMessage, HuggingFaceTokenizer};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)]
pub struct Tokenizer(Arc<dyn traits::Tokenizer>);
impl Tokenizer {
/// Create a tokenizer from a file path
/// Will be implemented in Phase 3 with factory pattern
pub fn from_file(_file_path: &str) -> Result<Tokenizer> {
// TODO: Implement factory pattern in Phase 3
unimplemented!("Factory pattern will be implemented in Phase 3")
pub fn from_file(file_path: &str) -> Result<Tokenizer> {
Ok(Tokenizer(factory::create_tokenizer_from_file(file_path)?))
}
/// Create a tokenizer from an Arc<dyn Tokenizer>
......
use super::traits;
use anyhow::Result;
use std::collections::HashSet;
use std::sync::Arc;
/// Output from the sequence decoder
#[derive(Debug, Clone, PartialEq)]
pub enum SequenceDecoderOutput {
/// Normal text output
Text(String),
/// Text is being held due to partial stop sequence match
Held,
/// Stop sequence matched (hidden - not included in output)
Stopped,
/// Stop sequence matched with text (visible - included in output)
StoppedWithText(String),
}
/// Configuration for stop sequences
#[derive(Debug, Clone, Default)]
pub struct StopSequenceConfig {
/// Token IDs that trigger a stop
pub stop_tokens: HashSet<u32>,
/// String sequences that trigger a stop
pub stop_sequences: Vec<String>,
/// Token IDs for visible stops (included in output)
pub visible_stop_tokens: HashSet<u32>,
/// String sequences for visible stops (included in output)
pub visible_stop_sequences: Vec<String>,
}
impl StopSequenceConfig {
/// Builder pattern - add a stop token
pub fn with_stop_token(mut self, token_id: u32) -> Self {
self.stop_tokens.insert(token_id);
self
}
/// Builder pattern - add a stop sequence
pub fn with_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.stop_sequences.push(sequence.into());
self
}
/// Builder pattern - add a visible stop token
pub fn with_visible_stop_token(mut self, token_id: u32) -> Self {
self.visible_stop_tokens.insert(token_id);
self
}
/// Builder pattern - add a visible stop sequence
pub fn with_visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.visible_stop_sequences.push(sequence.into());
self
}
}
/// Decoder that handles stop sequences
pub struct StopSequenceDecoder {
tokenizer: Arc<dyn traits::Tokenizer>,
config: StopSequenceConfig,
/// Buffer for partial matches (the "jail")
jail_buffer: String,
/// Accumulated tokens
token_buffer: Vec<u32>,
/// Offset where the prefix text starts (for context)
prefix_offset: usize,
/// Offset marking the end of previously decoded text
read_offset: usize,
/// Whether we've stopped
stopped: bool,
skip_special_tokens: bool,
}
impl StopSequenceDecoder {
/// Create a new stop sequence decoder
pub fn new(
tokenizer: Arc<dyn traits::Tokenizer>,
config: StopSequenceConfig,
skip_special_tokens: bool,
) -> Self {
StopSequenceDecoder {
tokenizer,
config,
jail_buffer: String::new(),
token_buffer: Vec::new(),
prefix_offset: 0,
read_offset: 0,
stopped: false,
skip_special_tokens,
}
}
/// Process a single token
pub fn process_token(&mut self, token_id: u32) -> Result<SequenceDecoderOutput> {
if self.stopped {
return Ok(SequenceDecoderOutput::Stopped);
}
// Check for token-level stops first
if self.config.stop_tokens.contains(&token_id) {
self.stopped = true;
// Flush any jailed text before stopping
if !self.jail_buffer.is_empty() {
let output = self.jail_buffer.clone();
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
return Ok(SequenceDecoderOutput::Stopped);
}
if self.config.visible_stop_tokens.contains(&token_id) {
self.stopped = true;
// Include jailed text plus the stop token
let stop_text = self
.tokenizer
.decode(&[token_id], self.skip_special_tokens)?;
let output = format!("{}{}", self.jail_buffer, stop_text);
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
// Add token to buffer
self.token_buffer.push(token_id);
// Use incremental decoding like DecodeStream
// First decode the previous context (what we've already output)
let prefix_text = if self.read_offset > self.prefix_offset {
self.tokenizer.decode(
&self.token_buffer[self.prefix_offset..self.read_offset],
self.skip_special_tokens,
)?
} else {
String::new()
};
// Now decode from prefix to current position
let new_full_text = self.tokenizer.decode(
&self.token_buffer[self.prefix_offset..],
self.skip_special_tokens,
)?;
// Check for incomplete UTF-8 sequence
if new_full_text.ends_with("�") {
// Wait for more tokens to complete the sequence
return Ok(SequenceDecoderOutput::Held);
}
// Calculate only the NEW text since last successful decode
let new_text = if new_full_text.len() > prefix_text.len() {
&new_full_text[prefix_text.len()..]
} else {
// No new text produced (can happen with special tokens)
return Ok(SequenceDecoderOutput::Held);
};
// Combine jail buffer with new text for checking
let check_text = format!("{}{}", self.jail_buffer, new_text);
// Check for complete stop sequences
for stop_seq in &self.config.stop_sequences {
if let Some(pos) = check_text.find(stop_seq) {
self.stopped = true;
// Output text before the stop sequence
let output = check_text[..pos].to_string();
self.jail_buffer.clear();
return Ok(if output.is_empty() {
SequenceDecoderOutput::Stopped
} else {
SequenceDecoderOutput::StoppedWithText(output)
});
}
}
// Check for visible stop sequences
for stop_seq in &self.config.visible_stop_sequences {
if let Some(pos) = check_text.find(stop_seq) {
self.stopped = true;
// Include the stop sequence in output
let end_pos = pos + stop_seq.len();
let output = check_text[..end_pos].to_string();
self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output));
}
}
// Check for partial matches at the end of check_text
let mut partial_match_len = 0;
for stop_seq in self
.config
.stop_sequences
.iter()
.chain(&self.config.visible_stop_sequences)
{
// Check all possible suffixes that could be a prefix of stop_seq
for i in 1..=check_text.len().min(stop_seq.len() - 1) {
let suffix = &check_text[check_text.len() - i..];
if stop_seq.starts_with(suffix) {
partial_match_len = partial_match_len.max(i);
}
}
}
if partial_match_len > 0 {
// Split: output safe text, jail the potential match
let safe_end = check_text.len() - partial_match_len;
let safe_text = &check_text[..safe_end];
self.jail_buffer = check_text[safe_end..].to_string();
// Update offsets for next iteration
self.prefix_offset = self.read_offset;
self.read_offset = self.token_buffer.len();
if safe_text.is_empty() {
Ok(SequenceDecoderOutput::Held)
} else {
Ok(SequenceDecoderOutput::Text(safe_text.to_string()))
}
} else {
// No partial matches - output everything
self.jail_buffer.clear();
// Update offsets for next iteration
self.prefix_offset = self.read_offset;
self.read_offset = self.token_buffer.len();
Ok(SequenceDecoderOutput::Text(check_text))
}
}
/// Process multiple tokens
pub fn process_tokens(&mut self, token_ids: &[u32]) -> Result<Vec<SequenceDecoderOutput>> {
let mut outputs = Vec::new();
for &token_id in token_ids {
outputs.push(self.process_token(token_id)?);
}
Ok(outputs)
}
/// Flush any held text
pub fn flush(&mut self) -> SequenceDecoderOutput {
if !self.jail_buffer.is_empty() {
let output = self.jail_buffer.clone();
self.jail_buffer.clear();
SequenceDecoderOutput::Text(output)
} else {
SequenceDecoderOutput::Text(String::new())
}
}
/// Check if decoding has stopped
pub fn is_stopped(&self) -> bool {
self.stopped
}
/// Reset the decoder state
pub fn reset(&mut self) {
self.jail_buffer.clear();
self.token_buffer.clear();
self.prefix_offset = 0;
self.read_offset = 0;
self.stopped = false;
}
}
/// Builder for StopSequenceDecoder
pub struct StopSequenceDecoderBuilder {
tokenizer: Arc<dyn traits::Tokenizer>,
config: StopSequenceConfig,
skip_special_tokens: bool,
}
impl StopSequenceDecoderBuilder {
pub fn new(tokenizer: Arc<dyn traits::Tokenizer>) -> Self {
StopSequenceDecoderBuilder {
tokenizer,
config: StopSequenceConfig::default(),
skip_special_tokens: true,
}
}
pub fn stop_token(mut self, token_id: u32) -> Self {
self.config.stop_tokens.insert(token_id);
self
}
pub fn stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.config.stop_sequences.push(sequence.into());
self
}
pub fn visible_stop_token(mut self, token_id: u32) -> Self {
self.config.visible_stop_tokens.insert(token_id);
self
}
pub fn visible_stop_sequence(mut self, sequence: impl Into<String>) -> Self {
self.config.visible_stop_sequences.push(sequence.into());
self
}
pub fn skip_special_tokens(mut self, skip: bool) -> Self {
self.skip_special_tokens = skip;
self
}
pub fn build(self) -> StopSequenceDecoder {
StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizer::mock::MockTokenizer;
#[test]
fn test_stop_token_detection() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_token(999); // <eos> token
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process tokens before stop
let result = decoder.process_token(1).unwrap(); // "Hello"
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
// Process stop token
let result = decoder.process_token(999).unwrap(); // <eos>
assert_eq!(result, SequenceDecoderOutput::Stopped);
// Further tokens should also return Stopped
let result = decoder.process_token(2).unwrap();
assert_eq!(result, SequenceDecoderOutput::Stopped);
}
#[test]
fn test_visible_stop_token() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_visible_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
let result = decoder.process_token(999).unwrap();
assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_)));
}
#[test]
fn test_builder_pattern() {
let tokenizer = Arc::new(MockTokenizer::new());
let decoder = StopSequenceDecoderBuilder::new(tokenizer)
.stop_token(999)
.stop_sequence("STOP")
.visible_stop_token(1000)
.skip_special_tokens(true)
.build();
assert!(!decoder.is_stopped());
}
#[test]
fn test_incremental_decoding_no_repetition() {
// This test verifies the critical fix: no repeated output
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default();
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process tokens one by one and collect outputs
let mut outputs = Vec::new();
// Token 1: "Hello"
let result = decoder.process_token(1).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
// Token 2: "world"
let result = decoder.process_token(2).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
// Token 3: "test"
let result = decoder.process_token(3).unwrap();
if let SequenceDecoderOutput::Text(text) = result {
outputs.push(text.clone());
}
// CRITICAL: Each output should be unique (no accumulation)
// The fix ensures we only output NEW text, not accumulated text
assert_eq!(outputs.len(), 3);
// Verify no text is repeated
for i in 0..outputs.len() {
for j in i + 1..outputs.len() {
// No output should contain another (no accumulation)
assert!(!outputs[j].contains(&outputs[i]));
}
}
}
#[test]
fn test_stop_sequence_detection() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("test");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process "Hello world"
decoder.process_token(1).unwrap(); // "Hello"
decoder.process_token(2).unwrap(); // "world"
// Process "test" which should trigger stop
let result = decoder.process_token(3).unwrap(); // "test"
// Should stop when we hit "test"
assert!(matches!(
result,
SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_)
));
}
#[test]
fn test_flush_after_partial() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process a token
decoder.process_token(1).unwrap(); // "Hello"
// Flush should return any remaining text in jail
let result = decoder.flush();
// After processing, flush should work
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
}
#[test]
fn test_reset_functionality() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_stop_token(999);
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process and stop
decoder.process_token(1).unwrap();
decoder.process_token(999).unwrap();
assert!(decoder.is_stopped());
// Reset should clear everything
decoder.reset();
assert!(!decoder.is_stopped());
// Should be able to process again
let result = decoder.process_token(2).unwrap();
assert!(matches!(result, SequenceDecoderOutput::Text(_)));
}
#[test]
fn test_visible_stop_sequence() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default().with_visible_stop_sequence("world");
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process "Hello"
decoder.process_token(1).unwrap();
// Process "world" - should include it in output
let result = decoder.process_token(2).unwrap();
if let SequenceDecoderOutput::StoppedWithText(text) = result {
// Should include "world" in the output
assert!(text.contains("world"));
} else {
panic!("Expected StoppedWithText with visible stop sequence");
}
}
#[test]
fn test_multiple_tokens_processing() {
let tokenizer = Arc::new(MockTokenizer::new());
let config = StopSequenceConfig::default();
let mut decoder = StopSequenceDecoder::new(tokenizer, config, false);
// Process multiple tokens at once
let results = decoder.process_tokens(&[1, 2, 3]).unwrap();
// Should get results for each token
assert_eq!(results.len(), 3);
// Each result should be Text (no stops configured)
for result in results {
assert!(matches!(
result,
SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held
));
}
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment