Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0b95a01a
Unverified
Commit
0b95a01a
authored
Aug 19, 2025
by
Simo Lin
Committed by
GitHub
Aug 19, 2025
Browse files
[router] add tiktokenizer and sequence in router (#9354)
Co-authored-by:
Chang Su
<
chang.s.su@oracle.com
>
parent
a3b810eb
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
578 additions
and
4 deletions
+578
-4
sgl-router/Cargo.toml
sgl-router/Cargo.toml
+2
-0
sgl-router/src/tokenizer/factory.rs
sgl-router/src/tokenizer/factory.rs
+35
-2
sgl-router/src/tokenizer/mod.rs
sgl-router/src/tokenizer/mod.rs
+8
-0
sgl-router/src/tokenizer/sequence.rs
sgl-router/src/tokenizer/sequence.rs
+238
-0
sgl-router/src/tokenizer/tests.rs
sgl-router/src/tokenizer/tests.rs
+3
-1
sgl-router/src/tokenizer/tiktoken.rs
sgl-router/src/tokenizer/tiktoken.rs
+276
-0
sgl-router/src/tokenizer/traits.rs
sgl-router/src/tokenizer/traits.rs
+16
-1
No files found.
sgl-router/Cargo.toml
View file @
0b95a01a
...
...
@@ -6,6 +6,7 @@ edition = "2021"
[features]
default
=
["huggingface"]
huggingface
=
["tokenizers"]
tiktoken
=
["tiktoken-rs"]
[lib]
name
=
"sglang_router_rs"
...
...
@@ -49,6 +50,7 @@ url = "2.5.4"
tokio-stream
=
{
version
=
"0.1"
,
features
=
["sync"]
}
anyhow
=
"1.0"
tokenizers
=
{
version
=
"0.21.4"
,
optional
=
true
}
tiktoken-rs
=
{
version
=
"0.5"
,
optional
=
true
}
[dev-dependencies]
criterion
=
{
version
=
"0.5"
,
features
=
["html_reports"]
}
...
...
sgl-router/src/tokenizer/factory.rs
View file @
0b95a01a
use
super
::
{
traits
,
TokenizerTrait
};
use
super
::
traits
::{
self
,
Tokenizer
as
TokenizerTrait
};
use
crate
::
metrics
::
TokenizerMetrics
;
use
anyhow
::{
Error
,
Result
};
use
std
::
fs
::
File
;
...
...
@@ -15,7 +15,9 @@ use super::huggingface::HuggingFaceTokenizer;
pub
enum
TokenizerType
{
HuggingFace
(
String
),
Mock
,
// Future: SentencePiece, GGUF, Tiktoken
#[cfg(feature
=
"tiktoken"
)]
Tiktoken
(
String
),
// Future: SentencePiece, GGUF
}
/// Create a tokenizer from a file path to a tokenizer file.
...
...
@@ -166,6 +168,23 @@ pub fn create_tokenizer(model_name_or_path: &str) -> Result<Arc<dyn traits::Toke
return
create_tokenizer_from_file
(
model_name_or_path
);
}
// Check if it's a GPT model name that should use Tiktoken
#[cfg(feature
=
"tiktoken"
)]
{
if
model_name_or_path
.contains
(
"gpt-"
)
||
model_name_or_path
.contains
(
"davinci"
)
||
model_name_or_path
.contains
(
"curie"
)
||
model_name_or_path
.contains
(
"babbage"
)
||
model_name_or_path
.contains
(
"ada"
)
{
use
super
::
tiktoken
::
TiktokenTokenizer
;
let
tokenizer
=
TiktokenTokenizer
::
from_model_name
(
model_name_or_path
)
?
;
TokenizerMetrics
::
record_factory_load
(
"tiktoken"
);
TokenizerMetrics
::
set_vocab_size
(
"tiktoken"
,
tokenizer
.vocab_size
());
return
Ok
(
Arc
::
new
(
tokenizer
));
}
}
// Otherwise, try to load from HuggingFace Hub
#[cfg(feature
=
"huggingface"
)]
{
...
...
@@ -245,4 +264,18 @@ mod tests {
assert
!
(
e
.to_string
()
.contains
(
"File not found"
));
}
}
#[cfg(feature
=
"tiktoken"
)]
#[test]
fn
test_create_tiktoken_tokenizer
()
{
// Test creating tokenizer for GPT models
let
tokenizer
=
create_tokenizer
(
"gpt-4"
)
.unwrap
();
assert
!
(
tokenizer
.vocab_size
()
>
0
);
// Test encoding and decoding
let
text
=
"Hello, world!"
;
let
encoding
=
tokenizer
.encode
(
text
)
.unwrap
();
let
decoded
=
tokenizer
.decode
(
&
encoding
.token_ids
(),
false
)
.unwrap
();
assert_eq!
(
decoded
,
text
);
}
}
sgl-router/src/tokenizer/mod.rs
View file @
0b95a01a
...
...
@@ -4,6 +4,7 @@ use std::sync::Arc;
pub
mod
factory
;
pub
mod
mock
;
pub
mod
sequence
;
pub
mod
stop
;
pub
mod
stream
;
pub
mod
traits
;
...
...
@@ -12,11 +13,15 @@ pub mod traits;
#[cfg(feature
=
"huggingface"
)]
pub
mod
huggingface
;
#[cfg(feature
=
"tiktoken"
)]
pub
mod
tiktoken
;
#[cfg(test)]
mod
tests
;
// Re-exports
pub
use
factory
::{
create_tokenizer
,
create_tokenizer_from_file
,
TokenizerType
};
pub
use
sequence
::
Sequence
;
pub
use
stop
::{
SequenceDecoderOutput
,
StopSequenceConfig
,
StopSequenceDecoder
};
pub
use
stream
::
DecodeStream
;
pub
use
traits
::{
Decoder
,
Encoder
,
Encoding
,
SpecialTokens
,
Tokenizer
as
TokenizerTrait
};
...
...
@@ -24,6 +29,9 @@ pub use traits::{Decoder, Encoder, Encoding, SpecialTokens, Tokenizer as Tokeniz
#[cfg(feature
=
"huggingface"
)]
pub
use
huggingface
::{
ChatMessage
,
HuggingFaceTokenizer
};
#[cfg(feature
=
"tiktoken"
)]
pub
use
tiktoken
::{
TiktokenModel
,
TiktokenTokenizer
};
/// Main tokenizer wrapper that provides a unified interface for different tokenizer implementations
#[derive(Clone)]
pub
struct
Tokenizer
(
Arc
<
dyn
traits
::
Tokenizer
>
);
...
...
sgl-router/src/tokenizer/sequence.rs
0 → 100644
View file @
0b95a01a
use
super
::
traits
::
Tokenizer
as
TokenizerTrait
;
use
anyhow
::
Result
;
use
std
::
sync
::
Arc
;
/// Maintains state for an ongoing sequence of tokens and their decoded text
/// This provides a cleaner abstraction for managing token sequences
pub
struct
Sequence
{
/// The tokenizer used for encoding/decoding
tokenizer
:
Arc
<
dyn
TokenizerTrait
>
,
/// The current sequence of token ids
token_ids
:
Vec
<
u32
>
,
/// The position in the current sequence the last decoded token completed
prefix_offset
:
usize
,
/// Current position in the sequence
read_offset
:
usize
,
}
impl
std
::
fmt
::
Debug
for
Sequence
{
fn
fmt
(
&
self
,
f
:
&
mut
std
::
fmt
::
Formatter
<
'_
>
)
->
std
::
fmt
::
Result
{
f
.debug_struct
(
"Sequence"
)
.field
(
"tokenizer"
,
&
"Arc<dyn Tokenizer>"
)
.field
(
"token_ids"
,
&
format_args!
(
"{}"
,
{
let
token_ids
=
self
.token_ids
();
if
token_ids
.len
()
<=
20
{
format!
(
"{:?}"
,
token_ids
)
}
else
{
let
first_ten
=
&
token_ids
[
..
10
];
let
last_ten
=
&
token_ids
[
token_ids
.len
()
-
10
..
];
format!
(
"{:?} ... {:?}"
,
first_ten
,
last_ten
)
}
}),
)
.field
(
"prefix_offset"
,
&
self
.prefix_offset
)
.field
(
"read_offset"
,
&
self
.read_offset
)
.field
(
"token count"
,
&
self
.token_ids
.len
())
.finish
()
}
}
impl
Sequence
{
/// Create a new empty sequence
pub
fn
new
(
tokenizer
:
Arc
<
dyn
TokenizerTrait
>
)
->
Self
{
Self
{
tokenizer
,
token_ids
:
Vec
::
new
(),
prefix_offset
:
0
,
read_offset
:
0
,
}
}
/// Create a sequence with initial tokens
pub
fn
with_tokens
(
tokenizer
:
Arc
<
dyn
TokenizerTrait
>
,
token_ids
:
Vec
<
u32
>
)
->
Self
{
let
len
=
token_ids
.len
();
Self
{
tokenizer
,
token_ids
,
prefix_offset
:
0
,
read_offset
:
len
,
}
}
/// Check if the sequence is empty
pub
fn
is_empty
(
&
self
)
->
bool
{
self
.token_ids
.is_empty
()
}
/// Get the length of the sequence
pub
fn
len
(
&
self
)
->
usize
{
self
.token_ids
.len
()
}
/// Clear the sequence
pub
fn
clear
(
&
mut
self
)
{
self
.token_ids
.clear
();
self
.prefix_offset
=
0
;
self
.read_offset
=
0
;
}
/// Append text to the sequence by encoding it
pub
fn
append_text
(
&
mut
self
,
input
:
&
str
)
->
Result
<
()
>
{
let
encoding
=
self
.tokenizer
.encode
(
input
)
?
;
self
.token_ids
.extend
(
encoding
.token_ids
());
Ok
(())
}
/// Append a single token to the sequence and return newly decoded text
/// Based on HuggingFace TGI incremental decoding
pub
fn
append_token
(
&
mut
self
,
token_id
:
u32
)
->
Result
<
String
>
{
// Store the old read offset before adding the new token
let
old_read_offset
=
self
.read_offset
;
self
.token_ids
.push
(
token_id
);
self
.read_offset
=
self
.token_ids
.len
();
// If this is the first token or we're at the beginning, decode everything
if
self
.prefix_offset
==
0
&&
old_read_offset
==
0
{
let
text
=
self
.tokenizer
.decode
(
&
self
.token_ids
,
false
)
?
;
if
text
.ends_with
(
"�"
)
{
// Incomplete UTF-8 sequence, wait for more tokens
return
Ok
(
String
::
new
());
}
self
.prefix_offset
=
0
;
return
Ok
(
text
);
}
// Decode the text up to the previous position
let
prefix_text
=
self
.tokenizer
.decode
(
&
self
.token_ids
[
self
.prefix_offset
..
old_read_offset
],
false
)
?
;
// Decode the text including the new token
let
new_text
=
self
.tokenizer
.decode
(
&
self
.token_ids
[
self
.prefix_offset
..
],
false
)
?
;
// Handle multi-byte character boundaries
let
mut
prefix_text_len
=
prefix_text
.len
();
while
!
new_text
.is_char_boundary
(
prefix_text_len
)
&&
prefix_text_len
>
0
{
prefix_text_len
-=
1
;
}
if
new_text
.len
()
>
prefix_text
.len
()
{
if
new_text
.ends_with
(
"�"
)
{
// Incomplete UTF-8 sequence, wait for more tokens
return
Ok
(
String
::
new
());
}
else
{
// Return the new text portion
let
incremental_text
=
new_text
[
prefix_text_len
..
]
.to_string
()
.replace
(
"�"
,
""
);
self
.prefix_offset
=
old_read_offset
;
return
Ok
(
incremental_text
);
}
}
Ok
(
String
::
new
())
}
/// Get a reference to the tokenizer
pub
fn
tokenizer
(
&
self
)
->
&
Arc
<
dyn
TokenizerTrait
>
{
&
self
.tokenizer
}
/// Get the current token ids
pub
fn
token_ids
(
&
self
)
->
&
[
u32
]
{
&
self
.token_ids
}
/// Decode the entire sequence to text
pub
fn
text
(
&
self
)
->
Result
<
String
>
{
self
.tokenizer
.decode
(
&
self
.token_ids
,
false
)
}
/// Get the prefix offset
pub
fn
prefix_offset
(
&
self
)
->
usize
{
self
.prefix_offset
}
/// Get the read offset
pub
fn
read_offset
(
&
self
)
->
usize
{
self
.read_offset
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
use
crate
::
tokenizer
::
mock
::
MockTokenizer
;
#[test]
fn
test_sequence_new
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
seq
=
Sequence
::
new
(
tokenizer
);
assert
!
(
seq
.is_empty
());
assert_eq!
(
seq
.len
(),
0
);
}
#[test]
fn
test_sequence_append_text
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
mut
seq
=
Sequence
::
new
(
tokenizer
);
seq
.append_text
(
"Hello"
)
.unwrap
();
assert
!
(
!
seq
.is_empty
());
assert
!
(
!
seq
.is_empty
());
let
text
=
seq
.text
()
.unwrap
();
assert_eq!
(
text
,
"Hello"
);
}
#[test]
fn
test_sequence_append_token
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
mut
seq
=
Sequence
::
new
(
tokenizer
.clone
());
// Start with an empty sequence and append token 1 ("Hello")
let
text1
=
seq
.append_token
(
1
)
.unwrap
();
assert_eq!
(
text1
,
"Hello"
);
// Now append token 2 ("world")
// The mock tokenizer will decode [1, 2] as "Hello world" (with a space)
let
text2
=
seq
.append_token
(
2
)
.unwrap
();
// The incremental text should be " world" (with the space that the mock tokenizer adds)
assert_eq!
(
text2
,
" world"
);
// Verify the full text
assert_eq!
(
seq
.text
()
.unwrap
(),
"Hello world"
);
}
#[test]
fn
test_sequence_clear
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
mut
seq
=
Sequence
::
new
(
tokenizer
);
seq
.append_text
(
"Hello world"
)
.unwrap
();
assert
!
(
!
seq
.is_empty
());
seq
.clear
();
assert
!
(
seq
.is_empty
());
assert_eq!
(
seq
.len
(),
0
);
assert_eq!
(
seq
.prefix_offset
(),
0
);
assert_eq!
(
seq
.read_offset
(),
0
);
}
#[test]
fn
test_sequence_debug
()
{
let
tokenizer
=
Arc
::
new
(
MockTokenizer
::
new
());
let
mut
seq
=
Sequence
::
new
(
tokenizer
);
seq
.append_text
(
"Test"
)
.unwrap
();
let
debug_str
=
format!
(
"{:?}"
,
seq
);
assert
!
(
debug_str
.contains
(
"Sequence"
));
assert
!
(
debug_str
.contains
(
"token count"
));
}
}
sgl-router/src/tokenizer/tests.rs
View file @
0b95a01a
...
...
@@ -129,7 +129,9 @@ fn test_thread_safety() {
thread
::
spawn
(
move
||
{
let
text
=
"Hello test"
.to_string
();
let
encoding
=
tokenizer_clone
.encode
(
&
text
)
.unwrap
();
let
decoded
=
tokenizer_clone
.decode
(
encoding
.token_ids
(),
false
)
.unwrap
();
let
decoded
=
tokenizer_clone
.decode
(
&
encoding
.token_ids
(),
false
)
.unwrap
();
assert
!
(
decoded
.contains
(
"Hello"
)
||
decoded
.contains
(
"test"
));
i
})
...
...
sgl-router/src/tokenizer/tiktoken.rs
0 → 100644
View file @
0b95a01a
use
super
::
traits
::{
Decoder
,
Encoder
,
Encoding
,
SpecialTokens
,
Tokenizer
as
TokenizerTrait
};
use
anyhow
::{
Error
,
Result
};
use
tiktoken_rs
::{
cl100k_base
,
p50k_base
,
p50k_edit
,
r50k_base
,
CoreBPE
};
/// Tiktoken tokenizer wrapper for OpenAI GPT models
pub
struct
TiktokenTokenizer
{
tokenizer
:
CoreBPE
,
#[allow(dead_code)]
model
:
TiktokenModel
,
special_tokens
:
SpecialTokens
,
vocab_size
:
usize
,
}
/// Supported Tiktoken models
#[derive(Debug,
Clone,
Copy)]
pub
enum
TiktokenModel
{
/// GPT-4, GPT-3.5-turbo, text-embedding-ada-002
Cl100kBase
,
/// Codex models, text-davinci-002, text-davinci-003
P50kBase
,
/// Use for edit models like text-davinci-edit-001, code-davinci-edit-001
P50kEdit
,
/// GPT-3 models like davinci
R50kBase
,
}
impl
TiktokenTokenizer
{
/// Create a new Tiktoken tokenizer for the specified model
pub
fn
new
(
model
:
TiktokenModel
)
->
Result
<
Self
>
{
let
tokenizer
=
match
model
{
TiktokenModel
::
Cl100kBase
=>
cl100k_base
()
.map_err
(|
e
|
Error
::
msg
(
format!
(
"Failed to load cl100k_base: {}"
,
e
)))
?
,
TiktokenModel
::
P50kBase
=>
p50k_base
()
.map_err
(|
e
|
Error
::
msg
(
format!
(
"Failed to load p50k_base: {}"
,
e
)))
?
,
TiktokenModel
::
P50kEdit
=>
p50k_edit
()
.map_err
(|
e
|
Error
::
msg
(
format!
(
"Failed to load p50k_edit: {}"
,
e
)))
?
,
TiktokenModel
::
R50kBase
=>
r50k_base
()
.map_err
(|
e
|
Error
::
msg
(
format!
(
"Failed to load r50k_base: {}"
,
e
)))
?
,
};
// Extract special tokens (tiktoken-rs doesn't expose them directly)
// We'll use common ones for GPT models
let
special_tokens
=
Self
::
get_special_tokens_for_model
(
model
);
// Get vocabulary size (this is an approximation)
let
vocab_size
=
match
model
{
TiktokenModel
::
Cl100kBase
=>
100256
,
// cl100k has ~100k tokens
TiktokenModel
::
P50kBase
|
TiktokenModel
::
P50kEdit
=>
50281
,
// p50k has ~50k tokens
TiktokenModel
::
R50kBase
=>
50257
,
// r50k has ~50k tokens
};
Ok
(
TiktokenTokenizer
{
tokenizer
,
model
,
special_tokens
,
vocab_size
,
})
}
/// Create a tokenizer from a model string (e.g., "gpt-4", "gpt-3.5-turbo")
pub
fn
from_model_name
(
model_name
:
&
str
)
->
Result
<
Self
>
{
let
model
=
Self
::
model_from_name
(
model_name
)
?
;
Self
::
new
(
model
)
}
/// Determine the appropriate model from a model name
fn
model_from_name
(
model_name
:
&
str
)
->
Result
<
TiktokenModel
>
{
// Based on OpenAI's model-to-encoding mapping
if
model_name
.contains
(
"gpt-4"
)
||
model_name
.contains
(
"gpt-3.5"
)
||
model_name
.contains
(
"turbo"
)
{
Ok
(
TiktokenModel
::
Cl100kBase
)
}
else
if
model_name
.contains
(
"davinci-002"
)
||
model_name
.contains
(
"davinci-003"
)
||
model_name
.contains
(
"codex"
)
{
Ok
(
TiktokenModel
::
P50kBase
)
}
else
if
model_name
.contains
(
"edit"
)
{
Ok
(
TiktokenModel
::
P50kEdit
)
}
else
if
model_name
.contains
(
"davinci"
)
||
model_name
.contains
(
"curie"
)
||
model_name
.contains
(
"babbage"
)
||
model_name
.contains
(
"ada"
)
{
Ok
(
TiktokenModel
::
R50kBase
)
}
else
{
// Return an error for unrecognized model names to prevent silent failures
Err
(
anyhow
::
anyhow!
(
"Unrecognized OpenAI model name: '{}'. Expected GPT-3, GPT-3.5, GPT-4, or related model names"
,
model_name
))
}
}
/// Get special tokens for a specific model
fn
get_special_tokens_for_model
(
model
:
TiktokenModel
)
->
SpecialTokens
{
// These are common special tokens for GPT models
// The actual token IDs might vary by model
match
model
{
TiktokenModel
::
Cl100kBase
=>
SpecialTokens
{
bos_token
:
Some
(
"<|endoftext|>"
.to_string
()),
eos_token
:
Some
(
"<|endoftext|>"
.to_string
()),
unk_token
:
None
,
sep_token
:
None
,
pad_token
:
Some
(
"<|endoftext|>"
.to_string
()),
cls_token
:
None
,
mask_token
:
None
,
additional_special_tokens
:
vec!
[
"<|fim_prefix|>"
.to_string
(),
"<|fim_middle|>"
.to_string
(),
"<|fim_suffix|>"
.to_string
(),
"<|endofprompt|>"
.to_string
(),
],
},
_
=>
SpecialTokens
{
bos_token
:
Some
(
"<|endoftext|>"
.to_string
()),
eos_token
:
Some
(
"<|endoftext|>"
.to_string
()),
unk_token
:
None
,
sep_token
:
None
,
pad_token
:
Some
(
"<|endoftext|>"
.to_string
()),
cls_token
:
None
,
mask_token
:
None
,
additional_special_tokens
:
vec!
[],
},
}
}
}
impl
Encoder
for
TiktokenTokenizer
{
fn
encode
(
&
self
,
input
:
&
str
)
->
Result
<
Encoding
>
{
let
tokens
=
self
.tokenizer
.encode_ordinary
(
input
);
Ok
(
Encoding
::
Tiktoken
(
tokens
))
}
fn
encode_batch
(
&
self
,
inputs
:
&
[
&
str
])
->
Result
<
Vec
<
Encoding
>>
{
inputs
.iter
()
.map
(|
input
|
self
.encode
(
input
))
.collect
()
}
}
impl
Decoder
for
TiktokenTokenizer
{
fn
decode
(
&
self
,
token_ids
:
&
[
u32
],
_
skip_special_tokens
:
bool
)
->
Result
<
String
>
{
// Convert u32 to usize for tiktoken-rs
let
tokens
:
Vec
<
usize
>
=
token_ids
.iter
()
.map
(|
&
id
|
id
as
usize
)
.collect
();
self
.tokenizer
.decode
(
tokens
)
.map_err
(|
e
|
Error
::
msg
(
format!
(
"Decoding failed: {}"
,
e
)))
}
}
impl
TokenizerTrait
for
TiktokenTokenizer
{
fn
vocab_size
(
&
self
)
->
usize
{
self
.vocab_size
}
fn
get_special_tokens
(
&
self
)
->
&
SpecialTokens
{
&
self
.special_tokens
}
fn
token_to_id
(
&
self
,
_
token
:
&
str
)
->
Option
<
u32
>
{
// Tiktoken doesn't provide direct token-to-id mapping
// We'd need to encode the token and check if it produces a single ID
None
}
fn
id_to_token
(
&
self
,
_
id
:
u32
)
->
Option
<
String
>
{
// Tiktoken doesn't provide direct id-to-token mapping
// We can only decode IDs to text
None
}
}
#[cfg(test)]
mod
tests
{
use
super
::
*
;
#[test]
fn
test_tiktoken_creation
()
{
let
tokenizer
=
TiktokenTokenizer
::
new
(
TiktokenModel
::
Cl100kBase
)
.unwrap
();
assert_eq!
(
tokenizer
.vocab_size
(),
100256
);
}
#[test]
fn
test_model_from_name
()
{
assert
!
(
matches!
(
TiktokenTokenizer
::
model_from_name
(
"gpt-4"
)
.unwrap
(),
TiktokenModel
::
Cl100kBase
));
assert
!
(
matches!
(
TiktokenTokenizer
::
model_from_name
(
"gpt-3.5-turbo"
)
.unwrap
(),
TiktokenModel
::
Cl100kBase
));
assert
!
(
matches!
(
TiktokenTokenizer
::
model_from_name
(
"text-davinci-003"
)
.unwrap
(),
TiktokenModel
::
P50kBase
));
assert
!
(
matches!
(
TiktokenTokenizer
::
model_from_name
(
"text-davinci-edit-001"
)
.unwrap
(),
TiktokenModel
::
P50kEdit
));
assert
!
(
matches!
(
TiktokenTokenizer
::
model_from_name
(
"davinci"
)
.unwrap
(),
TiktokenModel
::
R50kBase
));
}
#[test]
fn
test_encode_decode
()
{
let
tokenizer
=
TiktokenTokenizer
::
new
(
TiktokenModel
::
Cl100kBase
)
.unwrap
();
let
text
=
"Hello, world!"
;
let
encoding
=
tokenizer
.encode
(
text
)
.unwrap
();
let
decoded
=
tokenizer
.decode
(
&
encoding
.token_ids
(),
false
)
.unwrap
();
assert_eq!
(
decoded
,
text
);
}
#[test]
fn
test_batch_encode
()
{
let
tokenizer
=
TiktokenTokenizer
::
new
(
TiktokenModel
::
Cl100kBase
)
.unwrap
();
let
texts
=
vec!
[
"Hello"
,
"World"
,
"Test"
];
let
encodings
=
tokenizer
.encode_batch
(
&
texts
)
.unwrap
();
assert_eq!
(
encodings
.len
(),
3
);
for
(
i
,
encoding
)
in
encodings
.iter
()
.enumerate
()
{
let
decoded
=
tokenizer
.decode
(
&
encoding
.token_ids
(),
false
)
.unwrap
();
assert_eq!
(
decoded
,
texts
[
i
]);
}
}
#[test]
fn
test_special_tokens
()
{
let
tokenizer
=
TiktokenTokenizer
::
new
(
TiktokenModel
::
Cl100kBase
)
.unwrap
();
let
special_tokens
=
tokenizer
.get_special_tokens
();
assert
!
(
special_tokens
.eos_token
.is_some
());
assert_eq!
(
special_tokens
.eos_token
.as_ref
()
.unwrap
(),
"<|endoftext|>"
);
}
#[test]
fn
test_unrecognized_model_name_returns_error
()
{
// Test that unrecognized model names return an error
let
result
=
TiktokenTokenizer
::
from_model_name
(
"distilgpt-2"
);
assert
!
(
result
.is_err
());
if
let
Err
(
e
)
=
result
{
assert
!
(
e
.to_string
()
.contains
(
"Unrecognized OpenAI model name"
));
}
let
result
=
TiktokenTokenizer
::
from_model_name
(
"bert-base-uncased"
);
assert
!
(
result
.is_err
());
if
let
Err
(
e
)
=
result
{
assert
!
(
e
.to_string
()
.contains
(
"Unrecognized OpenAI model name"
));
}
let
result
=
TiktokenTokenizer
::
from_model_name
(
"llama-7b"
);
assert
!
(
result
.is_err
());
if
let
Err
(
e
)
=
result
{
assert
!
(
e
.to_string
()
.contains
(
"Unrecognized OpenAI model name"
));
}
}
#[test]
fn
test_recognized_model_names
()
{
// Test that recognized model names work correctly
assert
!
(
TiktokenTokenizer
::
from_model_name
(
"gpt-4"
)
.is_ok
());
assert
!
(
TiktokenTokenizer
::
from_model_name
(
"gpt-3.5-turbo"
)
.is_ok
());
assert
!
(
TiktokenTokenizer
::
from_model_name
(
"text-davinci-003"
)
.is_ok
());
assert
!
(
TiktokenTokenizer
::
from_model_name
(
"code-davinci-002"
)
.is_ok
());
assert
!
(
TiktokenTokenizer
::
from_model_name
(
"text-curie-001"
)
.is_ok
());
assert
!
(
TiktokenTokenizer
::
from_model_name
(
"text-babbage-001"
)
.is_ok
());
assert
!
(
TiktokenTokenizer
::
from_model_name
(
"text-ada-001"
)
.is_ok
());
}
}
sgl-router/src/tokenizer/traits.rs
View file @
0b95a01a
...
...
@@ -26,13 +26,28 @@ pub enum Encoding {
Hf
(
Box
<
tokenizers
::
tokenizer
::
Encoding
>
),
/// Sentence Piece
Sp
(
Vec
<
u32
>
),
/// Tiktoken (for GPT models)
Tiktoken
(
Vec
<
usize
>
),
}
impl
Encoding
{
pub
fn
token_ids
(
&
self
)
->
&
[
u32
]
{
pub
fn
token_ids
(
&
self
)
->
Vec
<
u32
>
{
match
self
{
Encoding
::
Hf
(
inner
)
=>
inner
.get_ids
()
.to_vec
(),
Encoding
::
Sp
(
inner
)
=>
inner
.clone
(),
Encoding
::
Tiktoken
(
inner
)
=>
inner
.iter
()
.map
(|
&
id
|
id
as
u32
)
.collect
(),
}
}
pub
fn
token_ids_ref
(
&
self
)
->
&
[
u32
]
{
match
self
{
Encoding
::
Hf
(
inner
)
=>
inner
.get_ids
(),
Encoding
::
Sp
(
inner
)
=>
inner
,
Encoding
::
Tiktoken
(
_
)
=>
{
// Tiktoken uses usize, we can't return a reference to u32
// This is a limitation - callers should use token_ids() for Tiktoken
&
[]
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment