Unverified Commit 2cabf441 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

feat: Decoder clean-up for handling incomplete multi-byte sequence (#8022)

parent 20ce329b
...@@ -638,7 +638,7 @@ mod tests { ...@@ -638,7 +638,7 @@ mod tests {
&self, &self,
_token_ids: &[TokenIdType], _token_ids: &[TokenIdType],
_skip_special_tokens: bool, _skip_special_tokens: bool,
) -> anyhow::Result<String> { ) -> anyhow::Result<traits::DecodeResult> {
Err(anyhow::anyhow!( Err(anyhow::anyhow!(
"Unable to decode into a valid UTF-8 string: incomplete utf-8 byte sequence from index 6" "Unable to decode into a valid UTF-8 string: incomplete utf-8 byte sequence from index 6"
)) ))
......
...@@ -19,6 +19,7 @@ pub use anyhow::{Error, Result}; ...@@ -19,6 +19,7 @@ pub use anyhow::{Error, Result};
pub use fastokens::FastTokenizer; pub use fastokens::FastTokenizer;
pub use hf::HuggingFaceTokenizer; pub use hf::HuggingFaceTokenizer;
pub use tiktoken::TikTokenTokenizer; pub use tiktoken::TikTokenTokenizer;
pub use traits::DecodeResult;
/// Represents the type of tokenizer being used /// Represents the type of tokenizer being used
#[derive(Debug)] #[derive(Debug)]
...@@ -62,12 +63,66 @@ pub mod traits { ...@@ -62,12 +63,66 @@ pub mod traits {
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>; fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
} }
/// Implementations **must** use lossy UTF-8 conversion (e.g. `String::from_utf8_lossy`) /// Result of decoding token IDs to text.
/// so that partial multi-byte sequences produce U+FFFD (`�`) rather than returning `Err`. ///
/// `DecodeStream::step()` relies on the replacement character to detect incomplete /// Distinguishes between fully valid UTF-8 output and output that contains
/// sequences and buffer tokens until the full character arrives. /// trailing incomplete multi-byte sequences (represented as U+FFFD).
/// This lets callers like `DecodeStream::step()` decide whether to emit or
/// buffer without resorting to hardcoded replacement-character string checks.
#[derive(Debug, Clone, PartialEq, Eq, strum::EnumIs)]
pub enum DecodeResult {
/// No trailing incomplete multi-byte sequences (text does not end with U+FFFD).
/// Note: the string may still contain *interior* U+FFFD characters from
/// mid-stream invalid byte sequences; only trailing status is tracked here.
Complete(String),
/// The decoded string ends with U+FFFD, indicating incomplete trailing
/// multi-byte bytes that may be completed by subsequent tokens.
Partial(String),
}
impl DecodeResult {
/// Returns a reference to the inner string.
pub fn as_str(&self) -> &str {
match self {
DecodeResult::Complete(s) | DecodeResult::Partial(s) => s,
}
}
/// Construct from a decoded string: `Partial` if it ends with U+FFFD, else `Complete`.
pub fn from_decoded(text: String) -> Self {
if text.ends_with('\u{FFFD}') {
DecodeResult::Partial(text)
} else {
DecodeResult::Complete(text)
}
}
}
impl From<String> for DecodeResult {
fn from(text: String) -> Self {
DecodeResult::from_decoded(text)
}
}
impl From<DecodeResult> for String {
fn from(result: DecodeResult) -> Self {
match result {
DecodeResult::Complete(s) | DecodeResult::Partial(s) => s,
}
}
}
/// Implementations must ensure that partial multi-byte sequences produce U+FFFD
/// (`\u{FFFD}`) in the output rather than returning `Err`. This is commonly achieved
/// via `String::from_utf8_lossy` (tiktoken) or library-internal byte-fallback handling
/// (HuggingFace). `DecodeStream::step()` relies on `DecodeResult::Partial` to detect
/// incomplete sequences and buffer tokens until the full character arrives.
pub trait Decoder: Send + Sync { pub trait Decoder: Send + Sync {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>; fn decode(
&self,
token_ids: &[TokenIdType],
skip_special_tokens: bool,
) -> Result<DecodeResult>;
} }
pub trait Tokenizer: Encoder + Decoder { pub trait Tokenizer: Encoder + Decoder {
...@@ -219,23 +274,27 @@ impl DecodeStream { ...@@ -219,23 +274,27 @@ impl DecodeStream {
pub fn step(&mut self, id: u32) -> Result<Option<String>> { pub fn step(&mut self, id: u32) -> Result<Option<String>> {
self.all_token_ids.push(id); self.all_token_ids.push(id);
let prefix_text = self.tokenizer.decode( let prefix_text: String = self
&self.all_token_ids[self.prefix_offset..self.read_offset], .tokenizer
self.skip_special_tokens, .decode(
)?; &self.all_token_ids[self.prefix_offset..self.read_offset],
self.skip_special_tokens,
)?
.into();
let new_text = self.tokenizer.decode( let new_result = self.tokenizer.decode(
&self.all_token_ids[self.prefix_offset..], &self.all_token_ids[self.prefix_offset..],
self.skip_special_tokens, self.skip_special_tokens,
)?; )?;
if new_text.len() > prefix_text.len() && !new_text.ends_with("�") { let new_text = new_result.as_str();
let new_text = new_text[prefix_text.len()..].to_string(); if new_text.len() > prefix_text.len() && !new_result.is_partial() {
let emitted = new_text[prefix_text.len()..].to_string();
self.prefix_offset = self.read_offset; self.prefix_offset = self.read_offset;
self.read_offset = self.all_token_ids.len(); self.read_offset = self.all_token_ids.len();
Ok(Some(new_text)) Ok(Some(emitted))
} else { } else {
Ok(None) Ok(None)
} }
...@@ -322,14 +381,17 @@ impl Sequence { ...@@ -322,14 +381,17 @@ impl Sequence {
self.token_ids.push(token_id); self.token_ids.push(token_id);
// log::trace!("pushed token_id: {}", token_id); // log::trace!("pushed token_id: {}", token_id);
let prefix_text = self let prefix_text: String = self
.tokenizer .tokenizer
.decode(&self.token_ids[self.prefix_offset..self.read_offset], false)?; .decode(&self.token_ids[self.prefix_offset..self.read_offset], false)?
.into();
let new_text = self let new_result = self
.tokenizer .tokenizer
.decode(&self.token_ids[self.prefix_offset..], false)?; .decode(&self.token_ids[self.prefix_offset..], false)?;
let new_text = new_result.as_str();
// if the end character of the previous returned sequence is a multi-byte character // if the end character of the previous returned sequence is a multi-byte character
// then we can not split the text on that byte offset, so we roll back to the byte offset // then we can not split the text on that byte offset, so we roll back to the byte offset
// of the start of that character // of the start of that character
...@@ -340,11 +402,13 @@ impl Sequence { ...@@ -340,11 +402,13 @@ impl Sequence {
let prefix_text_len = prefix_text_len; let prefix_text_len = prefix_text_len;
if new_text.len() > prefix_text.len() { if new_text.len() > prefix_text.len() {
if new_text.ends_with("�") { if new_result.is_partial() {
return Ok("".to_string()); return Ok("".to_string());
} else { } else {
// shift and update the state // shift and update the state
let new_text = new_text[prefix_text_len..].to_string().replace("�", ""); let new_text = new_text[prefix_text_len..]
.to_string()
.replace('\u{FFFD}', "");
self.prefix_offset = self.read_offset; self.prefix_offset = self.read_offset;
self.read_offset = self.token_ids.len(); self.read_offset = self.token_ids.len();
return Ok(new_text); return Ok(new_text);
...@@ -366,7 +430,7 @@ impl Sequence { ...@@ -366,7 +430,7 @@ impl Sequence {
// let tokenizer = self.tokenizer.read().map_err(|err| { // let tokenizer = self.tokenizer.read().map_err(|err| {
// Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err)) // Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
// })?; // })?;
self.tokenizer.decode(&self.token_ids, false) Ok(self.tokenizer.decode(&self.token_ids, false)?.into())
} }
} }
......
...@@ -14,7 +14,7 @@ use rayon::prelude::*; ...@@ -14,7 +14,7 @@ use rayon::prelude::*;
use super::{ use super::{
Encoding, Error, Result, TokenIdType, Encoding, Error, Result, TokenIdType,
hf::HuggingFaceTokenizer, hf::HuggingFaceTokenizer,
traits::{Decoder, Encoder, Tokenizer}, traits::{DecodeResult, Decoder, Encoder, Tokenizer},
}; };
/// Hybrid tokenizer: fast BPE encoding via `fastokens`, decoding via HuggingFace. /// Hybrid tokenizer: fast BPE encoding via `fastokens`, decoding via HuggingFace.
...@@ -52,7 +52,7 @@ impl Encoder for FastTokenizer { ...@@ -52,7 +52,7 @@ impl Encoder for FastTokenizer {
} }
impl Decoder for FastTokenizer { impl Decoder for FastTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> { fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<DecodeResult> {
self.hf_decoder.decode(token_ids, skip_special_tokens) self.hf_decoder.decode(token_ids, skip_special_tokens)
} }
} }
...@@ -81,7 +81,7 @@ mod tests { ...@@ -81,7 +81,7 @@ mod tests {
let text = "Hello, world!"; let text = "Hello, world!";
let encoding = tokenizer.encode(text).unwrap(); let encoding = tokenizer.encode(text).unwrap();
assert!(!encoding.token_ids().is_empty()); assert!(!encoding.token_ids().is_empty());
let decoded = tokenizer.decode(encoding.token_ids(), true).unwrap(); let decoded: String = tokenizer.decode(encoding.token_ids(), true).unwrap().into();
assert!(!decoded.is_empty()); assert!(!decoded.is_empty());
// The decoded text should contain the same non-space characters // The decoded text should contain the same non-space characters
let enc_chars: String = text.chars().filter(|c| !c.is_whitespace()).collect(); let enc_chars: String = text.chars().filter(|c| !c.is_whitespace()).collect();
...@@ -149,8 +149,8 @@ mod tests { ...@@ -149,8 +149,8 @@ mod tests {
// decode(continuation) which lacks the surrounding context. // decode(continuation) which lacks the surrounding context.
let mut all_ids = prompt_ids.clone(); let mut all_ids = prompt_ids.clone();
all_ids.extend_from_slice(&cont_ids); all_ids.extend_from_slice(&cont_ids);
let full_text = wrapper.decode(&all_ids, true).unwrap(); let full_text: String = wrapper.decode(&all_ids, true).unwrap().into();
let prompt_text = wrapper.decode(&prompt_ids, true).unwrap(); let prompt_text: String = wrapper.decode(&prompt_ids, true).unwrap().into();
let expected = &full_text[prompt_text.len()..]; let expected = &full_text[prompt_text.len()..];
assert_eq!( assert_eq!(
accumulated, expected, accumulated, expected,
......
...@@ -5,7 +5,7 @@ use tokenizers::tokenizer::Tokenizer as HfTokenizer; ...@@ -5,7 +5,7 @@ use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use super::{ use super::{
Encoding, Error, Result, TokenIdType, Encoding, Error, Result, TokenIdType,
traits::{Decoder, Encoder, Tokenizer}, traits::{DecodeResult, Decoder, Encoder, Tokenizer},
}; };
pub struct HuggingFaceTokenizer { pub struct HuggingFaceTokenizer {
...@@ -52,14 +52,14 @@ impl Encoder for HuggingFaceTokenizer { ...@@ -52,14 +52,14 @@ impl Encoder for HuggingFaceTokenizer {
} }
impl Decoder for HuggingFaceTokenizer { impl Decoder for HuggingFaceTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> { fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<DecodeResult> {
// This calls into the library // This calls into the library
let text = self let text = self
.tokenizer .tokenizer
.decode(token_ids, skip_special_tokens) .decode(token_ids, skip_special_tokens)
.map_err(|err| Error::msg(format!("Error de-tokenizing input: {err}")))?; .map_err(|err| Error::msg(format!("Error de-tokenizing input: {err}")))?;
Ok(text) Ok(text.into())
} }
} }
......
...@@ -11,7 +11,7 @@ use tiktoken_rs::CoreBPE; ...@@ -11,7 +11,7 @@ use tiktoken_rs::CoreBPE;
use super::{ use super::{
Encoding, Error, Result, TokenIdType, Encoding, Error, Result, TokenIdType,
traits::{Decoder, Encoder, Tokenizer}, traits::{DecodeResult, Decoder, Encoder, Tokenizer},
}; };
/// Number of reserved special-token slots to generate when filling gaps in the vocabulary. /// Number of reserved special-token slots to generate when filling gaps in the vocabulary.
...@@ -89,7 +89,7 @@ impl Encoder for TikTokenTokenizer { ...@@ -89,7 +89,7 @@ impl Encoder for TikTokenTokenizer {
} }
impl Decoder for TikTokenTokenizer { impl Decoder for TikTokenTokenizer {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String> { fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<DecodeResult> {
let ids: Vec<u32> = if skip_special_tokens { let ids: Vec<u32> = if skip_special_tokens {
token_ids token_ids
.iter() .iter()
...@@ -100,12 +100,22 @@ impl Decoder for TikTokenTokenizer { ...@@ -100,12 +100,22 @@ impl Decoder for TikTokenTokenizer {
token_ids.to_vec() token_ids.to_vec()
}; };
// Use lossy UTF-8 conversion so that partial multi-byte sequences become U+FFFD (�). // Try strict UTF-8 first: valid bytes get `Complete` with zero extra allocation
// This is critical for incremental detokenization: DecodeStream::step() relies on // (takes ownership of the Vec). This correctly handles vocabulary tokens whose
// the replacement character to detect incomplete sequences and buffer tokens until // raw bytes are EF BF BD (legitimate U+FFFD) -- they are valid UTF-8 and must
// a complete character arrives. CoreBPE::decode() would error on invalid UTF-8 instead. // not be confused with incomplete multi-byte sequences.
//
// On failure, fall back to lossy conversion so partial multi-byte sequences
// become U+FFFD, then classify via the trailing-FFFD heuristic. This path is
// only hit during incremental detokenization of byte-fallback tokens.
let bytes: Vec<u8> = self.bpe._decode_native_and_split(ids).flatten().collect(); let bytes: Vec<u8> = self.bpe._decode_native_and_split(ids).flatten().collect();
Ok(String::from_utf8_lossy(&bytes).into_owned()) match String::from_utf8(bytes) {
Ok(text) => Ok(DecodeResult::Complete(text)),
Err(e) => {
let text = String::from_utf8_lossy(e.as_bytes()).into_owned();
Ok(DecodeResult::from_decoded(text))
}
}
} }
} }
...@@ -351,7 +361,7 @@ mod tests { ...@@ -351,7 +361,7 @@ mod tests {
assert!(!ids.is_empty()); assert!(!ids.is_empty());
// Test decode roundtrip // Test decode roundtrip
let decoded = tokenizer.decode(ids, false).unwrap(); let decoded: String = tokenizer.decode(ids, false).unwrap().into();
assert_eq!(decoded, "hello world"); assert_eq!(decoded, "hello world");
} }
...@@ -393,11 +403,11 @@ mod tests { ...@@ -393,11 +403,11 @@ mod tests {
ids.push(22); // [EOS] ids.push(22); // [EOS]
// Decode with skip_special_tokens=true should strip special tokens // Decode with skip_special_tokens=true should strip special tokens
let decoded_skip = tokenizer.decode(&ids, true).unwrap(); let decoded_skip: String = tokenizer.decode(&ids, true).unwrap().into();
assert_eq!(decoded_skip, "hello"); assert_eq!(decoded_skip, "hello");
// Decode with skip_special_tokens=false should include them // Decode with skip_special_tokens=false should include them
let decoded_all = tokenizer.decode(&ids, false).unwrap(); let decoded_all: String = tokenizer.decode(&ids, false).unwrap().into();
assert!(decoded_all.contains("hello")); assert!(decoded_all.contains("hello"));
} }
...@@ -416,7 +426,7 @@ mod tests { ...@@ -416,7 +426,7 @@ mod tests {
let ids = encoding.token_ids(); let ids = encoding.token_ids();
assert!(!ids.is_empty()); assert!(!ids.is_empty());
let decoded = tokenizer.decode(ids, false).unwrap(); let decoded: String = tokenizer.decode(ids, false).unwrap().into();
assert_eq!(decoded, "hello world"); assert_eq!(decoded, "hello world");
} }
...@@ -490,6 +500,15 @@ mod tests { ...@@ -490,6 +500,15 @@ mod tests {
content.push_str(&format!("{encoded} {rank}\n")); content.push_str(&format!("{encoded} {rank}\n"));
} }
// Legitimate U+FFFD token: valid UTF-8 bytes EF BF BD (replacement character
// as an actual vocabulary entry, not an artifact of lossy conversion)
let fffd_token: Vec<(Vec<u8>, u32)> = vec![(vec![0xEF, 0xBF, 0xBD], 300)];
for (token, rank) in &fffd_token {
let encoded = engine.encode(token);
content.push_str(&format!("{encoded} {rank}\n"));
}
let file_path = dir.join("tiktoken.model"); let file_path = dir.join("tiktoken.model");
let mut file = std::fs::File::create(&file_path).unwrap(); let mut file = std::fs::File::create(&file_path).unwrap();
file.write_all(content.as_bytes()).unwrap(); file.write_all(content.as_bytes()).unwrap();
...@@ -516,11 +535,11 @@ mod tests { ...@@ -516,11 +535,11 @@ mod tests {
result.is_ok(), result.is_ok(),
"decode() should not error on incomplete UTF-8 bytes" "decode() should not error on incomplete UTF-8 bytes"
); );
let text = result.unwrap(); let decode_result = result.unwrap();
assert!( assert!(
text.contains('\u{FFFD}'), decode_result.is_partial(),
"incomplete UTF-8 byte should produce replacement character, got: {:?}", "incomplete UTF-8 byte should produce DecodeResult::Partial, got: {:?}",
text decode_result
); );
} }
...@@ -532,11 +551,11 @@ mod tests { ...@@ -532,11 +551,11 @@ mod tests {
let result = tokenizer.decode(&[100, 101], false); let result = tokenizer.decode(&[100, 101], false);
assert!(result.is_ok()); assert!(result.is_ok());
let text = result.unwrap(); let decode_result = result.unwrap();
assert!( assert!(
text.contains('\u{FFFD}'), decode_result.is_partial(),
"incomplete 2-of-3 UTF-8 bytes should produce replacement character, got: {:?}", "incomplete 2-of-3 UTF-8 bytes should produce DecodeResult::Partial, got: {:?}",
text decode_result
); );
} }
...@@ -550,7 +569,7 @@ mod tests { ...@@ -550,7 +569,7 @@ mod tests {
let result = tokenizer.decode(&[100, 101, 102], false); let result = tokenizer.decode(&[100, 101, 102], false);
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), "你"); assert_eq!(String::from(result.unwrap()), "你");
} }
/// All 4 emoji bytes together form valid UTF-8, so this passes both before and after /// All 4 emoji bytes together form valid UTF-8, so this passes both before and after
...@@ -562,7 +581,27 @@ mod tests { ...@@ -562,7 +581,27 @@ mod tests {
let result = tokenizer.decode(&[200, 201, 202, 203], false); let result = tokenizer.decode(&[200, 201, 202, 203], false);
assert!(result.is_ok()); assert!(result.is_ok());
assert_eq!(result.unwrap(), "😀"); assert_eq!(String::from(result.unwrap()), "😀");
}
/// Regression test: a vocabulary token whose raw bytes are EF BF BD (the valid
/// UTF-8 encoding of U+FFFD) must decode as `Complete`, not `Partial`. Before the
/// from_utf8 fast-path fix, from_utf8_lossy + the trailing-FFFD heuristic would
/// misclassify this as Partial, causing the incremental decoder to suppress it.
#[test]
fn test_decode_legitimate_replacement_char_token_is_complete() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[300], false);
assert!(result.is_ok());
let decode_result = result.unwrap();
assert!(
decode_result.is_complete(),
"legitimate U+FFFD vocab token must be Complete, got: {:?}",
decode_result
);
assert_eq!(decode_result.as_str(), "\u{FFFD}");
} }
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode(). /// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
...@@ -573,7 +612,7 @@ mod tests { ...@@ -573,7 +612,7 @@ mod tests {
let result = tokenizer.decode(&[200], false); let result = tokenizer.decode(&[200], false);
assert!(result.is_ok()); assert!(result.is_ok());
assert!(result.unwrap().contains('\u{FFFD}')); assert!(result.unwrap().is_partial());
} }
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode(). /// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
...@@ -584,16 +623,17 @@ mod tests { ...@@ -584,16 +623,17 @@ mod tests {
let result = tokenizer.decode(&[5, 100], false); let result = tokenizer.decode(&[5, 100], false);
assert!(result.is_ok()); assert!(result.is_ok());
let text = result.unwrap(); let decode_result = result.unwrap();
assert!(
decode_result.is_partial(),
"trailing incomplete byte should produce DecodeResult::Partial"
);
let text: String = decode_result.into();
assert!( assert!(
text.starts_with("hello"), text.starts_with("hello"),
"should start with 'hello', got: {:?}", "should start with 'hello', got: {:?}",
text text
); );
assert!(
text.contains('\u{FFFD}'),
"trailing incomplete byte should produce U+FFFD"
);
} }
/// End-to-end incremental detokenization: DecodeStream buffers partial bytes, /// End-to-end incremental detokenization: DecodeStream buffers partial bytes,
...@@ -656,7 +696,10 @@ mod tests { ...@@ -656,7 +696,10 @@ mod tests {
assert_eq!(encodings.len(), 2); assert_eq!(encodings.len(), 2);
for (encoding, input) in encodings.iter().zip(inputs.iter()) { for (encoding, input) in encodings.iter().zip(inputs.iter()) {
let decoded = tokenizer.decode(encoding.token_ids(), false).unwrap(); let decoded: String = tokenizer
.decode(encoding.token_ids(), false)
.unwrap()
.into();
assert_eq!(decoded, *input); assert_eq!(decoded, *input);
} }
} }
......
...@@ -25,8 +25,8 @@ impl tokenizer_traits::Encoder for TestTokenizer { ...@@ -25,8 +25,8 @@ impl tokenizer_traits::Encoder for TestTokenizer {
} }
impl tokenizer_traits::Decoder for TestTokenizer { impl tokenizer_traits::Decoder for TestTokenizer {
fn decode(&self, ids: &[u32], skip_special: bool) -> Result<String> { fn decode(&self, ids: &[u32], skip_special: bool) -> Result<tokenizer_traits::DecodeResult> {
Ok(ids let text: String = ids
.iter() .iter()
.filter_map(|&id| match id { .filter_map(|&id| match id {
EOS if skip_special => None, EOS if skip_special => None,
...@@ -36,7 +36,8 @@ impl tokenizer_traits::Decoder for TestTokenizer { ...@@ -36,7 +36,8 @@ impl tokenizer_traits::Decoder for TestTokenizer {
EOS => Some("</s>"), EOS => Some("</s>"),
_ => Some("?"), _ => Some("?"),
}) })
.collect()) .collect();
Ok(text.into())
} }
} }
......
...@@ -113,9 +113,10 @@ fn test_encode_decode_roundtrip(#[case] tokenizer: Arc<dyn Tokenizer>) { ...@@ -113,9 +113,10 @@ fn test_encode_decode_roundtrip(#[case] tokenizer: Arc<dyn Tokenizer>) {
.unwrap_or_else(|e| panic!("Failed to encode '{text}': {e}")); .unwrap_or_else(|e| panic!("Failed to encode '{text}': {e}"));
assert!(!encoding.token_ids().is_empty()); assert!(!encoding.token_ids().is_empty());
let decoded = tokenizer let decoded: String = tokenizer
.decode(encoding.token_ids(), false) .decode(encoding.token_ids(), false)
.unwrap_or_else(|e| panic!("Failed to decode '{text}': {e}")); .unwrap_or_else(|e| panic!("Failed to decode '{text}': {e}"))
.into();
assert_eq!(decoded, text, "Roundtrip failed for: '{text}'"); assert_eq!(decoded, text, "Roundtrip failed for: '{text}'");
} }
} }
...@@ -129,9 +130,10 @@ fn test_encode_decode_roundtrip_multibyte(#[case] tokenizer: Arc<dyn Tokenizer>) ...@@ -129,9 +130,10 @@ fn test_encode_decode_roundtrip_multibyte(#[case] tokenizer: Arc<dyn Tokenizer>)
.encode(text) .encode(text)
.unwrap_or_else(|e| panic!("Failed to encode '{text}': {e}")); .unwrap_or_else(|e| panic!("Failed to encode '{text}': {e}"));
let decoded = tokenizer let decoded: String = tokenizer
.decode(encoding.token_ids(), false) .decode(encoding.token_ids(), false)
.unwrap_or_else(|e| panic!("Failed to decode '{text}': {e}")); .unwrap_or_else(|e| panic!("Failed to decode '{text}': {e}"))
.into();
assert_eq!(decoded, text, "Roundtrip failed for: '{text}'"); assert_eq!(decoded, text, "Roundtrip failed for: '{text}'");
} }
} }
...@@ -147,9 +149,10 @@ fn test_batch_encode_roundtrip(#[case] tokenizer: Arc<dyn Tokenizer>) { ...@@ -147,9 +149,10 @@ fn test_batch_encode_roundtrip(#[case] tokenizer: Arc<dyn Tokenizer>) {
assert_eq!(encodings.len(), inputs.len()); assert_eq!(encodings.len(), inputs.len());
for (encoding, &input) in encodings.iter().zip(inputs.iter()) { for (encoding, &input) in encodings.iter().zip(inputs.iter()) {
let decoded = tokenizer let decoded: String = tokenizer
.decode(encoding.token_ids(), false) .decode(encoding.token_ids(), false)
.expect("Failed to decode"); .expect("Failed to decode")
.into();
assert_eq!(decoded, input); assert_eq!(decoded, input);
} }
} }
...@@ -354,14 +357,16 @@ fn test_decode_with_skip_special_tokens() { ...@@ -354,14 +357,16 @@ fn test_decode_with_skip_special_tokens() {
token_ids.push(2); // </s> token_ids.push(2); // </s>
// Decode with skip_special_tokens = false (should keep special tokens) // Decode with skip_special_tokens = false (should keep special tokens)
let decoded_with_special = tokenizer let decoded_with_special: String = tokenizer
.decode(&token_ids, false) .decode(&token_ids, false)
.expect("Failed to decode with skip_special_tokens=false"); .expect("Failed to decode with skip_special_tokens=false")
.into();
// Decode with skip_special_tokens = true (should remove special tokens) // Decode with skip_special_tokens = true (should remove special tokens)
let decoded_without_special = tokenizer let decoded_without_special: String = tokenizer
.decode(&token_ids, true) .decode(&token_ids, true)
.expect("Failed to decode with skip_special_tokens=true"); .expect("Failed to decode with skip_special_tokens=true")
.into();
// Validate exact matches on the entire decoded strings // Validate exact matches on the entire decoded strings
assert_eq!(decoded_with_special, "<s> Hello world</s>"); assert_eq!(decoded_with_special, "<s> Hello world</s>");
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment