Unverified Commit f5d30dae authored by Simo Lin's avatar Simo Lin Committed by GitHub
Browse files

[router] Refactor StopSequenceDecoder to Use Sequence for Incremental Decoding (#11676)

parent 2479b894
...@@ -16,6 +16,9 @@ pub struct Sequence { ...@@ -16,6 +16,9 @@ pub struct Sequence {
/// Current position in the sequence /// Current position in the sequence
read_offset: usize, read_offset: usize,
/// Whether to skip special tokens when decoding
skip_special_tokens: bool,
} }
impl std::fmt::Debug for Sequence { impl std::fmt::Debug for Sequence {
...@@ -45,22 +48,38 @@ impl std::fmt::Debug for Sequence { ...@@ -45,22 +48,38 @@ impl std::fmt::Debug for Sequence {
impl Sequence { impl Sequence {
/// Create a new empty sequence /// Create a new empty sequence
pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self { pub fn new(tokenizer: Arc<dyn TokenizerTrait>) -> Self {
Self::new_with_options(tokenizer, false)
}
/// Create a new empty sequence with skip_special_tokens option
pub fn new_with_options(tokenizer: Arc<dyn TokenizerTrait>, skip_special_tokens: bool) -> Self {
Self { Self {
tokenizer, tokenizer,
token_ids: Vec::new(), token_ids: Vec::new(),
prefix_offset: 0, prefix_offset: 0,
read_offset: 0, read_offset: 0,
skip_special_tokens,
} }
} }
/// Create a sequence with initial tokens /// Create a sequence with initial tokens
pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self { pub fn with_tokens(tokenizer: Arc<dyn TokenizerTrait>, token_ids: Vec<TokenIdType>) -> Self {
Self::with_tokens_and_options(tokenizer, token_ids, false)
}
/// Create a sequence with initial tokens and skip_special_tokens option
pub fn with_tokens_and_options(
tokenizer: Arc<dyn TokenizerTrait>,
token_ids: Vec<TokenIdType>,
skip_special_tokens: bool,
) -> Self {
let len = token_ids.len(); let len = token_ids.len();
Self { Self {
tokenizer, tokenizer,
token_ids, token_ids,
prefix_offset: 0, prefix_offset: 0,
read_offset: len, read_offset: len,
skip_special_tokens,
} }
} }
...@@ -99,7 +118,9 @@ impl Sequence { ...@@ -99,7 +118,9 @@ impl Sequence {
// If this is the first token or we're at the beginning, decode everything // If this is the first token or we're at the beginning, decode everything
if self.prefix_offset == 0 && old_read_offset == 0 { if self.prefix_offset == 0 && old_read_offset == 0 {
let text = self.tokenizer.decode(&self.token_ids, false)?; let text = self
.tokenizer
.decode(&self.token_ids, self.skip_special_tokens)?;
if text.ends_with("�") { if text.ends_with("�") {
// Incomplete UTF-8 sequence, wait for more tokens // Incomplete UTF-8 sequence, wait for more tokens
return Ok(String::new()); return Ok(String::new());
...@@ -109,14 +130,16 @@ impl Sequence { ...@@ -109,14 +130,16 @@ impl Sequence {
} }
// Decode the text up to the previous position // Decode the text up to the previous position
let prefix_text = self let prefix_text = self.tokenizer.decode(
.tokenizer &self.token_ids[self.prefix_offset..old_read_offset],
.decode(&self.token_ids[self.prefix_offset..old_read_offset], false)?; self.skip_special_tokens,
)?;
// Decode the text including the new token // Decode the text including the new token
let new_text = self let new_text = self.tokenizer.decode(
.tokenizer &self.token_ids[self.prefix_offset..],
.decode(&self.token_ids[self.prefix_offset..], false)?; self.skip_special_tokens,
)?;
// Handle multi-byte character boundaries // Handle multi-byte character boundaries
let mut prefix_text_len = prefix_text.len(); let mut prefix_text_len = prefix_text.len();
...@@ -151,7 +174,8 @@ impl Sequence { ...@@ -151,7 +174,8 @@ impl Sequence {
/// Decode the entire sequence to text /// Decode the entire sequence to text
pub fn text(&self) -> Result<String> { pub fn text(&self) -> Result<String> {
self.tokenizer.decode(&self.token_ids, false) self.tokenizer
.decode(&self.token_ids, self.skip_special_tokens)
} }
/// Get the prefix offset /// Get the prefix offset
...@@ -163,6 +187,11 @@ impl Sequence { ...@@ -163,6 +187,11 @@ impl Sequence {
pub fn read_offset(&self) -> usize { pub fn read_offset(&self) -> usize {
self.read_offset self.read_offset
} }
/// Get whether special tokens are skipped during decoding
pub fn skip_special_tokens(&self) -> bool {
self.skip_special_tokens
}
} }
#[cfg(test)] #[cfg(test)]
......
use super::sequence::Sequence;
use super::traits::{self, TokenIdType}; use super::traits::{self, TokenIdType};
use anyhow::Result; use anyhow::Result;
use std::collections::HashSet; use std::collections::HashSet;
...@@ -57,19 +58,13 @@ impl StopSequenceConfig { ...@@ -57,19 +58,13 @@ impl StopSequenceConfig {
/// Decoder that handles stop sequences /// Decoder that handles stop sequences
pub struct StopSequenceDecoder { pub struct StopSequenceDecoder {
tokenizer: Arc<dyn traits::Tokenizer>, /// Sequence for incremental decoding (replaces token_buffer + offsets)
sequence: Sequence,
config: StopSequenceConfig, config: StopSequenceConfig,
/// Buffer for partial matches (the "jail") /// Buffer for partial matches (the "jail")
jail_buffer: String, jail_buffer: String,
/// Accumulated tokens
token_buffer: Vec<TokenIdType>,
/// Offset where the prefix text starts (for context)
prefix_offset: usize,
/// Offset marking the end of previously decoded text
read_offset: usize,
/// Whether we've stopped /// Whether we've stopped
stopped: bool, stopped: bool,
skip_special_tokens: bool,
} }
impl StopSequenceDecoder { impl StopSequenceDecoder {
...@@ -80,14 +75,10 @@ impl StopSequenceDecoder { ...@@ -80,14 +75,10 @@ impl StopSequenceDecoder {
skip_special_tokens: bool, skip_special_tokens: bool,
) -> Self { ) -> Self {
StopSequenceDecoder { StopSequenceDecoder {
tokenizer, sequence: Sequence::new_with_options(tokenizer, skip_special_tokens),
config, config,
jail_buffer: String::new(), jail_buffer: String::new(),
token_buffer: Vec::new(),
prefix_offset: 0,
read_offset: 0,
stopped: false, stopped: false,
skip_special_tokens,
} }
} }
...@@ -115,57 +106,24 @@ impl StopSequenceDecoder { ...@@ -115,57 +106,24 @@ impl StopSequenceDecoder {
// Include jailed text plus the stop token // Include jailed text plus the stop token
let stop_text = self let stop_text = self
.tokenizer .sequence
.decode(&[token_id], self.skip_special_tokens)?; .tokenizer()
.decode(&[token_id], self.sequence.skip_special_tokens())?;
let output = format!("{}{}", self.jail_buffer, stop_text); let output = format!("{}{}", self.jail_buffer, stop_text);
self.jail_buffer.clear(); self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output)); return Ok(SequenceDecoderOutput::StoppedWithText(output));
} }
// Add token to buffer // Use Sequence for incremental decoding
self.token_buffer.push(token_id); let new_text = self.sequence.append_token(token_id)?;
// Use incremental decoding like DecodeStream self.jail_buffer.push_str(&new_text);
// First decode the previous context (what we've already output)
let prefix_text = if self.read_offset > self.prefix_offset {
self.tokenizer.decode(
&self.token_buffer[self.prefix_offset..self.read_offset],
self.skip_special_tokens,
)?
} else {
String::new()
};
// Now decode from prefix to current position
let new_full_text = self.tokenizer.decode(
&self.token_buffer[self.prefix_offset..],
self.skip_special_tokens,
)?;
// Check for incomplete UTF-8 sequence
if new_full_text.ends_with("�") {
// Wait for more tokens to complete the sequence
return Ok(SequenceDecoderOutput::Held);
}
// Calculate only the NEW text since last successful decode // Check for hidden stop sequences
let new_text = if new_full_text.len() > prefix_text.len() {
&new_full_text[prefix_text.len()..]
} else {
// No new text produced (can happen with special tokens)
return Ok(SequenceDecoderOutput::Held);
};
// Combine jail buffer with new text for checking
let check_text = format!("{}{}", self.jail_buffer, new_text);
// Check for complete stop sequences
for stop_seq in &self.config.stop_sequences { for stop_seq in &self.config.stop_sequences {
if let Some(pos) = check_text.find(stop_seq) { if let Some(pos) = self.jail_buffer.find(stop_seq) {
self.stopped = true; self.stopped = true;
let output = self.jail_buffer[..pos].to_string();
// Output text before the stop sequence
let output = check_text[..pos].to_string();
self.jail_buffer.clear(); self.jail_buffer.clear();
return Ok(if output.is_empty() { return Ok(if output.is_empty() {
SequenceDecoderOutput::Stopped SequenceDecoderOutput::Stopped
...@@ -177,58 +135,54 @@ impl StopSequenceDecoder { ...@@ -177,58 +135,54 @@ impl StopSequenceDecoder {
// Check for visible stop sequences // Check for visible stop sequences
for stop_seq in &self.config.visible_stop_sequences { for stop_seq in &self.config.visible_stop_sequences {
if let Some(pos) = check_text.find(stop_seq) { if let Some(pos) = self.jail_buffer.find(stop_seq) {
self.stopped = true; self.stopped = true;
// Include the stop sequence in output
let end_pos = pos + stop_seq.len(); let end_pos = pos + stop_seq.len();
let output = check_text[..end_pos].to_string(); let output = self.jail_buffer[..end_pos].to_string();
self.jail_buffer.clear(); self.jail_buffer.clear();
return Ok(SequenceDecoderOutput::StoppedWithText(output)); return Ok(SequenceDecoderOutput::StoppedWithText(output));
} }
} }
// Check for partial matches at the end of check_text // Check for partial matches: is the end of jail_buffer the start of any stop_seq?
let mut partial_match_len = 0; // This handles stop sequences split across tokens
let mut longest_partial = 0;
for stop_seq in self for stop_seq in self
.config .config
.stop_sequences .stop_sequences
.iter() .iter()
.chain(&self.config.visible_stop_sequences) .chain(&self.config.visible_stop_sequences)
{ {
// Check all possible suffixes that could be a prefix of stop_seq // Check suffixes of jail_buffer that match prefixes of stop_seq
for i in 1..=check_text.len().min(stop_seq.len() - 1) { // We check up to stop_seq.len() - 1 to avoid rechecking exact matches
let suffix = &check_text[check_text.len() - i..]; let max_len = self.jail_buffer.len().min(stop_seq.len() - 1);
for len in 1..=max_len {
let suffix = &self.jail_buffer[self.jail_buffer.len() - len..];
if stop_seq.starts_with(suffix) { if stop_seq.starts_with(suffix) {
partial_match_len = partial_match_len.max(i); longest_partial = longest_partial.max(len);
} }
} }
} }
if partial_match_len > 0 { if longest_partial > 0 {
// Split: output safe text, jail the potential match // Hold the partial match, flush the rest
let safe_end = check_text.len() - partial_match_len; let split_pos = self.jail_buffer.len() - longest_partial;
let safe_text = &check_text[..safe_end]; let to_output = self.jail_buffer[..split_pos].to_string();
self.jail_buffer = check_text[safe_end..].to_string(); self.jail_buffer = self.jail_buffer[split_pos..].to_string();
// Update offsets for next iteration
self.prefix_offset = self.read_offset;
self.read_offset = self.token_buffer.len();
if safe_text.is_empty() { if to_output.is_empty() {
Ok(SequenceDecoderOutput::Held) Ok(SequenceDecoderOutput::Held)
} else { } else {
Ok(SequenceDecoderOutput::Text(safe_text.to_string())) Ok(SequenceDecoderOutput::Text(to_output))
} }
} else { } else {
// No partial matches - output everything // No partial matches - flush everything
self.jail_buffer.clear(); let output = std::mem::take(&mut self.jail_buffer);
if output.is_empty() {
// Update offsets for next iteration Ok(SequenceDecoderOutput::Held)
self.prefix_offset = self.read_offset; } else {
self.read_offset = self.token_buffer.len(); Ok(SequenceDecoderOutput::Text(output))
}
Ok(SequenceDecoderOutput::Text(check_text))
} }
} }
...@@ -263,9 +217,7 @@ impl StopSequenceDecoder { ...@@ -263,9 +217,7 @@ impl StopSequenceDecoder {
/// Reset the decoder state /// Reset the decoder state
pub fn reset(&mut self) { pub fn reset(&mut self) {
self.jail_buffer.clear(); self.jail_buffer.clear();
self.token_buffer.clear(); self.sequence.clear();
self.prefix_offset = 0;
self.read_offset = 0;
self.stopped = false; self.stopped = false;
} }
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment