Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
86d10d22
Unverified
Commit
86d10d22
authored
Aug 23, 2025
by
Lianmin Zheng
Committed by
GitHub
Aug 23, 2025
Browse files
Update grok.py and tiktoken tokenizer (#9532)
parent
83871aa1
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
732 additions
and
64 deletions
+732
-64
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+10
-6
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+5
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+16
-2
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+31
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+18
-0
python/sglang/srt/layers/elementwise.py
python/sglang/srt/layers/elementwise.py
+94
-0
python/sglang/srt/layers/moe/router.py
python/sglang/srt/layers/moe/router.py
+15
-9
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+6
-0
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+376
-47
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
+161
-0
No files found.
python/sglang/srt/constrained/xgrammar_backend.py
View file @
86d10d22
...
...
@@ -162,12 +162,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
):
super
().
__init__
()
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
# This ensures consistency between what the model considers EOS and what XGrammar uses
tokenizer_info
=
TokenizerInfo
.
from_huggingface
(
tokenizer
,
vocab_size
=
vocab_size
,
stop_token_ids
=
model_eos_token_ids
)
override_stop_tokens
=
None
if
hasattr
(
tokenizer
,
"init_xgrammar"
):
# For special tokenizer
tokenizer_info
,
override_stop_tokens
=
tokenizer
.
init_xgrammar
()
else
:
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
# This ensures consistency between what the model considers EOS and what XGrammar uses
tokenizer_info
=
TokenizerInfo
.
from_huggingface
(
tokenizer
,
vocab_size
=
vocab_size
,
stop_token_ids
=
model_eos_token_ids
)
override_stop_tokens
=
None
self
.
grammar_compiler
=
GrammarCompiler
(
tokenizer_info
=
tokenizer_info
)
self
.
vocab_size
=
vocab_size
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
86d10d22
...
...
@@ -263,6 +263,11 @@ def get_tokenizer(
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
"""Gets a tokenizer for the given model name via Huggingface."""
if
tokenizer_name
.
endswith
(
".json"
):
from
sglang.srt.tokenizer.tiktoken_tokenizer
import
TiktokenTokenizer
return
TiktokenTokenizer
(
tokenizer_name
)
if
tokenizer_mode
==
"slow"
:
if
kwargs
.
get
(
"use_fast"
,
False
):
raise
ValueError
(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
86d10d22
...
...
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
def
logit_capping_mod
(
logit_capping_method
,
logit_cap
):
# positive logit_cap -> tanh cap
if
logit_capping_method
==
"tanh"
:
return
logit_cap
else
:
raise
ValueError
()
@
dataclass
class
ForwardMetadata
:
attn_logits
:
torch
.
Tensor
...
...
@@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
logits_soft_cap
=
logit_capping_mod
(
layer
.
logit_capping_method
,
layer
.
logit_cap
)
causal
=
True
if
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
causal
=
False
...
...
@@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
l
ayer
.
logi
t_cap
,
l
ogit_cap
=
logits_sof
t_cap
,
sliding_window_size
=
sliding_window_size
,
sinks
=
sinks
,
window_kv_offsets
=
window_kv_offsets
,
xai_temperature_len
=
layer
.
xai_temperature_len
,
)
return
o
...
...
@@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
else
:
o
=
torch
.
empty_like
(
q
)
logits_soft_cap
=
logit_capping_mod
(
layer
.
logit_capping_method
,
layer
.
logit_cap
)
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
...
...
@@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
num_kv_splits
,
self
.
max_kv_splits
,
layer
.
scaling
,
l
ayer
.
logi
t_cap
,
l
ogit_cap
=
logits_sof
t_cap
,
sinks
=
sinks
,
xai_temperature_len
=
layer
.
xai_temperature_len
,
)
return
o
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
86d10d22
...
...
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
):
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
...
...
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_batch_seq_len
-
1
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
_qtemp
=
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
_qtemp
,
1.0
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
kv_len_per_split
=
(
...
...
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
qk
=
tl
.
where
(
offs_n
<
split_kv_end
,
qk
,
float
(
"-inf"
))
offs_buf_v
=
(
...
...
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
max_kv_splits
,
sm_scale
,
logit_cap
,
xai_temperature_len
=-
1
,
):
BLOCK
=
64
# [TODO] work around SGPR limit on MI3xx
...
...
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
BLOCK_N
=
BLOCK
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
num_warps
=
num_warps
,
num_stages
=
2
,
Lk
=
Lk
,
...
...
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_H
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
...
...
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_batch_seq_len
-
1
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
_qtemp
=
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
_qtemp
,
1.0
)
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
if
BLOCK_DPE
>
0
:
...
...
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
)
...
...
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
max_kv_splits
,
sm_scale
,
logit_cap
,
xai_temperature_len
=-
1
,
):
BLOCK
=
32
Lk
=
k_buffer
.
shape
[
-
1
]
...
...
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_H
=
BLOCK_H
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
num_warps
=
4
,
num_stages
=
num_stages
,
Lk
=
Lk
,
...
...
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
sm_scale
,
logit_cap
=
0.0
,
sinks
=
None
,
xai_temperature_len
=-
1
,
):
_decode_att_m_fwd
(
q
,
...
...
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
max_kv_splits
,
sm_scale
,
logit_cap
,
xai_temperature_len
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
...
...
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
sm_scale
,
logit_cap
=
0.0
,
sinks
=
None
,
xai_temperature_len
=-
1
,
):
_decode_grouped_att_m_fwd
(
q
,
...
...
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
max_kv_splits
,
sm_scale
,
logit_cap
,
xai_temperature_len
,
)
_decode_softmax_reducev_fwd
(
attn_logits
,
...
...
@@ -702,6 +730,7 @@ def decode_attention_fwd(
sm_scale
,
logit_cap
=
0.0
,
sinks
=
None
,
xai_temperature_len
=-
1
,
):
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
...
...
@@ -725,6 +754,7 @@ def decode_attention_fwd(
sm_scale
,
logit_cap
=
logit_cap
,
sinks
=
sinks
,
xai_temperature_len
=
xai_temperature_len
,
)
else
:
# GQA/MQA/MLA
...
...
@@ -742,4 +772,5 @@ def decode_attention_fwd(
sm_scale
,
logit_cap
=
logit_cap
,
sinks
=
sinks
,
xai_temperature_len
=
xai_temperature_len
,
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
86d10d22
...
...
@@ -69,6 +69,7 @@ def _fwd_kernel(
stride_buf_vh
,
SLIDING_WINDOW_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
...
...
@@ -109,6 +110,15 @@ def _fwd_kernel(
mask_d
=
offs_d
<
Lq
mask_dv
=
offs_dv
<
Lv
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
,
1.0
,
)
offs_q
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
...
...
@@ -203,6 +213,9 @@ def _fwd_kernel(
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
...
...
@@ -306,6 +319,9 @@ def _fwd_kernel(
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
...
...
@@ -373,6 +389,7 @@ def extend_attention_fwd(
sliding_window_size
=-
1
,
sinks
=
None
,
window_kv_offsets
=
None
,
xai_temperature_len
=-
1
,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
...
@@ -477,6 +494,7 @@ def extend_attention_fwd(
v_buffer
.
stride
(
1
),
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
...
...
python/sglang/srt/layers/elementwise.py
View file @
86d10d22
...
...
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
return
out_hidden_states
,
out_scales
else
:
return
out_hidden_states
,
None
# silu on first half of vector
@
triton
.
jit
def
silu_and_mul_kernel
(
out_hidden_states_ptr
,
# (bs, hidden_dim)
out_scales_ptr
,
# (bs,)
hidden_states_ptr
,
# (bs, hidden_dim * 2)
quant_max
:
tl
.
constexpr
,
static_scale
:
tl
.
constexpr
,
hidden_dim
:
tl
.
constexpr
,
# the output hidden_dim
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
input_start
=
pid
*
hidden_dim
*
2
output_start
=
pid
*
hidden_dim
input1_offs
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
<
hidden_dim
# shared for input1, input3, output
input3_offs
=
hidden_dim
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
output_offs
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
x1
=
tl
.
load
(
hidden_states_ptr
+
input_start
+
input1_offs
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
x3
=
tl
.
load
(
hidden_states_ptr
+
input_start
+
input3_offs
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# silu
# cast down before mul to better match training?
silu_x1
=
x1
*
tl
.
sigmoid
(
x1
)
out
=
x3
*
silu_x1
.
to
(
hidden_states_ptr
.
dtype
.
element_ty
)
if
quant_max
is
not
None
:
raise
NotImplementedError
()
tl
.
store
(
out_hidden_states_ptr
+
output_start
+
output_offs
,
out
,
mask
=
mask
)
def
silu_and_mul_triton
(
hidden_states
,
scales
=
None
,
quantize
=
None
,
# dtype to quantize to
out
=
None
,
):
bs
,
in_hidden_dim
=
hidden_states
.
shape
hidden_dim
=
in_hidden_dim
//
2
if
out
is
None
:
out_hidden_states
=
torch
.
empty
(
(
bs
,
hidden_dim
),
dtype
=
quantize
or
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
else
:
assert
out
.
shape
==
(
bs
,
hidden_dim
)
assert
out
.
dtype
==
(
quantize
or
hidden_states
.
dtype
)
out_hidden_states
=
out
out_scales
=
None
static_scale
=
False
if
quantize
is
not
None
:
if
scales
is
None
:
out_scales
=
torch
.
empty
(
(
bs
,),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
else
:
out_scales
=
scales
static_scale
=
True
max_warps
=
16
if
_is_hip
else
32
config
=
{
# 8 ele per thread (not tuned)
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
8
*
32
)),
max_warps
),
4
),
}
silu_and_mul_kernel
[(
bs
,)](
out_hidden_states
,
out_scales
,
hidden_states
,
quant_max
=
torch
.
finfo
(
quantize
).
max
if
quantize
is
not
None
else
None
,
static_scale
=
static_scale
,
hidden_dim
=
hidden_dim
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
hidden_dim
),
**
config
,
)
if
quantize
is
not
None
:
return
out_hidden_states
,
out_scales
else
:
return
out_hidden_states
,
None
python/sglang/srt/layers/moe/router.py
View file @
86d10d22
...
...
@@ -45,11 +45,14 @@ def fused_moe_router_kernel(
logits
=
tl
.
sum
((
w_router
.
to
(
tl
.
float32
)
*
x
[
None
,
:].
to
(
tl
.
float32
)),
axis
=-
1
)
# logit softcap
logits_scaled
=
logits
/
moe_softcapping
exped
=
tl
.
exp
(
2
*
logits_scaled
)
top
=
exped
-
1
bottom
=
exped
+
1
logits_softcapped
=
top
/
bottom
*
moe_softcapping
if
moe_softcapping
==
0
:
logits_softcapped
=
logits
else
:
logits_scaled
=
logits
/
moe_softcapping
exped
=
tl
.
exp
(
2
*
logits_scaled
)
top
=
exped
-
1
bottom
=
exped
+
1
logits_softcapped
=
top
/
bottom
*
moe_softcapping
# Add bias after softcapping
if
is_correction_bias
:
...
...
@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel(
b_ptrs
+=
BLOCK_SIZE_K
# 4. logit softcap
logits_scaled
=
acc
/
moe_softcapping
exped
=
tl
.
exp
(
2
*
logits_scaled
)
logits_softcapped
=
(
exped
-
1
)
/
(
exped
+
1
)
*
moe_softcapping
if
moe_softcapping
==
0
:
logits_softcapped
=
acc
else
:
logits_scaled
=
acc
/
moe_softcapping
exped
=
tl
.
exp
(
2
*
logits_scaled
)
logits_softcapped
=
(
exped
-
1
)
/
(
exped
+
1
)
*
moe_softcapping
# 5. top1
arange_block_size_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_N
)[
None
,
:]
...
...
@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
# 7. handle topk == 2
if
topk
==
2
:
cond_top2
=
(
arange_block_size_n
<
num_experts
)
and
(
cond_top2
=
(
arange_block_size_n
<
num_experts
)
&
(
arange_block_size_n
!=
top1
[:,
None
]
)
top2
=
tl
.
argmax
(
...
...
python/sglang/srt/layers/radix_attention.py
View file @
86d10d22
...
...
@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
v_head_dim
:
int
=
-
1
,
sliding_window_size
:
int
=
-
1
,
is_cross_attention
:
bool
=
False
,
pos_encoding_mode
:
str
=
"NONE"
,
logit_capping_method
:
str
=
"tanh"
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
...
...
@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
self
.
quant_method
.
create_weights
(
self
)
self
.
attn_type
=
attn_type
self
.
pos_encoding_mode
=
pos_encoding_mode
self
.
logit_capping_method
=
logit_capping_method
self
.
xai_temperature_len
=
-
1
def
forward
(
self
,
q
,
...
...
python/sglang/srt/models/grok.py
View file @
86d10d22
This diff is collapsed.
Click to expand it.
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
0 → 100644
View file @
86d10d22
import
functools
import
json
from
typing
import
AbstractSet
,
Collection
,
List
,
Literal
,
Union
class
TiktokenProcessor
:
def
__init__
(
self
,
name
:
str
):
self
.
tokenizer
=
TiktokenTokenizer
(
name
)
def
image_processor
(
self
,
image
):
return
{
"pixel_values"
:
[
image
]}
RESERVED_TOKEN_TEXTS
=
[
f
"<|reserved_
{
i
}
|>"
for
i
in
range
(
3
,
128
)]
CONTROL_TOKEN_TEXTS
=
[
f
"<|control
{
i
}
|>"
for
i
in
range
(
1
,
705
)]
PAD
=
"<|pad|>"
EOS
=
"<|eos|>"
SEP
=
"<|separator|>"
DEFAULT_SPECIAL_TOKENS
=
[
PAD
,
SEP
,
EOS
]
DEFAULT_CONTROL_TOKENS
=
{
"pad"
:
PAD
,
"sep"
:
EOS
,
"eos"
:
SEP
}
# default + separate each single digit
PAT_STR_B
=
r
"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
class
TiktokenTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
tiktoken
from
jinja2
import
Template
# Read the JSON
with
open
(
tokenizer_path
,
"rb"
)
as
fin
:
xtok_dict
=
json
.
load
(
fin
)
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict
mergeable_ranks
=
{
bytes
(
item
[
"bytes"
]):
item
[
"token"
]
for
item
in
xtok_dict
[
"regular_tokens"
]
}
special_tokens
=
{
bytes
(
item
[
"bytes"
]).
decode
():
item
[
"token"
]
for
item
in
xtok_dict
[
"special_tokens"
]
}
if
xtok_dict
[
"word_split"
]
==
"V1"
:
pad_str
=
PAT_STR_B
else
:
assert
False
,
f
"Unknown word_split:
{
xtok_dict
[
'word_split'
]
}
"
pad_str
=
xtok_dict
.
get
(
"pat_str"
,
pad_str
)
kwargs
=
{
"name"
:
tokenizer_path
,
"pat_str"
:
pad_str
,
"mergeable_ranks"
:
mergeable_ranks
,
"special_tokens"
:
special_tokens
,
}
if
"default_allowed_special"
in
xtok_dict
:
default_allowed_special
=
set
(
[
bytes
(
bytes_list
).
decode
()
for
bytes_list
in
xtok_dict
[
"default_allowed_special"
]
]
)
if
"vocab_size"
in
xtok_dict
:
kwargs
[
"explicit_n_vocab"
]
=
xtok_dict
[
"vocab_size"
]
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__
default_allowed_special
=
None
control_tokens
=
DEFAULT_CONTROL_TOKENS
tokenizer
=
tiktoken
.
Encoding
(
**
kwargs
)
tokenizer
.
_default_allowed_special
=
default_allowed_special
or
set
()
tokenizer
.
_control_tokens
=
control_tokens
def
encode_patched
(
self
,
text
:
str
,
*
,
allowed_special
:
Union
[
Literal
[
"all"
],
AbstractSet
[
str
]
]
=
set
(),
# noqa: B006
disallowed_special
:
Union
[
Literal
[
"all"
],
Collection
[
str
]]
=
"all"
,
)
->
List
[
int
]:
if
isinstance
(
allowed_special
,
set
):
allowed_special
|=
self
.
_default_allowed_special
return
tiktoken
.
Encoding
.
encode
(
self
,
text
,
allowed_special
=
allowed_special
,
disallowed_special
=
(),
)
tokenizer
.
encode
=
functools
.
partial
(
encode_patched
,
tokenizer
)
# Allow more tokens to prevent crash
tokenizer
.
_default_allowed_special
|=
set
(
DEFAULT_CONTROL_TOKENS
.
values
())
tokenizer
.
_default_allowed_special
|=
set
(
CONTROL_TOKEN_TEXTS
+
RESERVED_TOKEN_TEXTS
)
# Convert to HF interface
self
.
tokenizer
=
tokenizer
self
.
bos_token_id
=
None
self
.
eos_token_id
=
tokenizer
.
_special_tokens
[
EOS
]
self
.
vocab_size
=
tokenizer
.
n_vocab
self
.
chat_template
=
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>
\n\n
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
self
.
chat_template_jinja
=
Template
(
self
.
chat_template
)
self
.
additional_stop_token_ids
=
None
def
encode
(
self
,
x
,
add_special_tokens
=
False
):
return
self
.
tokenizer
.
encode
(
x
)
def
decode
(
self
,
x
,
*
args
,
**
kwargs
):
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
if
len
(
batch
)
>
0
and
isinstance
(
batch
[
0
],
int
):
batch
=
[[
x
]
for
x
in
batch
]
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
,
tools
=
None
):
ret
=
self
.
chat_template_jinja
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
def
__call__
(
self
,
text
,
**
kwargs
):
return
{
"input_ids"
:
self
.
encode
(
text
),
}
def
init_xgrammar
(
self
):
from
xgrammar
import
TokenizerInfo
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE
=
"<|xg_special_token_{}|>"
enc
=
self
.
tokenizer
encoded_vocab
=
{
**
enc
.
_mergeable_ranks
,
**
enc
.
_special_tokens
}
encoded_vocab
=
[
token
for
token
,
_
in
sorted
(
encoded_vocab
.
items
(),
key
=
lambda
x
:
x
[
1
])
]
override_stop_tokens
=
[
2
]
# eos
# These are treated as special tokens in xgrammar; we want to avoid them
# For now, xgrammar treats anything starting with b'\x00' as a special token
xgrammar_special_token_ids
=
[]
for
i
,
token
in
enumerate
(
encoded_vocab
):
if
isinstance
(
token
,
bytes
)
and
token
.
startswith
(
b
"
\x00
"
):
xgrammar_special_token_ids
.
append
(
i
)
for
i
,
id
in
enumerate
(
xgrammar_special_token_ids
):
encoded_vocab
[
id
]
=
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE
.
format
(
i
)
tokenizer_info
=
TokenizerInfo
(
encoded_vocab
,
stop_token_ids
=
override_stop_tokens
)
assert
len
(
tokenizer_info
.
special_token_ids
)
==
0
return
tokenizer_info
,
override_stop_tokens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment