Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
d08663ee
Unverified
Commit
d08663ee
authored
Aug 17, 2025
by
Simo Lin
Committed by
GitHub
Aug 17, 2025
Browse files
[router] tokenizer factory, hf tokenizer, and stop sequence detector (#9293)
Co-authored-by:
Chang Su
<
chang.s.su@oracle.com
>
parent
716e6827
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
935 additions
and
5 deletions
+935
-5
sgl-router/Cargo.toml
sgl-router/Cargo.toml
+5
-1
sgl-router/src/tokenizer/factory.rs
sgl-router/src/tokenizer/factory.rs
+228
-0
sgl-router/src/tokenizer/huggingface.rs
sgl-router/src/tokenizer/huggingface.rs
+189
-0
sgl-router/src/tokenizer/mod.rs
sgl-router/src/tokenizer/mod.rs
+14
-4
sgl-router/src/tokenizer/stop.rs
sgl-router/src/tokenizer/stop.rs
+499
-0
No files found.
sgl-router/Cargo.toml
View file @
d08663ee
...
@@ -3,6 +3,10 @@ name = "sglang_router_rs"
...
@@ -3,6 +3,10 @@ name = "sglang_router_rs"
version
=
"0.0.0"
version
=
"0.0.0"
edition
=
"2021"
edition
=
"2021"
[features]
default
=
["huggingface"]
huggingface
=
["tokenizers"]
[lib]
[lib]
name
=
"sglang_router_rs"
name
=
"sglang_router_rs"
# Pure Rust library: Just omit crate-type (defaults to rlib)
# Pure Rust library: Just omit crate-type (defaults to rlib)
...
@@ -44,7 +48,7 @@ thiserror = "2.0.12"
...
@@ -44,7 +48,7 @@ thiserror = "2.0.12"
url
=
"2.5.4"
url
=
"2.5.4"
tokio-stream
=
{
version
=
"0.1"
,
features
=
["sync"]
}
tokio-stream
=
{
version
=
"0.1"
,
features
=
["sync"]
}
anyhow
=
"1.0"
anyhow
=
"1.0"
tokenizers
=
"0.21.4"
tokenizers
=
{
version
=
"0.21.4"
,
optional
=
true
}
[dev-dependencies]
[dev-dependencies]
criterion
=
{
version
=
"0.5"
,
features
=
["html_reports"]
}
criterion
=
{
version
=
"0.5"
,
features
=
["html_reports"]
}
...
...
sgl-router/src/tokenizer/factory.rs
0 → 100644
View file @
d08663ee
use
super
::
traits
;
use
anyhow
::{
Error
,
Result
};
use
std
::
fs
::
File
;
use
std
::
io
::
Read
;
use
std
::
path
::
Path
;
use
std
::
sync
::
Arc
;
#[cfg(feature
=
"huggingface"
)]
use
super
::
huggingface
::
HuggingFaceTokenizer
;
/// Represents the type of tokenizer being used
#[derive(Debug,
Clone)]
pub
enum
TokenizerType
{
HuggingFace
(
String
),
Mock
,
// Future: SentencePiece, GGUF, Tiktoken
}
/// Create a tokenizer from a file path to a tokenizer file.
/// The file extension is used to determine the tokenizer type.
/// Supported file types are:
/// - json: HuggingFace tokenizer
/// - For testing: can return mock tokenizer
pub
fn
create_tokenizer_from_file
(
file_path
:
&
str
)
->
Result
<
Arc
<
dyn
traits
::
Tokenizer
>>
{
// Special case for testing
if
file_path
==
"mock"
||
file_path
==
"test"
{
return
Ok
(
Arc
::
new
(
super
::
mock
::
MockTokenizer
::
new
()));
}
let
path
=
Path
::
new
(
file_path
);
// Check if file exists
if
!
path
.exists
()
{
return
Err
(
Error
::
msg
(
format!
(
"File not found: {}"
,
file_path
)));
}
// Try to determine tokenizer type from extension
let
extension
=
path
.extension
()
.and_then
(
std
::
ffi
::
OsStr
::
to_str
)
.map
(|
s
|
s
.to_lowercase
());
match
extension
.as_deref
()
{
Some
(
"json"
)
=>
{
#[cfg(feature
=
"huggingface"
)]
{
let
tokenizer
=
HuggingFaceTokenizer
::
from_file
(
file_path
)
?
;
Ok
(
Arc
::
new
(
tokenizer
))
}
#[cfg(not(feature
=
"huggingface"
))]
{
Err
(
Error
::
msg
(
"HuggingFace support not enabled. Enable the 'huggingface' feature."
,
))
}
}
Some
(
"model"
)
=>
{
// SentencePiece model file
Err
(
Error
::
msg
(
"SentencePiece models not yet supported"
))
}
Some
(
"gguf"
)
=>
{
// GGUF format
Err
(
Error
::
msg
(
"GGUF format not yet supported"
))
}
_
=>
{
// Try to auto-detect by reading file content
auto_detect_tokenizer
(
file_path
)
}
}
}
/// Auto-detect tokenizer type by examining file content
fn
auto_detect_tokenizer
(
file_path
:
&
str
)
->
Result
<
Arc
<
dyn
traits
::
Tokenizer
>>
{
let
mut
file
=
File
::
open
(
file_path
)
?
;
let
mut
buffer
=
vec!
[
0u8
;
512
];
// Read first 512 bytes for detection
let
bytes_read
=
file
.read
(
&
mut
buffer
)
?
;
buffer
.truncate
(
bytes_read
);
// Check for JSON (HuggingFace format)
if
is_likely_json
(
&
buffer
)
{
#[cfg(feature
=
"huggingface"
)]
{
let
tokenizer
=
HuggingFaceTokenizer
::
from_file
(
file_path
)
?
;
return
Ok
(
Arc
::
new
(
tokenizer
));
}
#[cfg(not(feature
=
"huggingface"
))]
{
return
Err
(
Error
::
msg
(
"File appears to be JSON (HuggingFace) format, but HuggingFace support is not enabled"
,
));
}
}
// Check for GGUF magic number
if
buffer
.len
()
>=
4
&&
&
buffer
[
0
..
4
]
==
b
"GGUF"
{
return
Err
(
Error
::
msg
(
"GGUF format detected but not yet supported"
));
}
// Check for SentencePiece model
if
is_likely_sentencepiece
(
&
buffer
)
{
return
Err
(
Error
::
msg
(
"SentencePiece model detected but not yet supported"
,
));
}
Err
(
Error
::
msg
(
format!
(
"Unable to determine tokenizer type for file: {}"
,
file_path
)))
}
/// Check if the buffer likely contains JSON data
fn
is_likely_json
(
buffer
:
&
[
u8
])
->
bool
{
// Skip UTF-8 BOM if present
let
content
=
if
buffer
.len
()
>=
3
&&
buffer
[
0
..
3
]
==
[
0xEF
,
0xBB
,
0xBF
]
{
&
buffer
[
3
..
]
}
else
{
buffer
};
// Find first non-whitespace character without allocation
if
let
Some
(
first_byte
)
=
content
.iter
()
.find
(|
&&
b
|
!
b
.is_ascii_whitespace
())
{
*
first_byte
==
b
'{'
||
*
first_byte
==
b
'['
}
else
{
false
}
}
/// Check if the buffer likely contains a SentencePiece model
fn
is_likely_sentencepiece
(
buffer
:
&
[
u8
])
->
bool
{
// SentencePiece models often start with specific patterns
// This is a simplified check
buffer
.len
()
>=
12
&&
(
buffer
.starts_with
(
b
"
\x0a\x09
"
)
||
buffer
.starts_with
(
b
"
\x08\x00
"
)
||
buffer
.windows
(
4
)
.any
(|
w
|
w
==
b
"<unk"
)
||
buffer
.windows
(
4
)
.any
(|
w
|
w
==
b
"<s>"
)
||
buffer
.windows
(
4
)
.any
(|
w
|
w
==
b
"</s>"
))
}
/// Factory function to create tokenizer from a model name or path
pub
fn
create_tokenizer
(
model_name_or_path
:
&
str
)
->
Result
<
Arc
<
dyn
traits
::
Tokenizer
>>
{
// Check if it's a file path
let
path
=
Path
::
new
(
model_name_or_path
);
if
path
.exists
()
{
return
create_tokenizer_from_file
(
model_name_or_path
);
}
// Otherwise, try to load from HuggingFace Hub
#[cfg(feature
=
"huggingface"
)]
{
// This would download from HF Hub - not implemented yet
Err
(
Error
::
msg
(
"Loading from HuggingFace Hub not yet implemented"
,
))
}
#[cfg(not(feature
=
"huggingface"
))]
{
Err
(
Error
::
msg
(
format!
(
"Model '{}' not found locally and HuggingFace support is not enabled"
,
model_name_or_path
)))
}
}
/// Get information about a tokenizer file
pub
fn
get_tokenizer_info
(
file_path
:
&
str
)
->
Result
<
TokenizerType
>
{
let
path
=
Path
::
new
(
file_path
);
if
!
path
.exists
()
{
return
Err
(
Error
::
msg
(
format!
(
"File not found: {}"
,
file_path
)));
}
let
extension
=
path
.extension
()
.and_then
(
std
::
ffi
::
OsStr
::
to_str
)
.map
(|
s
|
s
.to_lowercase
());
match
extension
.as_deref
()
{
Some
(
"json"
)
=>
Ok
(
TokenizerType
::
HuggingFace
(
file_path
.to_string
())),
_
=>
{
// Try auto-detection
use
std
::
fs
::
File
;
use
std
::
io
::
Read
;
let
mut
file
=
File
::
open
(
file_path
)
?
;
let
mut
buffer
=
vec!
[
0u8
;
512
];
let
bytes_read
=
file
.read
(
&
mut
buffer
)
?
;
buffer
.truncate
(
bytes_read
);
if
is_likely_json
(
&
buffer
)
{
Ok
(
TokenizerType
::
HuggingFace
(
file_path
.to_string
()))
}
else
{
Err
(
Error
::
msg
(
"Unknown tokenizer type"
))
}
}
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_json_detection
()
{
assert
!
(
is_likely_json
(
b
"{
\"
test
\"
:
\"
value
\"
}"
));
assert
!
(
is_likely_json
(
b
"
\n\t
{
\"
test
\"
:
\"
value
\"
}"
));
assert
!
(
is_likely_json
(
b
"[1, 2, 3]"
));
assert
!
(
!
is_likely_json
(
b
"not json"
));
assert
!
(
!
is_likely_json
(
b
""
));
}
#[test]
fn
test_mock_tokenizer_creation
()
{
let
tokenizer
=
create_tokenizer_from_file
(
"mock"
)
.unwrap
();
assert_eq!
(
tokenizer
.vocab_size
(),
8
);
// Mock tokenizer has 8 tokens
}
#[test]
fn
test_file_not_found
()
{
let
result
=
create_tokenizer_from_file
(
"/nonexistent/file.json"
);
assert
!
(
result
.is_err
());
if
let
Err
(
e
)
=
result
{
assert
!
(
e
.to_string
()
.contains
(
"File not found"
));
}
}
}
sgl-router/src/tokenizer/huggingface.rs
0 → 100644
View file @
d08663ee
use
super
::
traits
::{
Decoder
,
Encoder
,
Encoding
,
SpecialTokens
,
Tokenizer
as
TokenizerTrait
};
use
anyhow
::{
Error
,
Result
};
use
std
::
collections
::
HashMap
;
use
tokenizers
::
tokenizer
::
Tokenizer
as
HfTokenizer
;
/// HuggingFace tokenizer wrapper
pub
struct
HuggingFaceTokenizer
{
tokenizer
:
HfTokenizer
,
special_tokens
:
SpecialTokens
,
vocab
:
HashMap
<
String
,
u32
>
,
reverse_vocab
:
HashMap
<
u32
,
String
>
,
}
impl
HuggingFaceTokenizer
{
/// Create a tokenizer from a HuggingFace tokenizer JSON file
pub
fn
from_file
(
file_path
:
&
str
)
->
Result
<
Self
>
{
let
tokenizer
=
HfTokenizer
::
from_file
(
file_path
)
.map_err
(|
e
|
Error
::
msg
(
format!
(
"Failed to load tokenizer: {}"
,
e
)))
?
;
// Extract special tokens
let
special_tokens
=
Self
::
extract_special_tokens
(
&
tokenizer
);
// Build vocab mappings
let
vocab
=
tokenizer
.get_vocab
(
false
);
let
reverse_vocab
:
HashMap
<
u32
,
String
>
=
vocab
.iter
()
.map
(|(
token
,
&
id
)|
(
id
,
token
.clone
()))
.collect
();
Ok
(
HuggingFaceTokenizer
{
tokenizer
,
special_tokens
,
vocab
,
reverse_vocab
,
})
}
/// Create from an existing HuggingFace tokenizer
pub
fn
from_tokenizer
(
tokenizer
:
HfTokenizer
)
->
Self
{
let
special_tokens
=
Self
::
extract_special_tokens
(
&
tokenizer
);
let
vocab
=
tokenizer
.get_vocab
(
false
);
let
reverse_vocab
:
HashMap
<
u32
,
String
>
=
vocab
.iter
()
.map
(|(
token
,
&
id
)|
(
id
,
token
.clone
()))
.collect
();
HuggingFaceTokenizer
{
tokenizer
,
special_tokens
,
vocab
,
reverse_vocab
,
}
}
/// Extract special tokens from the tokenizer
fn
extract_special_tokens
(
tokenizer
:
&
HfTokenizer
)
->
SpecialTokens
{
// Try to get special tokens from the tokenizer
// This is a simplified version - actual implementation would need to handle various formats
let
vocab
=
tokenizer
.get_vocab
(
true
);
let
find_token
=
|
patterns
:
&
[
&
str
]|
->
Option
<
String
>
{
for
pattern
in
patterns
{
if
vocab
.contains_key
(
*
pattern
)
{
return
Some
(
pattern
.to_string
());
}
}
None
};
SpecialTokens
{
bos_token
:
find_token
(
&
[
"<s>"
,
"<|startoftext|>"
,
"<BOS>"
,
"[CLS]"
]),
eos_token
:
find_token
(
&
[
"</s>"
,
"<|endoftext|>"
,
"<EOS>"
,
"[SEP]"
]),
unk_token
:
find_token
(
&
[
"<unk>"
,
"<UNK>"
,
"[UNK]"
]),
sep_token
:
find_token
(
&
[
"[SEP]"
,
"<sep>"
,
"<SEP>"
]),
pad_token
:
find_token
(
&
[
"<pad>"
,
"<PAD>"
,
"[PAD]"
]),
cls_token
:
find_token
(
&
[
"[CLS]"
,
"<cls>"
,
"<CLS>"
]),
mask_token
:
find_token
(
&
[
"[MASK]"
,
"<mask>"
,
"<MASK>"
]),
additional_special_tokens
:
vec!
[],
}
}
/// Apply chat template if available
pub
fn
apply_chat_template
(
&
self
,
messages
:
&
[
ChatMessage
])
->
Result
<
String
>
{
// This is a placeholder - actual implementation would handle templates
let
mut
result
=
String
::
new
();
for
msg
in
messages
{
result
.push_str
(
&
format!
(
"{}: {}
\n
"
,
msg
.role
,
msg
.content
));
}
Ok
(
result
)
}
}
impl
Encoder
for
HuggingFaceTokenizer
{
fn
encode
(
&
self
,
input
:
&
str
)
->
Result
<
Encoding
>
{
let
encoding
=
self
.tokenizer
.encode
(
input
,
false
)
.map_err
(|
e
|
Error
::
msg
(
format!
(
"Encoding failed: {}"
,
e
)))
?
;
Ok
(
Encoding
::
Hf
(
Box
::
new
(
encoding
)))
}
fn
encode_batch
(
&
self
,
inputs
:
&
[
&
str
])
->
Result
<
Vec
<
Encoding
>>
{
let
encodings
=
self
.tokenizer
.encode_batch
(
inputs
.to_vec
(),
false
)
.map_err
(|
e
|
Error
::
msg
(
format!
(
"Batch encoding failed: {}"
,
e
)))
?
;
Ok
(
encodings
.into_iter
()
.map
(|
e
|
Encoding
::
Hf
(
Box
::
new
(
e
)))
.collect
())
}
}
impl
Decoder
for
HuggingFaceTokenizer
{
fn
decode
(
&
self
,
token_ids
:
&
[
u32
],
skip_special_tokens
:
bool
)
->
Result
<
String
>
{
self
.tokenizer
.decode
(
token_ids
,
skip_special_tokens
)
.map_err
(|
e
|
Error
::
msg
(
format!
(
"Decoding failed: {}"
,
e
)))
}
}
impl
TokenizerTrait
for
HuggingFaceTokenizer
{
fn
vocab_size
(
&
self
)
->
usize
{
self
.tokenizer
.get_vocab_size
(
false
)
}
fn
get_special_tokens
(
&
self
)
->
&
SpecialTokens
{
&
self
.special_tokens
}
fn
token_to_id
(
&
self
,
token
:
&
str
)
->
Option
<
u32
>
{
self
.vocab
.get
(
token
)
.copied
()
}
fn
id_to_token
(
&
self
,
id
:
u32
)
->
Option
<
String
>
{
self
.reverse_vocab
.get
(
&
id
)
.cloned
()
}
}
/// Represents a chat message for template application
#[derive(Debug,
Clone)]
pub
struct
ChatMessage
{
pub
role
:
String
,
pub
content
:
String
,
}
impl
ChatMessage
{
pub
fn
new
(
role
:
impl
Into
<
String
>
,
content
:
impl
Into
<
String
>
)
->
Self
{
ChatMessage
{
role
:
role
.into
(),
content
:
content
.into
(),
}
}
pub
fn
system
(
content
:
impl
Into
<
String
>
)
->
Self
{
Self
::
new
(
"system"
,
content
)
}
pub
fn
user
(
content
:
impl
Into
<
String
>
)
->
Self
{
Self
::
new
(
"user"
,
content
)
}
pub
fn
assistant
(
content
:
impl
Into
<
String
>
)
->
Self
{
Self
::
new
(
"assistant"
,
content
)
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_chat_message_creation
()
{
let
msg
=
ChatMessage
::
system
(
"You are a helpful assistant"
);
assert_eq!
(
msg
.role
,
"system"
);
assert_eq!
(
msg
.content
,
"You are a helpful assistant"
);
let
user_msg
=
ChatMessage
::
user
(
"Hello!"
);
assert_eq!
(
user_msg
.role
,
"user"
);
let
assistant_msg
=
ChatMessage
::
assistant
(
"Hi there!"
);
assert_eq!
(
assistant_msg
.role
,
"assistant"
);
}
// Note: Actual tokenizer tests would require a real tokenizer file
// These would be integration tests rather than unit tests
}
sgl-router/src/tokenizer/mod.rs
View file @
d08663ee
...
@@ -2,26 +2,36 @@ use anyhow::Result;
...
@@ -2,26 +2,36 @@ use anyhow::Result;
use
std
::
ops
::
Deref
;
use
std
::
ops
::
Deref
;
use
std
::
sync
::
Arc
;
use
std
::
sync
::
Arc
;
pub
mod
factory
;
pub
mod
mock
;
pub
mod
mock
;
pub
mod
stop
;
pub
mod
stream
;
pub
mod
stream
;
pub
mod
traits
;
pub
mod
traits
;
// Feature-gated modules
#[cfg(feature
=
"huggingface"
)]
pub
mod
huggingface
;
#[cfg(test)]
#[cfg(test)]
mod
tests
;
mod
tests
;
// Re-exports
pub
use
factory
::{
create_tokenizer
,
create_tokenizer_from_file
,
TokenizerType
};
pub
use
stop
::{
SequenceDecoderOutput
,
StopSequenceConfig
,
StopSequenceDecoder
};
pub
use
stream
::
DecodeStream
;
pub
use
stream
::
DecodeStream
;
pub
use
traits
::{
Decoder
,
Encoder
,
Encoding
,
SpecialTokens
,
Tokenizer
as
TokenizerTrait
};
pub
use
traits
::{
Decoder
,
Encoder
,
Encoding
,
SpecialTokens
,
Tokenizer
as
TokenizerTrait
};
#[cfg(feature
=
"huggingface"
)]
pub
use
huggingface
::{
ChatMessage
,
HuggingFaceTokenizer
};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)]
#[derive(Clone)]
pub
struct
Tokenizer
(
Arc
<
dyn
traits
::
Tokenizer
>
);
pub
struct
Tokenizer
(
Arc
<
dyn
traits
::
Tokenizer
>
);
impl
Tokenizer
{
impl
Tokenizer
{
/// Create a tokenizer from a file path
/// Create a tokenizer from a file path
/// Will be implemented in Phase 3 with factory pattern
pub
fn
from_file
(
file_path
:
&
str
)
->
Result
<
Tokenizer
>
{
pub
fn
from_file
(
_
file_path
:
&
str
)
->
Result
<
Tokenizer
>
{
Ok
(
Tokenizer
(
factory
::
create_tokenizer_from_file
(
file_path
)
?
))
// TODO: Implement factory pattern in Phase 3
unimplemented!
(
"Factory pattern will be implemented in Phase 3"
)
}
}
/// Create a tokenizer from an Arc<dyn Tokenizer>
/// Create a tokenizer from an Arc<dyn Tokenizer>
...
...
sgl-router/src/tokenizer/stop.rs
0 → 100644
View file @
d08663ee
use
super
::
traits
;
use
anyhow
::
Result
;
use
std
::
collections
::
HashSet
;
use
std
::
sync
::
Arc
;
/// Output from the sequence decoder
#[derive(Debug,
Clone,
PartialEq)]
pub
enum
SequenceDecoderOutput
{
/// Normal text output
Text
(
String
),
/// Text is being held due to partial stop sequence match
Held
,
/// Stop sequence matched (hidden - not included in output)
Stopped
,
/// Stop sequence matched with text (visible - included in output)
StoppedWithText
(
String
),
}
/// Configuration for stop sequences
#[derive(Debug,
Clone,
Default)]
pub
struct
StopSequenceConfig
{
/// Token IDs that trigger a stop
pub
stop_tokens
:
HashSet
<
u32
>
,
/// String sequences that trigger a stop
pub
stop_sequences
:
Vec
<
String
>
,
/// Token IDs for visible stops (included in output)
pub
visible_stop_tokens
:
HashSet
<
u32
>
,
/// String sequences for visible stops (included in output)
pub
visible_stop_sequences
:
Vec
<
String
>
,
}
impl
StopSequenceConfig
{
/// Builder pattern - add a stop token
pub
fn
with_stop_token
(
mut
self
,
token_id
:
u32
)
->
Self
{
self
.stop_tokens
.insert
(
token_id
);
self
}
/// Builder pattern - add a stop sequence
pub
fn
with_stop_sequence
(
mut
self
,
sequence
:
impl
Into
<
String
>
)
->
Self
{
self
.stop_sequences
.push
(
sequence
.into
());
self
}
/// Builder pattern - add a visible stop token
pub
fn
with_visible_stop_token
(
mut
self
,
token_id
:
u32
)
->
Self
{
self
.visible_stop_tokens
.insert
(
token_id
);
self
}
/// Builder pattern - add a visible stop sequence
pub
fn
with_visible_stop_sequence
(
mut
self
,
sequence
:
impl
Into
<
String
>
)
->
Self
{
self
.visible_stop_sequences
.push
(
sequence
.into
());
self
}
}
/// Decoder that handles stop sequences
pub
struct
StopSequenceDecoder
{
tokenizer
:
Arc
<
dyn
traits
::
Tokenizer
>
,
config
:
StopSequenceConfig
,
/// Buffer for partial matches (the "jail")
jail_buffer
:
String
,
/// Accumulated tokens
token_buffer
:
Vec
<
u32
>
,
/// Offset where the prefix text starts (for context)
prefix_offset
:
usize
,
/// Offset marking the end of previously decoded text
read_offset
:
usize
,
/// Whether we've stopped
stopped
:
bool
,
skip_special_tokens
:
bool
,
}
impl
StopSequenceDecoder
{
/// Create a new stop sequence decoder
pub
fn
new
(
tokenizer
:
Arc
<
dyn
traits
::
Tokenizer
>
,
config
:
StopSequenceConfig
,
skip_special_tokens
:
bool
,
)
->
Self
{
StopSequenceDecoder
{
tokenizer
,
config
,
jail_buffer
:
String
::
new
(),
token_buffer
:
Vec
::
new
(),
prefix_offset
:
0
,
read_offset
:
0
,
stopped
:
false
,
skip_special_tokens
,
}
}
/// Process a single token
pub
fn
process_token
(
&
mut
self
,
token_id
:
u32
)
->
Result
<
SequenceDecoderOutput
>
{
if
self
.stopped
{
return
Ok
(
SequenceDecoderOutput
::
Stopped
);
}
// Check for token-level stops first
if
self
.config.stop_tokens
.contains
(
&
token_id
)
{
self
.stopped
=
true
;
// Flush any jailed text before stopping
if
!
self
.jail_buffer
.is_empty
()
{
let
output
=
self
.jail_buffer
.clone
();
self
.jail_buffer
.clear
();
return
Ok
(
SequenceDecoderOutput
::
StoppedWithText
(
output
));
}
return
Ok
(
SequenceDecoderOutput
::
Stopped
);
}
if
self
.config.visible_stop_tokens
.contains
(
&
token_id
)
{
self
.stopped
=
true
;
// Include jailed text plus the stop token
let
stop_text
=
self
.tokenizer
.decode
(
&
[
token_id
],
self
.skip_special_tokens
)
?
;
let
output
=
format!
(
"{}{}"
,
self
.jail_buffer
,
stop_text
);
self
.jail_buffer
.clear
();
return
Ok
(
SequenceDecoderOutput
::
StoppedWithText
(
output
));
}
// Add token to buffer
self
.token_buffer
.push
(
token_id
);
// Use incremental decoding like DecodeStream
// First decode the previous context (what we've already output)
let
prefix_text
=
if
self
.read_offset
>
self
.prefix_offset
{
self
.tokenizer
.decode
(
&
self
.token_buffer
[
self
.prefix_offset
..
self
.read_offset
],
self
.skip_special_tokens
,
)
?
}
else
{
String
::
new
()
};
// Now decode from prefix to current position
let
new_full_text
=
self
.tokenizer
.decode
(
&
self
.token_buffer
[
self
.prefix_offset
..
],
self
.skip_special_tokens
,
)
?
;
// Check for incomplete UTF-8 sequence
if
new_full_text
.ends_with
(
"�"
)
{
// Wait for more tokens to complete the sequence
return
Ok
(
SequenceDecoderOutput
::
Held
);
}
// Calculate only the NEW text since last successful decode
let
new_text
=
if
new_full_text
.len
()
>
prefix_text
.len
()
{
&
new_full_text
[
prefix_text
.len
()
..
]
}
else
{
// No new text produced (can happen with special tokens)
return
Ok
(
SequenceDecoderOutput
::
Held
);
};
// Combine jail buffer with new text for checking
let
check_text
=
format!
(
"{}{}"
,
self
.jail_buffer
,
new_text
);
// Check for complete stop sequences
for
stop_seq
in
&
self
.config.stop_sequences
{
if
let
Some
(
pos
)
=
check_text
.find
(
stop_seq
)
{
self
.stopped
=
true
;
// Output text before the stop sequence
let
output
=
check_text
[
..
pos
]
.to_string
();
self
.jail_buffer
.clear
();
return
Ok
(
if
output
.is_empty
()
{
SequenceDecoderOutput
::
Stopped
}
else
{
SequenceDecoderOutput
::
StoppedWithText
(
output
)
});
}
}
// Check for visible stop sequences
for
stop_seq
in
&
self
.config.visible_stop_sequences
{
if
let
Some
(
pos
)
=
check_text
.find
(
stop_seq
)
{
self
.stopped
=
true
;
// Include the stop sequence in output
let
end_pos
=
pos
+
stop_seq
.len
();
let
output
=
check_text
[
..
end_pos
]
.to_string
();
self
.jail_buffer
.clear
();
return
Ok
(
SequenceDecoderOutput
::
StoppedWithText
(
output
));
}
}
// Check for partial matches at the end of check_text
let
mut
partial_match_len
=
0
;
for
stop_seq
in
self
.config
.stop_sequences
.iter
()
.chain
(
&
self
.config.visible_stop_sequences
)
{
// Check all possible suffixes that could be a prefix of stop_seq
for
i
in
1
..=
check_text
.len
()
.min
(
stop_seq
.len
()
-
1
)
{
let
suffix
=
&
check_text
[
check_text
.len
()
-
i
..
];
if
stop_seq
.starts_with
(
suffix
)
{
partial_match_len
=
partial_match_len
.max
(
i
);
}
}
}
if
partial_match_len
>
0
{
// Split: output safe text, jail the potential match
let
safe_end
=
check_text
.len
()
-
partial_match_len
;
let
safe_text
=
&
check_text
[
..
safe_end
];
self
.jail_buffer
=
check_text
[
safe_end
..
]
.to_string
();
// Update offsets for next iteration
self
.prefix_offset
=
self
.read_offset
;
self
.read_offset
=
self
.token_buffer
.len
();
if
safe_text
.is_empty
()
{
Ok
(
SequenceDecoderOutput
::
Held
)
}
else
{
Ok
(
SequenceDecoderOutput
::
Text
(
safe_text
.to_string
()))
}
}
else
{
// No partial matches - output everything
self
.jail_buffer
.clear
();
// Update offsets for next iteration
self
.prefix_offset
=
self
.read_offset
;
self
.read_offset
=
self
.token_buffer
.len
();
Ok
(
SequenceDecoderOutput
::
Text
(
check_text
))
}
}
/// Process multiple tokens
pub
fn
process_tokens
(
&
mut
self
,
token_ids
:
&
[
u32
])
->
Result
<
Vec
<
SequenceDecoderOutput
>>
{
let
mut
outputs
=
Vec
::
new
();
for
&
token_id
in
token_ids
{
outputs
.push
(
self
.process_token
(
token_id
)
?
);
}
Ok
(
outputs
)
}
/// Flush any held text
pub
fn
flush
(
&
mut
self
)
->
SequenceDecoderOutput
{
if
!
self
.jail_buffer
.is_empty
()
{
let
output
=
self
.jail_buffer
.clone
();
self
.jail_buffer
.clear
();
SequenceDecoderOutput
::
Text
(
output
)
}
else
{
SequenceDecoderOutput
::
Text
(
String
::
new
())
}
}
/// Check if decoding has stopped
pub
fn
is_stopped
(
&
self
)
->
bool
{
self
.stopped
}
/// Reset the decoder state
pub
fn
reset
(
&
mut
self
)
{
self
.jail_buffer
.clear
();
self
.token_buffer
.clear
();
self
.prefix_offset
=
0
;
self
.read_offset
=
0
;
self
.stopped
=
false
;
}
}
/// Builder for StopSequenceDecoder
pub
struct
StopSequenceDecoderBuilder
{
tokenizer
:
Arc
<
dyn
traits
::
Tokenizer
>
,
config
:
StopSequenceConfig
,
skip_special_tokens
:
bool
,
}
impl
StopSequenceDecoderBuilder
{
pub
fn
new
(
tokenizer
:
Arc
<
dyn
traits
::
Tokenizer
>
)
->
Self
{
StopSequenceDecoderBuilder
{
tokenizer
,
config
:
StopSequenceConfig
::
default
(),
skip_special_tokens
:
true
,
}
}
pub
fn
stop_token
(
mut
self
,
token_id
:
u32
)
->
Self
{
self
.config.stop_tokens
.insert
(
token_id
);
self
}
pub
fn
stop_sequence
(
mut
self
,
sequence
:
impl
Into
<
String
>
)
->
Self
{
self
.config.stop_sequences
.push
(
sequence
.into
());
self
}
pub
fn
visible_stop_token
(
mut
self
,
token_id
:
u32
)
->
Self
{
self
.config.visible_stop_tokens
.insert
(
token_id
);
self
}
pub
fn
visible_stop_sequence
(
mut
self
,
sequence
:
impl
Into
<
String
>
)
->
Self
{
self
.config.visible_stop_sequences
.push
(
sequence
.into
());
self
}
pub
fn
skip_special_tokens
(
mut
self
,
skip
:
bool
)
->
Self
{
self
.skip_special_tokens
=
skip
;
self
}
pub
fn
build
(
self
)
->
StopSequenceDecoder
{
StopSequenceDecoder
::
new
(
self
.tokenizer
,
self
.config
,
self
.skip_special_tokens
)
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
tokenizer
::
mock
::
MockTokenizer
;
#[test]
fn
test_stop_token_detection
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
config
=
StopSequenceConfig
::
default
()
.with_stop_token
(
999
);
// <eos> token
let
mut
decoder
=
StopSequenceDecoder
::
new
(
tokenizer
,
config
,
false
);
// Process tokens before stop
let
result
=
decoder
.process_token
(
1
)
.unwrap
();
// "Hello"
assert
!
(
matches!
(
result
,
SequenceDecoderOutput
::
Text
(
_
)));
// Process stop token
let
result
=
decoder
.process_token
(
999
)
.unwrap
();
// <eos>
assert_eq!
(
result
,
SequenceDecoderOutput
::
Stopped
);
// Further tokens should also return Stopped
let
result
=
decoder
.process_token
(
2
)
.unwrap
();
assert_eq!
(
result
,
SequenceDecoderOutput
::
Stopped
);
}
#[test]
fn
test_visible_stop_token
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
config
=
StopSequenceConfig
::
default
()
.with_visible_stop_token
(
999
);
let
mut
decoder
=
StopSequenceDecoder
::
new
(
tokenizer
,
config
,
false
);
let
result
=
decoder
.process_token
(
999
)
.unwrap
();
assert
!
(
matches!
(
result
,
SequenceDecoderOutput
::
StoppedWithText
(
_
)));
}
#[test]
fn
test_builder_pattern
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
decoder
=
StopSequenceDecoderBuilder
::
new
(
tokenizer
)
.stop_token
(
999
)
.stop_sequence
(
"STOP"
)
.visible_stop_token
(
1000
)
.skip_special_tokens
(
true
)
.build
();
assert
!
(
!
decoder
.is_stopped
());
}
#[test]
fn
test_incremental_decoding_no_repetition
()
{
// This test verifies the critical fix: no repeated output
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
config
=
StopSequenceConfig
::
default
();
let
mut
decoder
=
StopSequenceDecoder
::
new
(
tokenizer
,
config
,
false
);
// Process tokens one by one and collect outputs
let
mut
outputs
=
Vec
::
new
();
// Token 1: "Hello"
let
result
=
decoder
.process_token
(
1
)
.unwrap
();
if
let
SequenceDecoderOutput
::
Text
(
text
)
=
result
{
outputs
.push
(
text
.clone
());
}
// Token 2: "world"
let
result
=
decoder
.process_token
(
2
)
.unwrap
();
if
let
SequenceDecoderOutput
::
Text
(
text
)
=
result
{
outputs
.push
(
text
.clone
());
}
// Token 3: "test"
let
result
=
decoder
.process_token
(
3
)
.unwrap
();
if
let
SequenceDecoderOutput
::
Text
(
text
)
=
result
{
outputs
.push
(
text
.clone
());
}
// CRITICAL: Each output should be unique (no accumulation)
// The fix ensures we only output NEW text, not accumulated text
assert_eq!
(
outputs
.len
(),
3
);
// Verify no text is repeated
for
i
in
0
..
outputs
.len
()
{
for
j
in
i
+
1
..
outputs
.len
()
{
// No output should contain another (no accumulation)
assert
!
(
!
outputs
[
j
]
.contains
(
&
outputs
[
i
]));
}
}
}
#[test]
fn
test_stop_sequence_detection
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
config
=
StopSequenceConfig
::
default
()
.with_stop_sequence
(
"test"
);
let
mut
decoder
=
StopSequenceDecoder
::
new
(
tokenizer
,
config
,
false
);
// Process "Hello world"
decoder
.process_token
(
1
)
.unwrap
();
// "Hello"
decoder
.process_token
(
2
)
.unwrap
();
// "world"
// Process "test" which should trigger stop
let
result
=
decoder
.process_token
(
3
)
.unwrap
();
// "test"
// Should stop when we hit "test"
assert
!
(
matches!
(
result
,
SequenceDecoderOutput
::
Stopped
|
SequenceDecoderOutput
::
StoppedWithText
(
_
)
));
}
#[test]
fn
test_flush_after_partial
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
config
=
StopSequenceConfig
::
default
()
.with_stop_sequence
(
"NEVER_MATCH"
);
let
mut
decoder
=
StopSequenceDecoder
::
new
(
tokenizer
,
config
,
false
);
// Process a token
decoder
.process_token
(
1
)
.unwrap
();
// "Hello"
// Flush should return any remaining text in jail
let
result
=
decoder
.flush
();
// After processing, flush should work
assert
!
(
matches!
(
result
,
SequenceDecoderOutput
::
Text
(
_
)));
}
#[test]
fn
test_reset_functionality
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
config
=
StopSequenceConfig
::
default
()
.with_stop_token
(
999
);
let
mut
decoder
=
StopSequenceDecoder
::
new
(
tokenizer
,
config
,
false
);
// Process and stop
decoder
.process_token
(
1
)
.unwrap
();
decoder
.process_token
(
999
)
.unwrap
();
assert
!
(
decoder
.is_stopped
());
// Reset should clear everything
decoder
.reset
();
assert
!
(
!
decoder
.is_stopped
());
// Should be able to process again
let
result
=
decoder
.process_token
(
2
)
.unwrap
();
assert
!
(
matches!
(
result
,
SequenceDecoderOutput
::
Text
(
_
)));
}
#[test]
fn
test_visible_stop_sequence
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
config
=
StopSequenceConfig
::
default
()
.with_visible_stop_sequence
(
"world"
);
let
mut
decoder
=
StopSequenceDecoder
::
new
(
tokenizer
,
config
,
false
);
// Process "Hello"
decoder
.process_token
(
1
)
.unwrap
();
// Process "world" - should include it in output
let
result
=
decoder
.process_token
(
2
)
.unwrap
();
if
let
SequenceDecoderOutput
::
StoppedWithText
(
text
)
=
result
{
// Should include "world" in the output
assert
!
(
text
.contains
(
"world"
));
}
else
{
panic!
(
"Expected StoppedWithText with visible stop sequence"
);
}
}
#[test]
fn
test_multiple_tokens_processing
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
config
=
StopSequenceConfig
::
default
();
let
mut
decoder
=
StopSequenceDecoder
::
new
(
tokenizer
,
config
,
false
);
// Process multiple tokens at once
let
results
=
decoder
.process_tokens
(
&
[
1
,
2
,
3
])
.unwrap
();
// Should get results for each token
assert_eq!
(
results
.len
(),
3
);
// Each result should be Text (no stops configured)
for
result
in
results
{
assert
!
(
matches!
(
result
,
SequenceDecoderOutput
::
Text
(
_
)
|
SequenceDecoderOutput
::
Held
));
}
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment