Unverified Commit d71c0c32 authored by Biswa Panda's avatar Biswa Panda Committed by GitHub
Browse files

fix: Kimi K2.5 - tiktoken incomplete multi-byte sequence handling with regression tests (#6996)

parent fc2d01d1
......@@ -182,7 +182,20 @@ impl
let data = output.data.as_ref().unwrap();
let result = state.decoder.process_token_ids(&data.token_ids).unwrap();
let result = match state.decoder.process_token_ids(&data.token_ids) {
Ok(result) => result,
Err(e) => {
tracing::error!("Failed to process token_ids: {e}");
state.stream.context().stop_generating();
state.finished = true;
let mut output = output;
if let Some(data) = &mut output.data {
data.finish_reason =
Some(FinishReason::Error(format!("decode error: {e}")));
}
return Some((output, state));
}
};
// NOTE: the `finish_reason` is computed from the generated `token_ids` alone.
// The `data` field can have a `finish_reason` set, coming from the underlying
......@@ -583,11 +596,14 @@ impl Decoder {
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizers::traits;
use std::sync::Arc;
#[test]
fn test_char_boundary_drain() {
use super::Decoder;
let mut s = String::from("helloñworld"); // 12 bytes total ñ is 2 bytes
let max_bytes = 6; // 12 - 6 = 6 which is inside ñ
assert!(!s.is_char_boundary(s.len() - max_bytes)); // initially we are not on a char boundary
......@@ -595,4 +611,88 @@ mod tests {
assert!(s.is_char_boundary(0)); // front of jail string on valid char boundary
assert_eq!(s, "ñworld");
}
/// A mock tokenizer that always returns Err from decode().
/// Used to test the error propagation path in Decoder::process_token_ids().
struct FailingDecoder;
impl traits::Encoder for FailingDecoder {
fn encode(&self, _input: &str) -> anyhow::Result<crate::tokenizers::Encoding> {
Ok(crate::tokenizers::Encoding::Sp(vec![]))
}
fn encode_batch(
&self,
_inputs: &[&str],
) -> anyhow::Result<Vec<crate::tokenizers::Encoding>> {
Ok(vec![])
}
}
impl traits::Decoder for FailingDecoder {
fn decode(
&self,
_token_ids: &[TokenIdType],
_skip_special_tokens: bool,
) -> anyhow::Result<String> {
Err(anyhow::anyhow!(
"Unable to decode into a valid UTF-8 string: incomplete utf-8 byte sequence from index 6"
))
}
}
impl traits::Tokenizer for FailingDecoder {}
/// When the tokenizer's decode() returns Err, Decoder::process_token_ids()
/// should propagate the error. In the backend unfold closure, this error
/// gets caught and converted to FinishReason::Error.
#[test]
fn test_decoder_process_token_ids_propagates_decode_error() {
let tokenizer: Arc<dyn traits::Tokenizer> = Arc::new(FailingDecoder);
let decode_stream = crate::tokenizers::DecodeStream::new(tokenizer, &[], false);
let stop_conditions = StopConditions::default();
let mut decoder = Decoder::new(decode_stream, stop_conditions, false, None);
let result = decoder.process_token_ids(&[42]);
assert!(
result.is_err(),
"process_token_ids should propagate decode errors"
);
let err_msg = result.err().unwrap().to_string();
assert!(
err_msg.contains("incomplete utf-8 byte sequence"),
"error should contain the original decode error message, got: {err_msg}"
);
}
/// Verify that the error message format matches what the backend unfold
/// closure would wrap into FinishReason::Error.
#[test]
fn test_decoder_error_message_format_for_finish_reason() {
let tokenizer: Arc<dyn traits::Tokenizer> = Arc::new(FailingDecoder);
let decode_stream = crate::tokenizers::DecodeStream::new(tokenizer, &[], false);
let stop_conditions = StopConditions::default();
let mut decoder = Decoder::new(decode_stream, stop_conditions, false, None);
let result = decoder.process_token_ids(&[42]);
let err = result.err().expect("should be Err");
// This is what the backend unfold closure does:
let finish_reason = FinishReason::Error(format!("decode error: {err}"));
match &finish_reason {
FinishReason::Error(msg) => {
assert!(
msg.starts_with("decode error:"),
"FinishReason::Error should have 'decode error:' prefix, got: {msg}"
);
assert!(
msg.contains("incomplete utf-8 byte sequence"),
"FinishReason::Error should contain original error, got: {msg}"
);
}
other => panic!("Expected FinishReason::Error, got: {:?}", other),
}
}
}
......@@ -60,6 +60,10 @@ pub mod traits {
fn encode_batch(&self, inputs: &[&str]) -> Result<Vec<Encoding>>;
}
/// Implementations **must** use lossy UTF-8 conversion (e.g. `String::from_utf8_lossy`)
/// so that partial multi-byte sequences produce U+FFFD (`�`) rather than returning `Err`.
/// `DecodeStream::step()` relies on the replacement character to detect incomplete
/// sequences and buffer tokens until the full character arrives.
pub trait Decoder: Send + Sync {
fn decode(&self, token_ids: &[TokenIdType], skip_special_tokens: bool) -> Result<String>;
}
......
......@@ -100,9 +100,12 @@ impl Decoder for TikTokenTokenizer {
token_ids.to_vec()
};
self.bpe
.decode(ids)
.map_err(|err| Error::msg(format!("Error decoding tiktoken tokens: {err}")))
// Use lossy UTF-8 conversion so that partial multi-byte sequences become U+FFFD (�).
// This is critical for incremental detokenization: DecodeStream::step() relies on
// the replacement character to detect incomplete sequences and buffer tokens until
// a complete character arrives. CoreBPE::decode() would error on invalid UTF-8 instead.
let bytes: Vec<u8> = self.bpe._decode_native_and_split(ids).flatten().collect();
Ok(String::from_utf8_lossy(&bytes).into_owned())
}
}
......@@ -236,7 +239,9 @@ fn load_special_tokens(directory: &Path, num_base_tokens: usize) -> Result<FxHas
#[cfg(test)]
mod tests {
use super::*;
use crate::tokenizers::DecodeStream;
use std::io::Write;
use std::sync::Arc;
fn create_test_tiktoken_file(dir: &Path) -> String {
let engine = base64::engine::general_purpose::STANDARD;
......@@ -443,6 +448,199 @@ mod tests {
assert!(tokens.len() > 2);
}
/// Helper: create a tiktoken file that includes raw byte tokens (byte fallback tokens).
fn create_test_tiktoken_file_with_byte_tokens(dir: &Path) -> String {
let engine = base64::engine::general_purpose::STANDARD;
let mut content = String::new();
let tokens: Vec<(&[u8], u32)> = vec![
(b"h", 0),
(b"e", 1),
(b"l", 2),
(b"o", 3),
(b" ", 4),
(b"hello", 5),
];
for (token, rank) in &tokens {
let encoded = engine.encode(token);
content.push_str(&format!("{encoded} {rank}\n"));
}
// Byte-fallback tokens: individual bytes that form CJK character "你" (U+4F60)
// UTF-8 encoding: 0xE4 0xBD 0xA0
let byte_tokens: Vec<(Vec<u8>, u32)> =
vec![(vec![0xE4], 100), (vec![0xBD], 101), (vec![0xA0], 102)];
for (token, rank) in &byte_tokens {
let encoded = engine.encode(token);
content.push_str(&format!("{encoded} {rank}\n"));
}
// Bytes for emoji "😀" (U+1F600) — 4-byte UTF-8: 0xF0 0x9F 0x98 0x80
let emoji_tokens: Vec<(Vec<u8>, u32)> = vec![
(vec![0xF0], 200),
(vec![0x9F], 201),
(vec![0x98], 202),
(vec![0x80], 203),
];
for (token, rank) in &emoji_tokens {
let encoded = engine.encode(token);
content.push_str(&format!("{encoded} {rank}\n"));
}
let file_path = dir.join("tiktoken.model");
let mut file = std::fs::File::create(&file_path).unwrap();
file.write_all(content.as_bytes()).unwrap();
file_path.to_str().unwrap().to_string()
}
fn create_byte_token_tokenizer(dir: &Path) -> TikTokenTokenizer {
let file_path = create_test_tiktoken_file_with_byte_tokens(dir);
let special_tokens = FxHashMap::default();
let pattern = r"[\w]+|[^\w\s]+|\s+";
TikTokenTokenizer::from_file(&file_path, pattern, special_tokens).unwrap()
}
/// Reproduces the original panic: decoding a single byte-fallback token that is
/// part of a multi-byte UTF-8 character. Before the fix, CoreBPE::decode() would
/// call String::from_utf8() on [0xE4] and error with "incomplete utf-8 byte sequence".
#[test]
fn test_decode_single_incomplete_utf8_byte_does_not_error() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[100], false);
assert!(
result.is_ok(),
"decode() should not error on incomplete UTF-8 bytes"
);
let text = result.unwrap();
assert!(
text.contains('\u{FFFD}'),
"incomplete UTF-8 byte should produce replacement character, got: {:?}",
text
);
}
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
#[test]
fn test_decode_two_of_three_utf8_bytes_does_not_error() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[100, 101], false);
assert!(result.is_ok());
let text = result.unwrap();
assert!(
text.contains('\u{FFFD}'),
"incomplete 2-of-3 UTF-8 bytes should produce replacement character, got: {:?}",
text
);
}
/// When all bytes of a multi-byte character are present, the concatenated bytes form
/// valid UTF-8, so this test passes both before and after the fix. It serves as a
/// correctness check that the lossy conversion doesn't corrupt complete characters.
#[test]
fn test_decode_complete_multibyte_utf8_produces_correct_char() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[100, 101, 102], false);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "你");
}
/// All 4 emoji bytes together form valid UTF-8, so this passes both before and after
/// the fix. Validates that lossy conversion doesn't alter complete multi-byte sequences.
#[test]
fn test_decode_complete_4byte_emoji_from_byte_tokens() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[200, 201, 202, 203], false);
assert!(result.is_ok());
assert_eq!(result.unwrap(), "😀");
}
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
#[test]
fn test_decode_partial_emoji_does_not_error() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[200], false);
assert!(result.is_ok());
assert!(result.unwrap().contains('\u{FFFD}'));
}
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
#[test]
fn test_decode_mixed_ascii_and_incomplete_bytes() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let result = tokenizer.decode(&[5, 100], false);
assert!(result.is_ok());
let text = result.unwrap();
assert!(
text.starts_with("hello"),
"should start with 'hello', got: {:?}",
text
);
assert!(
text.contains('\u{FFFD}'),
"trailing incomplete byte should produce U+FFFD"
);
}
/// End-to-end incremental detokenization: DecodeStream buffers partial bytes,
/// emits the complete character once all bytes arrive.
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
#[test]
fn test_decode_stream_incremental_multibyte_reassembly() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let tokenizer_arc: Arc<dyn crate::tokenizers::traits::Tokenizer> = Arc::new(tokenizer);
let mut stream = DecodeStream::new(tokenizer_arc, &[5], false);
let r1 = stream.step(100).unwrap();
assert_eq!(r1, None, "first byte of 3-byte char should be buffered");
let r2 = stream.step(101).unwrap();
assert_eq!(r2, None, "second byte of 3-byte char should be buffered");
let r3 = stream.step(102).unwrap();
assert!(r3.is_some(), "third byte should complete the character");
assert_eq!(r3.unwrap(), "你");
}
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
#[test]
fn test_decode_stream_incremental_emoji_reassembly() {
let dir = tempfile::tempdir().unwrap();
let tokenizer = create_byte_token_tokenizer(dir.path());
let tokenizer_arc: Arc<dyn crate::tokenizers::traits::Tokenizer> = Arc::new(tokenizer);
let mut stream = DecodeStream::new(tokenizer_arc, &[5], false);
let r1 = stream.step(200).unwrap();
assert_eq!(r1, None, "byte 1/4 of emoji should be buffered");
let r2 = stream.step(201).unwrap();
assert_eq!(r2, None, "byte 2/4 of emoji should be buffered");
let r3 = stream.step(202).unwrap();
assert_eq!(r3, None, "byte 3/4 of emoji should be buffered");
let r4 = stream.step(203).unwrap();
assert!(r4.is_some(), "byte 4/4 should complete the emoji");
assert_eq!(r4.unwrap(), "😀");
}
#[test]
fn test_tiktoken_encode_batch() {
let dir = tempfile::tempdir().unwrap();
......
......@@ -13,12 +13,17 @@
//! in a hashmap. We will then use these hashes to test that the tokenizer is working correctly. This
//! will detect if upstream dependency changes result in different/new behavior.
use dynamo_llm::tokenizers::traits::{Decoder, Encoder};
use dynamo_llm::tokenizers::traits::{Decoder, Encoder, Tokenizer};
use dynamo_llm::tokenizers::*;
use rstest::rstest;
use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
// ---------------------------------------------------------------------------
// Test data
// ---------------------------------------------------------------------------
const TEST_PROMPTS: [&str; 4] = [
"deep learning is",
"Deep learning is",
......@@ -42,137 +47,201 @@ const LONG_TEST_PROMPTS: [(&str, &str); 6] = [
("Tell me about the following text.", "😀😃😄😁😆🥹😅😂🤣🥲☺️😊😇🙂🙃😉🤩😎 🤪🥳🤓🙄🤪😵👻")
];
const TINYLLAMA_TOKENIZER_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1/tokenizer.json";
const MULTIBYTE_TEST_CASES: [&str; 14] = [
"hello world",
"deep learning is awesome",
"The quick brown fox jumps over the lazy dog.",
"line1\nline2\nline3",
"你好世界", // CJK: 3-byte UTF-8 chars
"😀😃😄😁", // Emoji: 4-byte UTF-8 chars
"hello 你好 world 🌍", // Mixed ASCII + CJK + emoji
"café résumé naïve", // Latin with diacritics (2-byte UTF-8)
"こんにちは", // Japanese hiragana
"Привет мир", // Cyrillic
"مرحبا", // Arabic (RTL)
"🧑‍💻👨‍👩‍👧‍👦", // Emoji ZWJ sequences (complex multi-codepoint)
"a你b😀c", // Interleaved single-byte and multi-byte
"", // Empty string
];
const HF_TOKENIZERS_LOCAL: [&str; 1] = [TINYLLAMA_TOKENIZER_PATH];
const STREAM_TEST_CASES: [(&str, &str); 8] = [
("hello world", "deep learning is great"),
("summarize:", "The quick brown fox jumps over the lazy dog."),
("hello world", "你好世界"),
("prompt:", "😀😃😄😁"),
("translate this:", "hello 你好 world 🌍"),
("text:", "café résumé naïve"),
("say:", "こんにちは"),
("input:", "🧑‍💻👨‍👩‍👧‍👦"),
];
const HASHES: [(&str, [u64; 4]); 1] = [(
TINYLLAMA_TOKENIZER_PATH,
[
1209591529327510910,
4181375434596349981,
6245658446118930933,
5097285695902185237,
],
)];
// ---------------------------------------------------------------------------
// Tokenizer paths
// ---------------------------------------------------------------------------
fn compute_hashes_for_tokenizer<E: Encoder>(tokenizer: &E, prompts: &[&str]) -> Vec<u64> {
prompts
.iter()
.map(|&prompt| {
tokenizer
.encode(prompt)
.expect("Failed to encode prompt")
.get_hash()
// Assuming `get_hash` returns a `u64`
})
.collect()
}
const TINYLLAMA_TOKENIZER_PATH: &str = "tests/data/sample-models/TinyLlama_v1.1/tokenizer.json";
const MOCK_TIKTOKEN_DIR: &str = "tests/data/sample-models/mock-tiktoken";
#[test]
fn compute_hashes_hf() {
let hash_map: HashMap<&str, [u64; 4]> = HASHES.iter().cloned().collect();
fn tinyllama_tokenizer() -> Arc<dyn Tokenizer> {
Arc::new(
HuggingFaceTokenizer::from_file(TINYLLAMA_TOKENIZER_PATH)
.expect("Failed to load HuggingFace tokenizer"),
)
}
for &tokenizer_name in HF_TOKENIZERS_LOCAL.iter() {
let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_name)
.expect("Failed to load HuggingFace tokenizer");
fn mock_tiktoken_tokenizer() -> Arc<dyn Tokenizer> {
let path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join(MOCK_TIKTOKEN_DIR)
.join("tiktoken.model");
Arc::new(
TikTokenTokenizer::from_file_auto(path.to_str().unwrap())
.expect("Failed to load tiktoken tokenizer"),
)
}
let prompt_hashes = compute_hashes_for_tokenizer(&tokenizer, &TEST_PROMPTS);
// ---------------------------------------------------------------------------
// Parameterized scenario tests — every tokenizer must pass all of these
// ---------------------------------------------------------------------------
println!(
"HF Tokenizer: {:?} Hashes: {:?}",
tokenizer_name, prompt_hashes
);
#[rstest]
#[case::huggingface(tinyllama_tokenizer())]
#[case::tiktoken(mock_tiktoken_tokenizer())]
fn test_encode_decode_roundtrip(#[case] tokenizer: Arc<dyn Tokenizer>) {
for &text in TEST_PROMPTS.iter() {
let encoding = tokenizer
.encode(text)
.unwrap_or_else(|e| panic!("Failed to encode '{text}': {e}"));
assert!(!encoding.token_ids().is_empty());
assert_eq!(prompt_hashes, hash_map[tokenizer_name]);
let decoded = tokenizer
.decode(encoding.token_ids(), false)
.unwrap_or_else(|e| panic!("Failed to decode '{text}': {e}"));
assert_eq!(decoded, text, "Roundtrip failed for: '{text}'");
}
}
#[test]
fn test_hf_lifecycle() {
let tokenizer = HuggingFaceTokenizer::from_file(TINYLLAMA_TOKENIZER_PATH)
.expect("Failed to load remote HuggingFace tokenizer");
#[rstest]
#[case::huggingface(tinyllama_tokenizer())]
#[case::tiktoken(mock_tiktoken_tokenizer())]
fn test_encode_decode_roundtrip_multibyte(#[case] tokenizer: Arc<dyn Tokenizer>) {
for &text in MULTIBYTE_TEST_CASES.iter() {
let encoding = tokenizer
.encode(TEST_PROMPTS[0])
.expect("Failed to encode prompt");
.encode(text)
.unwrap_or_else(|e| panic!("Failed to encode '{text}': {e}"));
let decoded = tokenizer
.decode(encoding.token_ids(), false)
.expect("Failed to decode token_ids");
assert_eq!(decoded, TEST_PROMPTS[0]);
.unwrap_or_else(|e| panic!("Failed to decode '{text}': {e}"));
assert_eq!(decoded, text, "Roundtrip failed for: '{text}'");
}
}
#[test]
fn test_sequence() {
let tokenizer = HuggingFaceTokenizer::from_file(TINYLLAMA_TOKENIZER_PATH)
.expect("Failed to load remote HuggingFace tokenizer");
let shared_tokenizer = Arc::new(tokenizer);
// let tokenizer = shared_tokenizer.read().unwrap();
#[rstest]
#[case::huggingface(tinyllama_tokenizer())]
#[case::tiktoken(mock_tiktoken_tokenizer())]
fn test_batch_encode_roundtrip(#[case] tokenizer: Arc<dyn Tokenizer>) {
let inputs = &["hello", "world", "deep learning"];
let encodings = tokenizer
.encode_batch(inputs)
.expect("Failed to batch encode");
assert_eq!(encodings.len(), inputs.len());
let encoding = shared_tokenizer
.encode(TEST_PROMPTS[0])
.expect("Failed to encode prompt");
for (encoding, &input) in encodings.iter().zip(inputs.iter()) {
let decoded = tokenizer
.decode(encoding.token_ids(), false)
.expect("Failed to decode");
assert_eq!(decoded, input);
}
}
let mut sequence = Sequence::new(shared_tokenizer.clone().into());
sequence
.append_text(TEST_PROMPTS[0])
.expect("Failed to append prompt");
#[rstest]
#[case::huggingface(tinyllama_tokenizer())]
#[case::tiktoken(mock_tiktoken_tokenizer())]
fn test_sequence_append_and_decode(#[case] tokenizer: Arc<dyn Tokenizer>) {
let text = TEST_PROMPTS[0];
let encoding = tokenizer.encode(text).expect("Failed to encode prompt");
// Append text and verify token count matches
let mut sequence = Sequence::new(tokenizer.clone().into());
sequence.append_text(text).expect("Failed to append prompt");
assert_eq!(sequence.len(), encoding.token_ids().len());
let mut decoder = Sequence::new(shared_tokenizer.clone().into());
// Incremental token-by-token decode via Sequence::append_token_id
let mut decoder = Sequence::new(tokenizer.clone().into());
let mut output = String::new();
for token_id in encoding.token_ids() {
let text = decoder
.append_token_id(*token_id)
for &token_id in encoding.token_ids() {
let chunk = decoder
.append_token_id(token_id)
.expect("Failed to decode token_id");
output.push_str(text.as_str());
output.push_str(&chunk);
}
assert_eq!(decoder.len(), sequence.len());
assert_eq!(decoder.token_ids(), sequence.token_ids());
assert_eq!(output, TEST_PROMPTS[0]);
assert_eq!(output, text);
}
let mut decoder = DecodeStream::new(shared_tokenizer.clone(), &[], false);
#[rstest]
#[case::huggingface(tinyllama_tokenizer())]
#[case::tiktoken(mock_tiktoken_tokenizer())]
fn test_sequence_roundtrip_multibyte(#[case] tokenizer: Arc<dyn Tokenizer>) {
// Skip empty string — Sequence doesn't produce output for zero tokens
for &text in MULTIBYTE_TEST_CASES.iter().filter(|t| !t.is_empty()) {
let encoding = tokenizer
.encode(text)
.unwrap_or_else(|e| panic!("Failed to encode '{text}': {e}"));
let mut sequence = Sequence::new(tokenizer.clone().into());
let mut output = String::new();
for token_id in encoding.token_ids() {
let text = decoder.step(*token_id).expect("Failed to decode token_id");
if let Some(text) = text {
output.push_str(text.as_str());
for &token_id in encoding.token_ids() {
let chunk = sequence
.append_token_id(token_id)
.unwrap_or_else(|e| panic!("append_token_id failed for '{text}': {e}"));
output.push_str(&chunk);
}
assert_eq!(output, text, "Sequence roundtrip failed for: '{text}'");
}
assert_eq!(output, TEST_PROMPTS[0]);
}
#[test]
fn test_long_sequence_incremental_decode_with_prefill() {
let tokenizer = HuggingFaceTokenizer::from_file(TINYLLAMA_TOKENIZER_PATH)
.expect("Failed to load remote HuggingFace tokenizer");
#[rstest]
#[case::huggingface(tinyllama_tokenizer())]
#[case::tiktoken(mock_tiktoken_tokenizer())]
fn test_decode_stream_basic(#[case] tokenizer: Arc<dyn Tokenizer>) {
let text = TEST_PROMPTS[0];
let encoding = tokenizer.encode(text).expect("Failed to encode prompt");
let shared_tokenizer = Arc::new(tokenizer);
let mut stream = DecodeStream::new(tokenizer.clone(), &[], false);
let mut output = String::new();
for &token_id in encoding.token_ids() {
if let Some(chunk) = stream.step(token_id).expect("Failed to decode token_id") {
output.push_str(&chunk);
}
}
assert_eq!(output, text);
}
for (input_text, output_text) in LONG_TEST_PROMPTS.iter() {
let input_encoding = shared_tokenizer
#[rstest]
#[case::huggingface(tinyllama_tokenizer())]
#[case::tiktoken(mock_tiktoken_tokenizer())]
fn test_decode_stream_with_prefill(#[case] tokenizer: Arc<dyn Tokenizer>) {
for &(input_text, output_text) in LONG_TEST_PROMPTS.iter() {
let input_encoding = tokenizer
.encode(input_text)
.expect("Failed to encode prompt");
.unwrap_or_else(|e| panic!("Failed to encode prompt '{input_text}': {e}"));
let output_encoding = shared_tokenizer
let output_encoding = tokenizer
.encode(output_text)
.expect("Failed to encode prompt");
.unwrap_or_else(|e| panic!("Failed to encode output '{output_text}': {e}"));
let mut decoder =
DecodeStream::new(shared_tokenizer.clone(), input_encoding.token_ids(), false);
let mut stream = DecodeStream::new(tokenizer.clone(), input_encoding.token_ids(), false);
let mut output = String::new();
for token_id in output_encoding.token_ids() {
let text = decoder.step(*token_id).expect("Failed to decode token_id");
if let Some(text) = text {
output.push_str(text.as_str());
for &token_id in output_encoding.token_ids() {
if let Some(chunk) = stream
.step(token_id)
.unwrap_or_else(|e| panic!("DecodeStream::step failed for '{output_text}': {e}"))
{
output.push_str(&chunk);
}
}
......@@ -180,6 +249,97 @@ fn test_long_sequence_incremental_decode_with_prefill() {
}
}
#[rstest]
#[case::huggingface(tinyllama_tokenizer())]
#[case::tiktoken(mock_tiktoken_tokenizer())]
fn test_decode_stream_multibyte(#[case] tokenizer: Arc<dyn Tokenizer>) {
for &(prompt, output_text) in STREAM_TEST_CASES.iter() {
let prompt_encoding = tokenizer
.encode(prompt)
.unwrap_or_else(|e| panic!("Failed to encode prompt '{prompt}': {e}"));
let output_encoding = tokenizer
.encode(output_text)
.unwrap_or_else(|e| panic!("Failed to encode output '{output_text}': {e}"));
let mut stream = DecodeStream::new(tokenizer.clone(), prompt_encoding.token_ids(), false);
let mut reassembled = String::new();
for &token_id in output_encoding.token_ids() {
if let Some(chunk) = stream
.step(token_id)
.unwrap_or_else(|e| panic!("DecodeStream::step failed for '{output_text}': {e}"))
{
reassembled.push_str(&chunk);
}
}
assert_eq!(
reassembled.trim(),
output_text,
"DecodeStream roundtrip failed for: '{output_text}'"
);
}
}
#[rstest]
#[case::huggingface(tinyllama_tokenizer())]
#[case::tiktoken(mock_tiktoken_tokenizer())]
fn test_hash_determinism(#[case] tokenizer: Arc<dyn Tokenizer>) {
let prompts = &["hello world", "deep learning", "another prompt"];
let hashes1 = compute_hashes_for_tokenizer(tokenizer.as_ref(), prompts);
let hashes2 = compute_hashes_for_tokenizer(tokenizer.as_ref(), prompts);
assert_eq!(hashes1, hashes2, "Hashes should be deterministic");
assert!(hashes1.iter().all(|&h| h != 0), "Hashes should be non-zero");
}
// ---------------------------------------------------------------------------
// Tokenizer-specific tests (not parameterized)
// ---------------------------------------------------------------------------
fn compute_hashes_for_tokenizer<E: Encoder + ?Sized>(tokenizer: &E, prompts: &[&str]) -> Vec<u64> {
prompts
.iter()
.map(|&prompt| {
tokenizer
.encode(prompt)
.expect("Failed to encode prompt")
.get_hash()
})
.collect()
}
const HF_TOKENIZERS_LOCAL: [&str; 1] = [TINYLLAMA_TOKENIZER_PATH];
const HASHES: [(&str, [u64; 4]); 1] = [(
TINYLLAMA_TOKENIZER_PATH,
[
1209591529327510910,
4181375434596349981,
6245658446118930933,
5097285695902185237,
],
)];
#[test]
fn compute_hashes_hf() {
let hash_map: HashMap<&str, [u64; 4]> = HASHES.iter().cloned().collect();
for &tokenizer_name in HF_TOKENIZERS_LOCAL.iter() {
let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_name)
.expect("Failed to load HuggingFace tokenizer");
let prompt_hashes = compute_hashes_for_tokenizer(&tokenizer, &TEST_PROMPTS);
println!(
"HF Tokenizer: {:?} Hashes: {:?}",
tokenizer_name, prompt_hashes
);
assert_eq!(prompt_hashes, hash_map[tokenizer_name]);
}
}
#[test]
fn test_decode_with_skip_special_tokens() {
let tokenizer = HuggingFaceTokenizer::from_file(TINYLLAMA_TOKENIZER_PATH)
......@@ -208,87 +368,13 @@ fn test_decode_with_skip_special_tokens() {
assert_eq!(decoded_without_special, "Hello world");
}
// --- tiktoken tests ---
const MOCK_TIKTOKEN_DIR: &str = "tests/data/sample-models/mock-tiktoken";
fn mock_tiktoken_model_path() -> String {
Path::new(env!("CARGO_MANIFEST_DIR"))
.join(MOCK_TIKTOKEN_DIR)
.join("tiktoken.model")
.to_str()
.unwrap()
.to_string()
}
#[test]
fn test_tiktoken_lifecycle() {
let path = mock_tiktoken_model_path();
let tokenizer =
TikTokenTokenizer::from_file_auto(&path).expect("Failed to load tiktoken tokenizer");
// Test simple encode/decode roundtrip
let text = "hello world";
let encoding = tokenizer.encode(text).expect("Failed to encode");
let ids = encoding.token_ids();
assert!(!ids.is_empty(), "Token IDs should not be empty");
// Verify Sp variant
match &encoding {
Encoding::Sp(_) => {}
other => panic!("Expected Encoding::Sp, got {:?}", other),
}
let decoded = tokenizer
.decode(ids, false)
.expect("Failed to decode token_ids");
assert_eq!(decoded, text);
}
#[test]
fn test_tiktoken_decode_stream() {
let path = mock_tiktoken_model_path();
let tokenizer =
TikTokenTokenizer::from_file_auto(&path).expect("Failed to load tiktoken tokenizer");
let shared_tokenizer: Arc<dyn dynamo_llm::tokenizers::traits::Tokenizer> = Arc::new(tokenizer);
let text = "hello world";
let encoding = shared_tokenizer
.encode(text)
.expect("Failed to encode prompt");
let mut decoder = DecodeStream::new(shared_tokenizer.clone(), &[], false);
let mut output = String::new();
for token_id in encoding.token_ids() {
let step_text = decoder.step(*token_id).expect("Failed to decode token_id");
if let Some(t) = step_text {
output.push_str(&t);
}
}
assert_eq!(output, text);
}
#[test]
fn compute_hashes_tiktoken() {
let path = mock_tiktoken_model_path();
let tokenizer =
TikTokenTokenizer::from_file_auto(&path).expect("Failed to load tiktoken tokenizer");
let simple_prompts = &["hello world", "hello", "world"];
let hashes = compute_hashes_for_tokenizer(&tokenizer, simple_prompts);
// Just verify we get consistent hashes (non-zero, deterministic)
let hashes2 = compute_hashes_for_tokenizer(&tokenizer, simple_prompts);
assert_eq!(hashes, hashes2, "Hashes should be deterministic");
assert!(hashes.iter().all(|&h| h != 0), "Hashes should be non-zero");
}
#[test]
fn test_tiktoken_create_from_file() {
let path = mock_tiktoken_model_path();
// Test the factory function used by the Tokenizer wrapper
let tokenizer = create_tokenizer_from_file(&path).expect("Failed to create tokenizer");
let path = Path::new(env!("CARGO_MANIFEST_DIR"))
.join(MOCK_TIKTOKEN_DIR)
.join("tiktoken.model");
let tokenizer =
create_tokenizer_from_file(path.to_str().unwrap()).expect("Failed to create tokenizer");
let encoding = tokenizer
.encode("hello")
......@@ -297,21 +383,11 @@ fn test_tiktoken_create_from_file() {
}
#[test]
fn test_tiktoken_batch_encode() {
let path = mock_tiktoken_model_path();
let tokenizer =
TikTokenTokenizer::from_file_auto(&path).expect("Failed to load tiktoken tokenizer");
let inputs = &["hello", "world"];
let encodings = tokenizer
.encode_batch(inputs)
.expect("Failed to batch encode");
assert_eq!(encodings.len(), 2);
for (encoding, input) in encodings.iter().zip(inputs.iter()) {
let decoded = tokenizer
.decode(encoding.token_ids(), false)
.expect("Failed to decode");
assert_eq!(decoded, *input);
fn test_tiktoken_encoding_variant_is_sp() {
let tokenizer = mock_tiktoken_tokenizer();
let encoding = tokenizer.encode("hello world").expect("Failed to encode");
match &encoding {
Encoding::Sp(_) => {}
other => panic!("Expected Encoding::Sp, got {:?}", other),
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment