Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
dynamo
Commits
2cabf441
Unverified
Commit
2cabf441
authored
Apr 13, 2026
by
Biswa Panda
Committed by
GitHub
Apr 13, 2026
Browse files
feat: Decoder clean-up for handling incomplete multi-byte sequence (#8022)
parent
20ce329b
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
182 additions
and
69 deletions
+182
-69
lib/llm/src/backend.rs
lib/llm/src/backend.rs
+1
-1
lib/llm/src/tokenizers.rs
lib/llm/src/tokenizers.rs
+83
-19
lib/llm/src/tokenizers/fastokens.rs
lib/llm/src/tokenizers/fastokens.rs
+5
-5
lib/llm/src/tokenizers/hf.rs
lib/llm/src/tokenizers/hf.rs
+3
-3
lib/llm/src/tokenizers/tiktoken.rs
lib/llm/src/tokenizers/tiktoken.rs
+71
-28
lib/llm/tests/test_stop_behavior.rs
lib/llm/tests/test_stop_behavior.rs
+4
-3
lib/llm/tests/tokenizers.rs
lib/llm/tests/tokenizers.rs
+15
-10
No files found.
lib/llm/src/backend.rs
View file @
2cabf441
...
@@ -638,7 +638,7 @@ mod tests {
...
@@ -638,7 +638,7 @@ mod tests {
&
self
,
&
self
,
_
token_ids
:
&
[
TokenIdType
],
_
token_ids
:
&
[
TokenIdType
],
_
skip_special_tokens
:
bool
,
_
skip_special_tokens
:
bool
,
)
->
anyhow
::
Result
<
S
tr
ing
>
{
)
->
anyhow
::
Result
<
tr
aits
::
DecodeResult
>
{
Err
(
anyhow
::
anyhow!
(
Err
(
anyhow
::
anyhow!
(
"Unable to decode into a valid UTF-8 string: incomplete utf-8 byte sequence from index 6"
"Unable to decode into a valid UTF-8 string: incomplete utf-8 byte sequence from index 6"
))
))
...
...
lib/llm/src/tokenizers.rs
View file @
2cabf441
...
@@ -19,6 +19,7 @@ pub use anyhow::{Error, Result};
...
@@ -19,6 +19,7 @@ pub use anyhow::{Error, Result};
pub
use
fastokens
::
FastTokenizer
;
pub
use
fastokens
::
FastTokenizer
;
pub
use
hf
::
HuggingFaceTokenizer
;
pub
use
hf
::
HuggingFaceTokenizer
;
pub
use
tiktoken
::
TikTokenTokenizer
;
pub
use
tiktoken
::
TikTokenTokenizer
;
pub
use
traits
::
DecodeResult
;
/// Represents the type of tokenizer being used
/// Represents the type of tokenizer being used
#[derive(Debug)]
#[derive(Debug)]
...
@@ -62,12 +63,66 @@ pub mod traits {
...
@@ -62,12 +63,66 @@ pub mod traits {
fn
encode_batch
(
&
self
,
inputs
:
&
[
&
str
])
->
Result
<
Vec
<
Encoding
>>
;
fn
encode_batch
(
&
self
,
inputs
:
&
[
&
str
])
->
Result
<
Vec
<
Encoding
>>
;
}
}
/// Implementations **must** use lossy UTF-8 conversion (e.g. `String::from_utf8_lossy`)
/// Result of decoding token IDs to text.
/// so that partial multi-byte sequences produce U+FFFD (`�`) rather than returning `Err`.
///
/// `DecodeStream::step()` relies on the replacement character to detect incomplete
/// Distinguishes between fully valid UTF-8 output and output that contains
/// sequences and buffer tokens until the full character arrives.
/// trailing incomplete multi-byte sequences (represented as U+FFFD).
/// This lets callers like `DecodeStream::step()` decide whether to emit or
/// buffer without resorting to hardcoded replacement-character string checks.
#[derive(Debug,
Clone,
PartialEq,
Eq,
strum::EnumIs)]
pub
enum
DecodeResult
{
/// No trailing incomplete multi-byte sequences (text does not end with U+FFFD).
/// Note: the string may still contain *interior* U+FFFD characters from
/// mid-stream invalid byte sequences; only trailing status is tracked here.
Complete
(
String
),
/// The decoded string ends with U+FFFD, indicating incomplete trailing
/// multi-byte bytes that may be completed by subsequent tokens.
Partial
(
String
),
}
impl
DecodeResult
{
/// Returns a reference to the inner string.
pub
fn
as_str
(
&
self
)
->
&
str
{
match
self
{
DecodeResult
::
Complete
(
s
)
|
DecodeResult
::
Partial
(
s
)
=>
s
,
}
}
/// Construct from a decoded string: `Partial` if it ends with U+FFFD, else `Complete`.
pub
fn
from_decoded
(
text
:
String
)
->
Self
{
if
text
.ends_with
(
'\
u
{
FFFD
}
'
)
{
DecodeResult
::
Partial
(
text
)
}
else
{
DecodeResult
::
Complete
(
text
)
}
}
}
impl
From
<
String
>
for
DecodeResult
{
fn
from
(
text
:
String
)
->
Self
{
DecodeResult
::
from_decoded
(
text
)
}
}
impl
From
<
DecodeResult
>
for
String
{
fn
from
(
result
:
DecodeResult
)
->
Self
{
match
result
{
DecodeResult
::
Complete
(
s
)
|
DecodeResult
::
Partial
(
s
)
=>
s
,
}
}
}
/// Implementations must ensure that partial multi-byte sequences produce U+FFFD
/// (`\u{FFFD}`) in the output rather than returning `Err`. This is commonly achieved
/// via `String::from_utf8_lossy` (tiktoken) or library-internal byte-fallback handling
/// (HuggingFace). `DecodeStream::step()` relies on `DecodeResult::Partial` to detect
/// incomplete sequences and buffer tokens until the full character arrives.
pub
trait
Decoder
:
Send
+
Sync
{
pub
trait
Decoder
:
Send
+
Sync
{
fn
decode
(
&
self
,
token_ids
:
&
[
TokenIdType
],
skip_special_tokens
:
bool
)
->
Result
<
String
>
;
fn
decode
(
&
self
,
token_ids
:
&
[
TokenIdType
],
skip_special_tokens
:
bool
,
)
->
Result
<
DecodeResult
>
;
}
}
pub
trait
Tokenizer
:
Encoder
+
Decoder
{
pub
trait
Tokenizer
:
Encoder
+
Decoder
{
...
@@ -219,23 +274,27 @@ impl DecodeStream {
...
@@ -219,23 +274,27 @@ impl DecodeStream {
pub
fn
step
(
&
mut
self
,
id
:
u32
)
->
Result
<
Option
<
String
>>
{
pub
fn
step
(
&
mut
self
,
id
:
u32
)
->
Result
<
Option
<
String
>>
{
self
.all_token_ids
.push
(
id
);
self
.all_token_ids
.push
(
id
);
let
prefix_text
=
self
.tokenizer
.decode
(
let
prefix_text
:
String
=
self
&
self
.all_token_ids
[
self
.prefix_offset
..
self
.read_offset
],
.tokenizer
self
.skip_special_tokens
,
.decode
(
)
?
;
&
self
.all_token_ids
[
self
.prefix_offset
..
self
.read_offset
],
self
.skip_special_tokens
,
)
?
.into
();
let
new_
tex
t
=
self
.tokenizer
.decode
(
let
new_
resul
t
=
self
.tokenizer
.decode
(
&
self
.all_token_ids
[
self
.prefix_offset
..
],
&
self
.all_token_ids
[
self
.prefix_offset
..
],
self
.skip_special_tokens
,
self
.skip_special_tokens
,
)
?
;
)
?
;
if
new_text
.len
()
>
prefix_text
.len
()
&&
!
new_text
.ends_with
(
"�"
)
{
let
new_text
=
new_result
.as_str
();
let
new_text
=
new_text
[
prefix_text
.len
()
..
]
.to_string
();
if
new_text
.len
()
>
prefix_text
.len
()
&&
!
new_result
.is_partial
()
{
let
emitted
=
new_text
[
prefix_text
.len
()
..
]
.to_string
();
self
.prefix_offset
=
self
.read_offset
;
self
.prefix_offset
=
self
.read_offset
;
self
.read_offset
=
self
.all_token_ids
.len
();
self
.read_offset
=
self
.all_token_ids
.len
();
Ok
(
Some
(
new_text
))
Ok
(
Some
(
emitted
))
}
else
{
}
else
{
Ok
(
None
)
Ok
(
None
)
}
}
...
@@ -322,14 +381,17 @@ impl Sequence {
...
@@ -322,14 +381,17 @@ impl Sequence {
self
.token_ids
.push
(
token_id
);
self
.token_ids
.push
(
token_id
);
// log::trace!("pushed token_id: {}", token_id);
// log::trace!("pushed token_id: {}", token_id);
let
prefix_text
=
self
let
prefix_text
:
String
=
self
.tokenizer
.tokenizer
.decode
(
&
self
.token_ids
[
self
.prefix_offset
..
self
.read_offset
],
false
)
?
;
.decode
(
&
self
.token_ids
[
self
.prefix_offset
..
self
.read_offset
],
false
)
?
.into
();
let
new_
tex
t
=
self
let
new_
resul
t
=
self
.tokenizer
.tokenizer
.decode
(
&
self
.token_ids
[
self
.prefix_offset
..
],
false
)
?
;
.decode
(
&
self
.token_ids
[
self
.prefix_offset
..
],
false
)
?
;
let
new_text
=
new_result
.as_str
();
// if the end character of the previous returned sequence is a multi-byte character
// if the end character of the previous returned sequence is a multi-byte character
// then we can not split the text on that byte offset, so we roll back to the byte offset
// then we can not split the text on that byte offset, so we roll back to the byte offset
// of the start of that character
// of the start of that character
...
@@ -340,11 +402,13 @@ impl Sequence {
...
@@ -340,11 +402,13 @@ impl Sequence {
let
prefix_text_len
=
prefix_text_len
;
let
prefix_text_len
=
prefix_text_len
;
if
new_text
.len
()
>
prefix_text
.len
()
{
if
new_text
.len
()
>
prefix_text
.len
()
{
if
new_
text
.ends_with
(
"�"
)
{
if
new_
result
.is_partial
(
)
{
return
Ok
(
""
.to_string
());
return
Ok
(
""
.to_string
());
}
else
{
}
else
{
// shift and update the state
// shift and update the state
let
new_text
=
new_text
[
prefix_text_len
..
]
.to_string
()
.replace
(
"�"
,
""
);
let
new_text
=
new_text
[
prefix_text_len
..
]
.to_string
()
.replace
(
'\
u
{
FFFD
}
'
,
""
);
self
.prefix_offset
=
self
.read_offset
;
self
.prefix_offset
=
self
.read_offset
;
self
.read_offset
=
self
.token_ids
.len
();
self
.read_offset
=
self
.token_ids
.len
();
return
Ok
(
new_text
);
return
Ok
(
new_text
);
...
@@ -366,7 +430,7 @@ impl Sequence {
...
@@ -366,7 +430,7 @@ impl Sequence {
// let tokenizer = self.tokenizer.read().map_err(|err| {
// let tokenizer = self.tokenizer.read().map_err(|err| {
// Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
// Error::msg(format!("Failed to acquire read lock on tokenizer: {}", err))
// })?;
// })?;
self
.tokenizer
.decode
(
&
self
.token_ids
,
false
)
Ok
(
self
.tokenizer
.decode
(
&
self
.token_ids
,
false
)
?
.into
())
}
}
}
}
...
...
lib/llm/src/tokenizers/fastokens.rs
View file @
2cabf441
...
@@ -14,7 +14,7 @@ use rayon::prelude::*;
...
@@ -14,7 +14,7 @@ use rayon::prelude::*;
use
super
::{
use
super
::{
Encoding
,
Error
,
Result
,
TokenIdType
,
Encoding
,
Error
,
Result
,
TokenIdType
,
hf
::
HuggingFaceTokenizer
,
hf
::
HuggingFaceTokenizer
,
traits
::{
Decoder
,
Encoder
,
Tokenizer
},
traits
::{
DecodeResult
,
Decoder
,
Encoder
,
Tokenizer
},
};
};
/// Hybrid tokenizer: fast BPE encoding via `fastokens`, decoding via HuggingFace.
/// Hybrid tokenizer: fast BPE encoding via `fastokens`, decoding via HuggingFace.
...
@@ -52,7 +52,7 @@ impl Encoder for FastTokenizer {
...
@@ -52,7 +52,7 @@ impl Encoder for FastTokenizer {
}
}
impl
Decoder
for
FastTokenizer
{
impl
Decoder
for
FastTokenizer
{
fn
decode
(
&
self
,
token_ids
:
&
[
TokenIdType
],
skip_special_tokens
:
bool
)
->
Result
<
String
>
{
fn
decode
(
&
self
,
token_ids
:
&
[
TokenIdType
],
skip_special_tokens
:
bool
)
->
Result
<
DecodeResult
>
{
self
.hf_decoder
.decode
(
token_ids
,
skip_special_tokens
)
self
.hf_decoder
.decode
(
token_ids
,
skip_special_tokens
)
}
}
}
}
...
@@ -81,7 +81,7 @@ mod tests {
...
@@ -81,7 +81,7 @@ mod tests {
let
text
=
"Hello, world!"
;
let
text
=
"Hello, world!"
;
let
encoding
=
tokenizer
.encode
(
text
)
.unwrap
();
let
encoding
=
tokenizer
.encode
(
text
)
.unwrap
();
assert
!
(
!
encoding
.token_ids
()
.is_empty
());
assert
!
(
!
encoding
.token_ids
()
.is_empty
());
let
decoded
=
tokenizer
.decode
(
encoding
.token_ids
(),
true
)
.unwrap
();
let
decoded
:
String
=
tokenizer
.decode
(
encoding
.token_ids
(),
true
)
.unwrap
()
.into
()
;
assert
!
(
!
decoded
.is_empty
());
assert
!
(
!
decoded
.is_empty
());
// The decoded text should contain the same non-space characters
// The decoded text should contain the same non-space characters
let
enc_chars
:
String
=
text
.chars
()
.filter
(|
c
|
!
c
.is_whitespace
())
.collect
();
let
enc_chars
:
String
=
text
.chars
()
.filter
(|
c
|
!
c
.is_whitespace
())
.collect
();
...
@@ -149,8 +149,8 @@ mod tests {
...
@@ -149,8 +149,8 @@ mod tests {
// decode(continuation) which lacks the surrounding context.
// decode(continuation) which lacks the surrounding context.
let
mut
all_ids
=
prompt_ids
.clone
();
let
mut
all_ids
=
prompt_ids
.clone
();
all_ids
.extend_from_slice
(
&
cont_ids
);
all_ids
.extend_from_slice
(
&
cont_ids
);
let
full_text
=
wrapper
.decode
(
&
all_ids
,
true
)
.unwrap
();
let
full_text
:
String
=
wrapper
.decode
(
&
all_ids
,
true
)
.unwrap
()
.into
()
;
let
prompt_text
=
wrapper
.decode
(
&
prompt_ids
,
true
)
.unwrap
();
let
prompt_text
:
String
=
wrapper
.decode
(
&
prompt_ids
,
true
)
.unwrap
()
.into
()
;
let
expected
=
&
full_text
[
prompt_text
.len
()
..
];
let
expected
=
&
full_text
[
prompt_text
.len
()
..
];
assert_eq!
(
assert_eq!
(
accumulated
,
expected
,
accumulated
,
expected
,
...
...
lib/llm/src/tokenizers/hf.rs
View file @
2cabf441
...
@@ -5,7 +5,7 @@ use tokenizers::tokenizer::Tokenizer as HfTokenizer;
...
@@ -5,7 +5,7 @@ use tokenizers::tokenizer::Tokenizer as HfTokenizer;
use
super
::{
use
super
::{
Encoding
,
Error
,
Result
,
TokenIdType
,
Encoding
,
Error
,
Result
,
TokenIdType
,
traits
::{
Decoder
,
Encoder
,
Tokenizer
},
traits
::{
DecodeResult
,
Decoder
,
Encoder
,
Tokenizer
},
};
};
pub
struct
HuggingFaceTokenizer
{
pub
struct
HuggingFaceTokenizer
{
...
@@ -52,14 +52,14 @@ impl Encoder for HuggingFaceTokenizer {
...
@@ -52,14 +52,14 @@ impl Encoder for HuggingFaceTokenizer {
}
}
impl
Decoder
for
HuggingFaceTokenizer
{
impl
Decoder
for
HuggingFaceTokenizer
{
fn
decode
(
&
self
,
token_ids
:
&
[
TokenIdType
],
skip_special_tokens
:
bool
)
->
Result
<
String
>
{
fn
decode
(
&
self
,
token_ids
:
&
[
TokenIdType
],
skip_special_tokens
:
bool
)
->
Result
<
DecodeResult
>
{
// This calls into the library
// This calls into the library
let
text
=
self
let
text
=
self
.tokenizer
.tokenizer
.decode
(
token_ids
,
skip_special_tokens
)
.decode
(
token_ids
,
skip_special_tokens
)
.map_err
(|
err
|
Error
::
msg
(
format!
(
"Error de-tokenizing input: {err}"
)))
?
;
.map_err
(|
err
|
Error
::
msg
(
format!
(
"Error de-tokenizing input: {err}"
)))
?
;
Ok
(
text
)
Ok
(
text
.into
()
)
}
}
}
}
...
...
lib/llm/src/tokenizers/tiktoken.rs
View file @
2cabf441
...
@@ -11,7 +11,7 @@ use tiktoken_rs::CoreBPE;
...
@@ -11,7 +11,7 @@ use tiktoken_rs::CoreBPE;
use
super
::{
use
super
::{
Encoding
,
Error
,
Result
,
TokenIdType
,
Encoding
,
Error
,
Result
,
TokenIdType
,
traits
::{
Decoder
,
Encoder
,
Tokenizer
},
traits
::{
DecodeResult
,
Decoder
,
Encoder
,
Tokenizer
},
};
};
/// Number of reserved special-token slots to generate when filling gaps in the vocabulary.
/// Number of reserved special-token slots to generate when filling gaps in the vocabulary.
...
@@ -89,7 +89,7 @@ impl Encoder for TikTokenTokenizer {
...
@@ -89,7 +89,7 @@ impl Encoder for TikTokenTokenizer {
}
}
impl
Decoder
for
TikTokenTokenizer
{
impl
Decoder
for
TikTokenTokenizer
{
fn
decode
(
&
self
,
token_ids
:
&
[
TokenIdType
],
skip_special_tokens
:
bool
)
->
Result
<
String
>
{
fn
decode
(
&
self
,
token_ids
:
&
[
TokenIdType
],
skip_special_tokens
:
bool
)
->
Result
<
DecodeResult
>
{
let
ids
:
Vec
<
u32
>
=
if
skip_special_tokens
{
let
ids
:
Vec
<
u32
>
=
if
skip_special_tokens
{
token_ids
token_ids
.iter
()
.iter
()
...
@@ -100,12 +100,22 @@ impl Decoder for TikTokenTokenizer {
...
@@ -100,12 +100,22 @@ impl Decoder for TikTokenTokenizer {
token_ids
.to_vec
()
token_ids
.to_vec
()
};
};
// Use lossy UTF-8 conversion so that partial multi-byte sequences become U+FFFD (�).
// Try strict UTF-8 first: valid bytes get `Complete` with zero extra allocation
// This is critical for incremental detokenization: DecodeStream::step() relies on
// (takes ownership of the Vec). This correctly handles vocabulary tokens whose
// the replacement character to detect incomplete sequences and buffer tokens until
// raw bytes are EF BF BD (legitimate U+FFFD) -- they are valid UTF-8 and must
// a complete character arrives. CoreBPE::decode() would error on invalid UTF-8 instead.
// not be confused with incomplete multi-byte sequences.
//
// On failure, fall back to lossy conversion so partial multi-byte sequences
// become U+FFFD, then classify via the trailing-FFFD heuristic. This path is
// only hit during incremental detokenization of byte-fallback tokens.
let
bytes
:
Vec
<
u8
>
=
self
.bpe
._decode_native_and_split
(
ids
)
.flatten
()
.collect
();
let
bytes
:
Vec
<
u8
>
=
self
.bpe
._decode_native_and_split
(
ids
)
.flatten
()
.collect
();
Ok
(
String
::
from_utf8_lossy
(
&
bytes
)
.into_owned
())
match
String
::
from_utf8
(
bytes
)
{
Ok
(
text
)
=>
Ok
(
DecodeResult
::
Complete
(
text
)),
Err
(
e
)
=>
{
let
text
=
String
::
from_utf8_lossy
(
e
.as_bytes
())
.into_owned
();
Ok
(
DecodeResult
::
from_decoded
(
text
))
}
}
}
}
}
}
...
@@ -351,7 +361,7 @@ mod tests {
...
@@ -351,7 +361,7 @@ mod tests {
assert
!
(
!
ids
.is_empty
());
assert
!
(
!
ids
.is_empty
());
// Test decode roundtrip
// Test decode roundtrip
let
decoded
=
tokenizer
.decode
(
ids
,
false
)
.unwrap
();
let
decoded
:
String
=
tokenizer
.decode
(
ids
,
false
)
.unwrap
()
.into
()
;
assert_eq!
(
decoded
,
"hello world"
);
assert_eq!
(
decoded
,
"hello world"
);
}
}
...
@@ -393,11 +403,11 @@ mod tests {
...
@@ -393,11 +403,11 @@ mod tests {
ids
.push
(
22
);
// [EOS]
ids
.push
(
22
);
// [EOS]
// Decode with skip_special_tokens=true should strip special tokens
// Decode with skip_special_tokens=true should strip special tokens
let
decoded_skip
=
tokenizer
.decode
(
&
ids
,
true
)
.unwrap
();
let
decoded_skip
:
String
=
tokenizer
.decode
(
&
ids
,
true
)
.unwrap
()
.into
()
;
assert_eq!
(
decoded_skip
,
"hello"
);
assert_eq!
(
decoded_skip
,
"hello"
);
// Decode with skip_special_tokens=false should include them
// Decode with skip_special_tokens=false should include them
let
decoded_all
=
tokenizer
.decode
(
&
ids
,
false
)
.unwrap
();
let
decoded_all
:
String
=
tokenizer
.decode
(
&
ids
,
false
)
.unwrap
()
.into
()
;
assert
!
(
decoded_all
.contains
(
"hello"
));
assert
!
(
decoded_all
.contains
(
"hello"
));
}
}
...
@@ -416,7 +426,7 @@ mod tests {
...
@@ -416,7 +426,7 @@ mod tests {
let
ids
=
encoding
.token_ids
();
let
ids
=
encoding
.token_ids
();
assert
!
(
!
ids
.is_empty
());
assert
!
(
!
ids
.is_empty
());
let
decoded
=
tokenizer
.decode
(
ids
,
false
)
.unwrap
();
let
decoded
:
String
=
tokenizer
.decode
(
ids
,
false
)
.unwrap
()
.into
()
;
assert_eq!
(
decoded
,
"hello world"
);
assert_eq!
(
decoded
,
"hello world"
);
}
}
...
@@ -490,6 +500,15 @@ mod tests {
...
@@ -490,6 +500,15 @@ mod tests {
content
.push_str
(
&
format!
(
"{encoded} {rank}
\n
"
));
content
.push_str
(
&
format!
(
"{encoded} {rank}
\n
"
));
}
}
// Legitimate U+FFFD token: valid UTF-8 bytes EF BF BD (replacement character
// as an actual vocabulary entry, not an artifact of lossy conversion)
let
fffd_token
:
Vec
<
(
Vec
<
u8
>
,
u32
)
>
=
vec!
[(
vec!
[
0xEF
,
0xBF
,
0xBD
],
300
)];
for
(
token
,
rank
)
in
&
fffd_token
{
let
encoded
=
engine
.encode
(
token
);
content
.push_str
(
&
format!
(
"{encoded} {rank}
\n
"
));
}
let
file_path
=
dir
.join
(
"tiktoken.model"
);
let
file_path
=
dir
.join
(
"tiktoken.model"
);
let
mut
file
=
std
::
fs
::
File
::
create
(
&
file_path
)
.unwrap
();
let
mut
file
=
std
::
fs
::
File
::
create
(
&
file_path
)
.unwrap
();
file
.write_all
(
content
.as_bytes
())
.unwrap
();
file
.write_all
(
content
.as_bytes
())
.unwrap
();
...
@@ -516,11 +535,11 @@ mod tests {
...
@@ -516,11 +535,11 @@ mod tests {
result
.is_ok
(),
result
.is_ok
(),
"decode() should not error on incomplete UTF-8 bytes"
"decode() should not error on incomplete UTF-8 bytes"
);
);
let
tex
t
=
result
.unwrap
();
let
decode_resul
t
=
result
.unwrap
();
assert
!
(
assert
!
(
text
.contains
(
'\
u
{
FFFD
}
'
),
decode_result
.is_partial
(
),
"incomplete UTF-8 byte should produce
replacement character
, got: {:?}"
,
"incomplete UTF-8 byte should produce
DecodeResult::Partial
, got: {:?}"
,
tex
t
decode_resul
t
);
);
}
}
...
@@ -532,11 +551,11 @@ mod tests {
...
@@ -532,11 +551,11 @@ mod tests {
let
result
=
tokenizer
.decode
(
&
[
100
,
101
],
false
);
let
result
=
tokenizer
.decode
(
&
[
100
,
101
],
false
);
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
let
tex
t
=
result
.unwrap
();
let
decode_resul
t
=
result
.unwrap
();
assert
!
(
assert
!
(
text
.contains
(
'\
u
{
FFFD
}
'
),
decode_result
.is_partial
(
),
"incomplete 2-of-3 UTF-8 bytes should produce
replacement character
, got: {:?}"
,
"incomplete 2-of-3 UTF-8 bytes should produce
DecodeResult::Partial
, got: {:?}"
,
tex
t
decode_resul
t
);
);
}
}
...
@@ -550,7 +569,7 @@ mod tests {
...
@@ -550,7 +569,7 @@ mod tests {
let
result
=
tokenizer
.decode
(
&
[
100
,
101
,
102
],
false
);
let
result
=
tokenizer
.decode
(
&
[
100
,
101
,
102
],
false
);
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
assert_eq!
(
result
.unwrap
(),
"你"
);
assert_eq!
(
String
::
from
(
result
.unwrap
()
)
,
"你"
);
}
}
/// All 4 emoji bytes together form valid UTF-8, so this passes both before and after
/// All 4 emoji bytes together form valid UTF-8, so this passes both before and after
...
@@ -562,7 +581,27 @@ mod tests {
...
@@ -562,7 +581,27 @@ mod tests {
let
result
=
tokenizer
.decode
(
&
[
200
,
201
,
202
,
203
],
false
);
let
result
=
tokenizer
.decode
(
&
[
200
,
201
,
202
,
203
],
false
);
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
assert_eq!
(
result
.unwrap
(),
"😀"
);
assert_eq!
(
String
::
from
(
result
.unwrap
()),
"😀"
);
}
/// Regression test: a vocabulary token whose raw bytes are EF BF BD (the valid
/// UTF-8 encoding of U+FFFD) must decode as `Complete`, not `Partial`. Before the
/// from_utf8 fast-path fix, from_utf8_lossy + the trailing-FFFD heuristic would
/// misclassify this as Partial, causing the incremental decoder to suppress it.
#[test]
fn
test_decode_legitimate_replacement_char_token_is_complete
()
{
let
dir
=
tempfile
::
tempdir
()
.unwrap
();
let
tokenizer
=
create_byte_token_tokenizer
(
dir
.path
());
let
result
=
tokenizer
.decode
(
&
[
300
],
false
);
assert
!
(
result
.is_ok
());
let
decode_result
=
result
.unwrap
();
assert
!
(
decode_result
.is_complete
(),
"legitimate U+FFFD vocab token must be Complete, got: {:?}"
,
decode_result
);
assert_eq!
(
decode_result
.as_str
(),
"
\
u{FFFD}"
);
}
}
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
...
@@ -573,7 +612,7 @@ mod tests {
...
@@ -573,7 +612,7 @@ mod tests {
let
result
=
tokenizer
.decode
(
&
[
200
],
false
);
let
result
=
tokenizer
.decode
(
&
[
200
],
false
);
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
assert
!
(
result
.unwrap
()
.
contains
(
'\
u
{
FFFD
}
'
));
assert
!
(
result
.unwrap
()
.
is_partial
(
));
}
}
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
/// Without the fix, fails with "incomplete utf-8 byte sequence" from CoreBPE::decode().
...
@@ -584,16 +623,17 @@ mod tests {
...
@@ -584,16 +623,17 @@ mod tests {
let
result
=
tokenizer
.decode
(
&
[
5
,
100
],
false
);
let
result
=
tokenizer
.decode
(
&
[
5
,
100
],
false
);
assert
!
(
result
.is_ok
());
assert
!
(
result
.is_ok
());
let
text
=
result
.unwrap
();
let
decode_result
=
result
.unwrap
();
assert
!
(
decode_result
.is_partial
(),
"trailing incomplete byte should produce DecodeResult::Partial"
);
let
text
:
String
=
decode_result
.into
();
assert
!
(
assert
!
(
text
.starts_with
(
"hello"
),
text
.starts_with
(
"hello"
),
"should start with 'hello', got: {:?}"
,
"should start with 'hello', got: {:?}"
,
text
text
);
);
assert
!
(
text
.contains
(
'\
u
{
FFFD
}
'
),
"trailing incomplete byte should produce U+FFFD"
);
}
}
/// End-to-end incremental detokenization: DecodeStream buffers partial bytes,
/// End-to-end incremental detokenization: DecodeStream buffers partial bytes,
...
@@ -656,7 +696,10 @@ mod tests {
...
@@ -656,7 +696,10 @@ mod tests {
assert_eq!
(
encodings
.len
(),
2
);
assert_eq!
(
encodings
.len
(),
2
);
for
(
encoding
,
input
)
in
encodings
.iter
()
.zip
(
inputs
.iter
())
{
for
(
encoding
,
input
)
in
encodings
.iter
()
.zip
(
inputs
.iter
())
{
let
decoded
=
tokenizer
.decode
(
encoding
.token_ids
(),
false
)
.unwrap
();
let
decoded
:
String
=
tokenizer
.decode
(
encoding
.token_ids
(),
false
)
.unwrap
()
.into
();
assert_eq!
(
decoded
,
*
input
);
assert_eq!
(
decoded
,
*
input
);
}
}
}
}
...
...
lib/llm/tests/test_stop_behavior.rs
View file @
2cabf441
...
@@ -25,8 +25,8 @@ impl tokenizer_traits::Encoder for TestTokenizer {
...
@@ -25,8 +25,8 @@ impl tokenizer_traits::Encoder for TestTokenizer {
}
}
impl
tokenizer_traits
::
Decoder
for
TestTokenizer
{
impl
tokenizer_traits
::
Decoder
for
TestTokenizer
{
fn
decode
(
&
self
,
ids
:
&
[
u32
],
skip_special
:
bool
)
->
Result
<
String
>
{
fn
decode
(
&
self
,
ids
:
&
[
u32
],
skip_special
:
bool
)
->
Result
<
tokenizer_traits
::
DecodeResult
>
{
Ok
(
ids
let
text
:
String
=
ids
.iter
()
.iter
()
.filter_map
(|
&
id
|
match
id
{
.filter_map
(|
&
id
|
match
id
{
EOS
if
skip_special
=>
None
,
EOS
if
skip_special
=>
None
,
...
@@ -36,7 +36,8 @@ impl tokenizer_traits::Decoder for TestTokenizer {
...
@@ -36,7 +36,8 @@ impl tokenizer_traits::Decoder for TestTokenizer {
EOS
=>
Some
(
"</s>"
),
EOS
=>
Some
(
"</s>"
),
_
=>
Some
(
"?"
),
_
=>
Some
(
"?"
),
})
})
.collect
())
.collect
();
Ok
(
text
.into
())
}
}
}
}
...
...
lib/llm/tests/tokenizers.rs
View file @
2cabf441
...
@@ -113,9 +113,10 @@ fn test_encode_decode_roundtrip(#[case] tokenizer: Arc<dyn Tokenizer>) {
...
@@ -113,9 +113,10 @@ fn test_encode_decode_roundtrip(#[case] tokenizer: Arc<dyn Tokenizer>) {
.unwrap_or_else
(|
e
|
panic!
(
"Failed to encode '{text}': {e}"
));
.unwrap_or_else
(|
e
|
panic!
(
"Failed to encode '{text}': {e}"
));
assert
!
(
!
encoding
.token_ids
()
.is_empty
());
assert
!
(
!
encoding
.token_ids
()
.is_empty
());
let
decoded
=
tokenizer
let
decoded
:
String
=
tokenizer
.decode
(
encoding
.token_ids
(),
false
)
.decode
(
encoding
.token_ids
(),
false
)
.unwrap_or_else
(|
e
|
panic!
(
"Failed to decode '{text}': {e}"
));
.unwrap_or_else
(|
e
|
panic!
(
"Failed to decode '{text}': {e}"
))
.into
();
assert_eq!
(
decoded
,
text
,
"Roundtrip failed for: '{text}'"
);
assert_eq!
(
decoded
,
text
,
"Roundtrip failed for: '{text}'"
);
}
}
}
}
...
@@ -129,9 +130,10 @@ fn test_encode_decode_roundtrip_multibyte(#[case] tokenizer: Arc<dyn Tokenizer>)
...
@@ -129,9 +130,10 @@ fn test_encode_decode_roundtrip_multibyte(#[case] tokenizer: Arc<dyn Tokenizer>)
.encode
(
text
)
.encode
(
text
)
.unwrap_or_else
(|
e
|
panic!
(
"Failed to encode '{text}': {e}"
));
.unwrap_or_else
(|
e
|
panic!
(
"Failed to encode '{text}': {e}"
));
let
decoded
=
tokenizer
let
decoded
:
String
=
tokenizer
.decode
(
encoding
.token_ids
(),
false
)
.decode
(
encoding
.token_ids
(),
false
)
.unwrap_or_else
(|
e
|
panic!
(
"Failed to decode '{text}': {e}"
));
.unwrap_or_else
(|
e
|
panic!
(
"Failed to decode '{text}': {e}"
))
.into
();
assert_eq!
(
decoded
,
text
,
"Roundtrip failed for: '{text}'"
);
assert_eq!
(
decoded
,
text
,
"Roundtrip failed for: '{text}'"
);
}
}
}
}
...
@@ -147,9 +149,10 @@ fn test_batch_encode_roundtrip(#[case] tokenizer: Arc<dyn Tokenizer>) {
...
@@ -147,9 +149,10 @@ fn test_batch_encode_roundtrip(#[case] tokenizer: Arc<dyn Tokenizer>) {
assert_eq!
(
encodings
.len
(),
inputs
.len
());
assert_eq!
(
encodings
.len
(),
inputs
.len
());
for
(
encoding
,
&
input
)
in
encodings
.iter
()
.zip
(
inputs
.iter
())
{
for
(
encoding
,
&
input
)
in
encodings
.iter
()
.zip
(
inputs
.iter
())
{
let
decoded
=
tokenizer
let
decoded
:
String
=
tokenizer
.decode
(
encoding
.token_ids
(),
false
)
.decode
(
encoding
.token_ids
(),
false
)
.expect
(
"Failed to decode"
);
.expect
(
"Failed to decode"
)
.into
();
assert_eq!
(
decoded
,
input
);
assert_eq!
(
decoded
,
input
);
}
}
}
}
...
@@ -354,14 +357,16 @@ fn test_decode_with_skip_special_tokens() {
...
@@ -354,14 +357,16 @@ fn test_decode_with_skip_special_tokens() {
token_ids
.push
(
2
);
// </s>
token_ids
.push
(
2
);
// </s>
// Decode with skip_special_tokens = false (should keep special tokens)
// Decode with skip_special_tokens = false (should keep special tokens)
let
decoded_with_special
=
tokenizer
let
decoded_with_special
:
String
=
tokenizer
.decode
(
&
token_ids
,
false
)
.decode
(
&
token_ids
,
false
)
.expect
(
"Failed to decode with skip_special_tokens=false"
);
.expect
(
"Failed to decode with skip_special_tokens=false"
)
.into
();
// Decode with skip_special_tokens = true (should remove special tokens)
// Decode with skip_special_tokens = true (should remove special tokens)
let
decoded_without_special
=
tokenizer
let
decoded_without_special
:
String
=
tokenizer
.decode
(
&
token_ids
,
true
)
.decode
(
&
token_ids
,
true
)
.expect
(
"Failed to decode with skip_special_tokens=true"
);
.expect
(
"Failed to decode with skip_special_tokens=true"
)
.into
();
// Validate exact matches on the entire decoded strings
// Validate exact matches on the entire decoded strings
assert_eq!
(
decoded_with_special
,
"<s> Hello world</s>"
);
assert_eq!
(
decoded_with_special
,
"<s> Hello world</s>"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment