use super::traits::{self, TokenIdType}; use anyhow::Result; use std::collections::HashSet; use std::sync::Arc; /// Output from the sequence decoder #[derive(Debug, Clone, PartialEq)] pub enum SequenceDecoderOutput { /// Normal text output Text(String), /// Text is being held due to partial stop sequence match Held, /// Stop sequence matched (hidden - not included in output) Stopped, /// Stop sequence matched with text (visible - included in output) StoppedWithText(String), } /// Configuration for stop sequences #[derive(Debug, Clone, Default)] pub struct StopSequenceConfig { /// Token IDs that trigger a stop pub stop_tokens: HashSet, /// String sequences that trigger a stop pub stop_sequences: Vec, /// Token IDs for visible stops (included in output) pub visible_stop_tokens: HashSet, /// String sequences for visible stops (included in output) pub visible_stop_sequences: Vec, } impl StopSequenceConfig { /// Builder pattern - add a stop token pub fn with_stop_token(mut self, token_id: TokenIdType) -> Self { self.stop_tokens.insert(token_id); self } /// Builder pattern - add a stop sequence pub fn with_stop_sequence(mut self, sequence: impl Into) -> Self { self.stop_sequences.push(sequence.into()); self } /// Builder pattern - add a visible stop token pub fn with_visible_stop_token(mut self, token_id: TokenIdType) -> Self { self.visible_stop_tokens.insert(token_id); self } /// Builder pattern - add a visible stop sequence pub fn with_visible_stop_sequence(mut self, sequence: impl Into) -> Self { self.visible_stop_sequences.push(sequence.into()); self } } /// Decoder that handles stop sequences pub struct StopSequenceDecoder { tokenizer: Arc, config: StopSequenceConfig, /// Buffer for partial matches (the "jail") jail_buffer: String, /// Accumulated tokens token_buffer: Vec, /// Offset where the prefix text starts (for context) prefix_offset: usize, /// Offset marking the end of previously decoded text read_offset: usize, /// Whether we've stopped stopped: bool, skip_special_tokens: bool, } impl StopSequenceDecoder { /// Create a new stop sequence decoder pub fn new( tokenizer: Arc, config: StopSequenceConfig, skip_special_tokens: bool, ) -> Self { StopSequenceDecoder { tokenizer, config, jail_buffer: String::new(), token_buffer: Vec::new(), prefix_offset: 0, read_offset: 0, stopped: false, skip_special_tokens, } } /// Process a single token pub fn process_token(&mut self, token_id: TokenIdType) -> Result { if self.stopped { return Ok(SequenceDecoderOutput::Stopped); } // Check for token-level stops first if self.config.stop_tokens.contains(&token_id) { self.stopped = true; // Flush any jailed text before stopping if !self.jail_buffer.is_empty() { let output = self.jail_buffer.clone(); self.jail_buffer.clear(); return Ok(SequenceDecoderOutput::StoppedWithText(output)); } return Ok(SequenceDecoderOutput::Stopped); } if self.config.visible_stop_tokens.contains(&token_id) { self.stopped = true; // Include jailed text plus the stop token let stop_text = self .tokenizer .decode(&[token_id], self.skip_special_tokens)?; let output = format!("{}{}", self.jail_buffer, stop_text); self.jail_buffer.clear(); return Ok(SequenceDecoderOutput::StoppedWithText(output)); } // Add token to buffer self.token_buffer.push(token_id); // Use incremental decoding like DecodeStream // First decode the previous context (what we've already output) let prefix_text = if self.read_offset > self.prefix_offset { self.tokenizer.decode( &self.token_buffer[self.prefix_offset..self.read_offset], self.skip_special_tokens, )? } else { String::new() }; // Now decode from prefix to current position let new_full_text = self.tokenizer.decode( &self.token_buffer[self.prefix_offset..], self.skip_special_tokens, )?; // Check for incomplete UTF-8 sequence if new_full_text.ends_with("�") { // Wait for more tokens to complete the sequence return Ok(SequenceDecoderOutput::Held); } // Calculate only the NEW text since last successful decode let new_text = if new_full_text.len() > prefix_text.len() { &new_full_text[prefix_text.len()..] } else { // No new text produced (can happen with special tokens) return Ok(SequenceDecoderOutput::Held); }; // Combine jail buffer with new text for checking let check_text = format!("{}{}", self.jail_buffer, new_text); // Check for complete stop sequences for stop_seq in &self.config.stop_sequences { if let Some(pos) = check_text.find(stop_seq) { self.stopped = true; // Output text before the stop sequence let output = check_text[..pos].to_string(); self.jail_buffer.clear(); return Ok(if output.is_empty() { SequenceDecoderOutput::Stopped } else { SequenceDecoderOutput::StoppedWithText(output) }); } } // Check for visible stop sequences for stop_seq in &self.config.visible_stop_sequences { if let Some(pos) = check_text.find(stop_seq) { self.stopped = true; // Include the stop sequence in output let end_pos = pos + stop_seq.len(); let output = check_text[..end_pos].to_string(); self.jail_buffer.clear(); return Ok(SequenceDecoderOutput::StoppedWithText(output)); } } // Check for partial matches at the end of check_text let mut partial_match_len = 0; for stop_seq in self .config .stop_sequences .iter() .chain(&self.config.visible_stop_sequences) { // Check all possible suffixes that could be a prefix of stop_seq for i in 1..=check_text.len().min(stop_seq.len() - 1) { let suffix = &check_text[check_text.len() - i..]; if stop_seq.starts_with(suffix) { partial_match_len = partial_match_len.max(i); } } } if partial_match_len > 0 { // Split: output safe text, jail the potential match let safe_end = check_text.len() - partial_match_len; let safe_text = &check_text[..safe_end]; self.jail_buffer = check_text[safe_end..].to_string(); // Update offsets for next iteration self.prefix_offset = self.read_offset; self.read_offset = self.token_buffer.len(); if safe_text.is_empty() { Ok(SequenceDecoderOutput::Held) } else { Ok(SequenceDecoderOutput::Text(safe_text.to_string())) } } else { // No partial matches - output everything self.jail_buffer.clear(); // Update offsets for next iteration self.prefix_offset = self.read_offset; self.read_offset = self.token_buffer.len(); Ok(SequenceDecoderOutput::Text(check_text)) } } /// Process multiple tokens pub fn process_tokens( &mut self, token_ids: &[TokenIdType], ) -> Result> { let mut outputs = Vec::new(); for &token_id in token_ids { outputs.push(self.process_token(token_id)?); } Ok(outputs) } /// Flush any held text pub fn flush(&mut self) -> SequenceDecoderOutput { if !self.jail_buffer.is_empty() { let output = self.jail_buffer.clone(); self.jail_buffer.clear(); SequenceDecoderOutput::Text(output) } else { SequenceDecoderOutput::Text(String::new()) } } /// Check if decoding has stopped pub fn is_stopped(&self) -> bool { self.stopped } /// Reset the decoder state pub fn reset(&mut self) { self.jail_buffer.clear(); self.token_buffer.clear(); self.prefix_offset = 0; self.read_offset = 0; self.stopped = false; } } /// Builder for StopSequenceDecoder pub struct StopSequenceDecoderBuilder { tokenizer: Arc, config: StopSequenceConfig, skip_special_tokens: bool, } impl StopSequenceDecoderBuilder { pub fn new(tokenizer: Arc) -> Self { StopSequenceDecoderBuilder { tokenizer, config: StopSequenceConfig::default(), skip_special_tokens: true, } } pub fn stop_token(mut self, token_id: TokenIdType) -> Self { self.config.stop_tokens.insert(token_id); self } pub fn stop_sequence(mut self, sequence: impl Into) -> Self { self.config.stop_sequences.push(sequence.into()); self } pub fn visible_stop_token(mut self, token_id: TokenIdType) -> Self { self.config.visible_stop_tokens.insert(token_id); self } pub fn visible_stop_sequence(mut self, sequence: impl Into) -> Self { self.config.visible_stop_sequences.push(sequence.into()); self } pub fn skip_special_tokens(mut self, skip: bool) -> Self { self.skip_special_tokens = skip; self } pub fn build(self) -> StopSequenceDecoder { StopSequenceDecoder::new(self.tokenizer, self.config, self.skip_special_tokens) } } #[cfg(test)] mod tests { use super::*; use crate::tokenizer::mock::MockTokenizer; #[test] fn test_stop_token_detection() { let tokenizer = Arc::new(MockTokenizer::new()); let config = StopSequenceConfig::default().with_stop_token(999); // token let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); // Process tokens before stop let result = decoder.process_token(1).unwrap(); // "Hello" assert!(matches!(result, SequenceDecoderOutput::Text(_))); // Process stop token let result = decoder.process_token(999).unwrap(); // assert_eq!(result, SequenceDecoderOutput::Stopped); // Further tokens should also return Stopped let result = decoder.process_token(2).unwrap(); assert_eq!(result, SequenceDecoderOutput::Stopped); } #[test] fn test_visible_stop_token() { let tokenizer = Arc::new(MockTokenizer::new()); let config = StopSequenceConfig::default().with_visible_stop_token(999); let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); let result = decoder.process_token(999).unwrap(); assert!(matches!(result, SequenceDecoderOutput::StoppedWithText(_))); } #[test] fn test_builder_pattern() { let tokenizer = Arc::new(MockTokenizer::new()); let decoder = StopSequenceDecoderBuilder::new(tokenizer) .stop_token(999) .stop_sequence("STOP") .visible_stop_token(1000) .skip_special_tokens(true) .build(); assert!(!decoder.is_stopped()); } #[test] fn test_incremental_decoding_no_repetition() { // This test verifies the critical fix: no repeated output let tokenizer = Arc::new(MockTokenizer::new()); let config = StopSequenceConfig::default(); let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); // Process tokens one by one and collect outputs let mut outputs = Vec::new(); // Token 1: "Hello" let result = decoder.process_token(1).unwrap(); if let SequenceDecoderOutput::Text(text) = result { outputs.push(text.clone()); } // Token 2: "world" let result = decoder.process_token(2).unwrap(); if let SequenceDecoderOutput::Text(text) = result { outputs.push(text.clone()); } // Token 3: "test" let result = decoder.process_token(3).unwrap(); if let SequenceDecoderOutput::Text(text) = result { outputs.push(text.clone()); } // CRITICAL: Each output should be unique (no accumulation) // The fix ensures we only output NEW text, not accumulated text assert_eq!(outputs.len(), 3); // Verify no text is repeated for i in 0..outputs.len() { for j in i + 1..outputs.len() { // No output should contain another (no accumulation) assert!(!outputs[j].contains(&outputs[i])); } } } #[test] fn test_stop_sequence_detection() { let tokenizer = Arc::new(MockTokenizer::new()); let config = StopSequenceConfig::default().with_stop_sequence("test"); let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); // Process "Hello world" decoder.process_token(1).unwrap(); // "Hello" decoder.process_token(2).unwrap(); // "world" // Process "test" which should trigger stop let result = decoder.process_token(3).unwrap(); // "test" // Should stop when we hit "test" assert!(matches!( result, SequenceDecoderOutput::Stopped | SequenceDecoderOutput::StoppedWithText(_) )); } #[test] fn test_flush_after_partial() { let tokenizer = Arc::new(MockTokenizer::new()); let config = StopSequenceConfig::default().with_stop_sequence("NEVER_MATCH"); let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); // Process a token decoder.process_token(1).unwrap(); // "Hello" // Flush should return any remaining text in jail let result = decoder.flush(); // After processing, flush should work assert!(matches!(result, SequenceDecoderOutput::Text(_))); } #[test] fn test_reset_functionality() { let tokenizer = Arc::new(MockTokenizer::new()); let config = StopSequenceConfig::default().with_stop_token(999); let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); // Process and stop decoder.process_token(1).unwrap(); decoder.process_token(999).unwrap(); assert!(decoder.is_stopped()); // Reset should clear everything decoder.reset(); assert!(!decoder.is_stopped()); // Should be able to process again let result = decoder.process_token(2).unwrap(); assert!(matches!(result, SequenceDecoderOutput::Text(_))); } #[test] fn test_visible_stop_sequence() { let tokenizer = Arc::new(MockTokenizer::new()); let config = StopSequenceConfig::default().with_visible_stop_sequence("world"); let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); // Process "Hello" decoder.process_token(1).unwrap(); // Process "world" - should include it in output let result = decoder.process_token(2).unwrap(); if let SequenceDecoderOutput::StoppedWithText(text) = result { // Should include "world" in the output assert!(text.contains("world")); } else { panic!("Expected StoppedWithText with visible stop sequence"); } } #[test] fn test_multiple_tokens_processing() { let tokenizer = Arc::new(MockTokenizer::new()); let config = StopSequenceConfig::default(); let mut decoder = StopSequenceDecoder::new(tokenizer, config, false); // Process multiple tokens at once let results = decoder.process_tokens(&[1, 2, 3]).unwrap(); // Should get results for each token assert_eq!(results.len(), 3); // Each result should be Text (no stops configured) for result in results { assert!(matches!( result, SequenceDecoderOutput::Text(_) | SequenceDecoderOutput::Held )); } } }