Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
86d10d22
"git@developer.sourcefind.cn:change/sglang.git" did not exist on "bb418ced802c6dbb6b0ae0d65218327129148769"
Unverified
Commit
86d10d22
authored
Aug 23, 2025
by
Lianmin Zheng
Committed by
GitHub
Aug 23, 2025
Browse files
Update grok.py and tiktoken tokenizer (#9532)
parent
83871aa1
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
732 additions
and
64 deletions
+732
-64
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+10
-6
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+5
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+16
-2
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+31
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+18
-0
python/sglang/srt/layers/elementwise.py
python/sglang/srt/layers/elementwise.py
+94
-0
python/sglang/srt/layers/moe/router.py
python/sglang/srt/layers/moe/router.py
+15
-9
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+6
-0
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+376
-47
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
+161
-0
No files found.
python/sglang/srt/constrained/xgrammar_backend.py
View file @
86d10d22
...
@@ -162,12 +162,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -162,12 +162,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
):
):
super
().
__init__
()
super
().
__init__
()
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
if
hasattr
(
tokenizer
,
"init_xgrammar"
):
# This ensures consistency between what the model considers EOS and what XGrammar uses
# For special tokenizer
tokenizer_info
=
TokenizerInfo
.
from_huggingface
(
tokenizer_info
,
override_stop_tokens
=
tokenizer
.
init_xgrammar
()
tokenizer
,
vocab_size
=
vocab_size
,
stop_token_ids
=
model_eos_token_ids
else
:
)
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
override_stop_tokens
=
None
# This ensures consistency between what the model considers EOS and what XGrammar uses
tokenizer_info
=
TokenizerInfo
.
from_huggingface
(
tokenizer
,
vocab_size
=
vocab_size
,
stop_token_ids
=
model_eos_token_ids
)
override_stop_tokens
=
None
self
.
grammar_compiler
=
GrammarCompiler
(
tokenizer_info
=
tokenizer_info
)
self
.
grammar_compiler
=
GrammarCompiler
(
tokenizer_info
=
tokenizer_info
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
86d10d22
...
@@ -263,6 +263,11 @@ def get_tokenizer(
...
@@ -263,6 +263,11 @@ def get_tokenizer(
**
kwargs
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
"""Gets a tokenizer for the given model name via Huggingface."""
"""Gets a tokenizer for the given model name via Huggingface."""
if
tokenizer_name
.
endswith
(
".json"
):
from
sglang.srt.tokenizer.tiktoken_tokenizer
import
TiktokenTokenizer
return
TiktokenTokenizer
(
tokenizer_name
)
if
tokenizer_mode
==
"slow"
:
if
tokenizer_mode
==
"slow"
:
if
kwargs
.
get
(
"use_fast"
,
False
):
if
kwargs
.
get
(
"use_fast"
,
False
):
raise
ValueError
(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
raise
ValueError
(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
86d10d22
...
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
...
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
def
logit_capping_mod
(
logit_capping_method
,
logit_cap
):
# positive logit_cap -> tanh cap
if
logit_capping_method
==
"tanh"
:
return
logit_cap
else
:
raise
ValueError
()
@
dataclass
@
dataclass
class
ForwardMetadata
:
class
ForwardMetadata
:
attn_logits
:
torch
.
Tensor
attn_logits
:
torch
.
Tensor
...
@@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
logits_soft_cap
=
logit_capping_mod
(
layer
.
logit_capping_method
,
layer
.
logit_cap
)
causal
=
True
causal
=
True
if
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
if
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
causal
=
False
causal
=
False
...
@@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
max_extend_len
,
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
l
ayer
.
logi
t_cap
,
l
ogit_cap
=
logits_sof
t_cap
,
sliding_window_size
=
sliding_window_size
,
sliding_window_size
=
sliding_window_size
,
sinks
=
sinks
,
sinks
=
sinks
,
window_kv_offsets
=
window_kv_offsets
,
window_kv_offsets
=
window_kv_offsets
,
xai_temperature_len
=
layer
.
xai_temperature_len
,
)
)
return
o
return
o
...
@@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
logits_soft_cap
=
logit_capping_mod
(
layer
.
logit_capping_method
,
layer
.
logit_cap
)
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
...
@@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
num_kv_splits
,
self
.
forward_metadata
.
num_kv_splits
,
self
.
max_kv_splits
,
self
.
max_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
l
ayer
.
logi
t_cap
,
l
ogit_cap
=
logits_sof
t_cap
,
sinks
=
sinks
,
sinks
=
sinks
,
xai_temperature_len
=
layer
.
xai_temperature_len
,
)
)
return
o
return
o
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
86d10d22
...
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
...
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
...
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_batch_seq_len
-
1
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
_qtemp
=
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
_qtemp
,
1.0
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
kv_len_per_split
=
(
kv_len_per_split
=
(
...
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
...
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
qk
=
tl
.
where
(
offs_n
<
split_kv_end
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
offs_n
<
split_kv_end
,
qk
,
float
(
"-inf"
))
offs_buf_v
=
(
offs_buf_v
=
(
...
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
...
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
xai_temperature_len
=-
1
,
):
):
BLOCK
=
64
BLOCK
=
64
# [TODO] work around SGPR limit on MI3xx
# [TODO] work around SGPR limit on MI3xx
...
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
...
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
2
,
num_stages
=
2
,
Lk
=
Lk
,
Lk
=
Lk
,
...
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
...
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_H
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
...
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
...
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_batch_seq_len
-
1
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
_qtemp
=
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
_qtemp
,
1.0
)
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
...
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
...
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
qk
=
tl
.
where
(
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
)
)
...
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
xai_temperature_len
=-
1
,
):
):
BLOCK
=
32
BLOCK
=
32
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
...
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
Lk
=
Lk
,
Lk
=
Lk
,
...
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
...
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sinks
=
None
,
sinks
=
None
,
xai_temperature_len
=-
1
,
):
):
_decode_att_m_fwd
(
_decode_att_m_fwd
(
q
,
q
,
...
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
...
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
xai_temperature_len
,
)
)
_decode_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
,
attn_logits
,
...
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
...
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sinks
=
None
,
sinks
=
None
,
xai_temperature_len
=-
1
,
):
):
_decode_grouped_att_m_fwd
(
_decode_grouped_att_m_fwd
(
q
,
q
,
...
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
...
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
xai_temperature_len
,
)
)
_decode_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
,
attn_logits
,
...
@@ -702,6 +730,7 @@ def decode_attention_fwd(
...
@@ -702,6 +730,7 @@ def decode_attention_fwd(
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sinks
=
None
,
sinks
=
None
,
xai_temperature_len
=-
1
,
):
):
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
...
@@ -725,6 +754,7 @@ def decode_attention_fwd(
...
@@ -725,6 +754,7 @@ def decode_attention_fwd(
sm_scale
,
sm_scale
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
sinks
=
sinks
,
sinks
=
sinks
,
xai_temperature_len
=
xai_temperature_len
,
)
)
else
:
else
:
# GQA/MQA/MLA
# GQA/MQA/MLA
...
@@ -742,4 +772,5 @@ def decode_attention_fwd(
...
@@ -742,4 +772,5 @@ def decode_attention_fwd(
sm_scale
,
sm_scale
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
sinks
=
sinks
,
sinks
=
sinks
,
xai_temperature_len
=
xai_temperature_len
,
)
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
86d10d22
...
@@ -69,6 +69,7 @@ def _fwd_kernel(
...
@@ -69,6 +69,7 @@ def _fwd_kernel(
stride_buf_vh
,
stride_buf_vh
,
SLIDING_WINDOW_SIZE
:
tl
.
constexpr
,
SLIDING_WINDOW_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
...
@@ -109,6 +110,15 @@ def _fwd_kernel(
...
@@ -109,6 +110,15 @@ def _fwd_kernel(
mask_d
=
offs_d
<
Lq
mask_d
=
offs_d
<
Lq
mask_dv
=
offs_dv
<
Lv
mask_dv
=
offs_dv
<
Lv
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
,
1.0
,
)
offs_q
=
(
offs_q
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
*
stride_qbs
...
@@ -203,6 +213,9 @@ def _fwd_kernel(
...
@@ -203,6 +213,9 @@ def _fwd_kernel(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max
=
tl
.
max
(
qk
,
1
)
...
@@ -306,6 +319,9 @@ def _fwd_kernel(
...
@@ -306,6 +319,9 @@ def _fwd_kernel(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max
=
tl
.
max
(
qk
,
1
)
...
@@ -373,6 +389,7 @@ def extend_attention_fwd(
...
@@ -373,6 +389,7 @@ def extend_attention_fwd(
sliding_window_size
=-
1
,
sliding_window_size
=-
1
,
sinks
=
None
,
sinks
=
None
,
window_kv_offsets
=
None
,
window_kv_offsets
=
None
,
xai_temperature_len
=-
1
,
):
):
"""
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
@@ -477,6 +494,7 @@ def extend_attention_fwd(
...
@@ -477,6 +494,7 @@ def extend_attention_fwd(
v_buffer
.
stride
(
1
),
v_buffer
.
stride
(
1
),
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
...
...
python/sglang/srt/layers/elementwise.py
View file @
86d10d22
...
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
...
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
return
out_hidden_states
,
out_scales
return
out_hidden_states
,
out_scales
else
:
else
:
return
out_hidden_states
,
None
return
out_hidden_states
,
None
# silu on first half of vector
@
triton
.
jit
def
silu_and_mul_kernel
(
out_hidden_states_ptr
,
# (bs, hidden_dim)
out_scales_ptr
,
# (bs,)
hidden_states_ptr
,
# (bs, hidden_dim * 2)
quant_max
:
tl
.
constexpr
,
static_scale
:
tl
.
constexpr
,
hidden_dim
:
tl
.
constexpr
,
# the output hidden_dim
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
input_start
=
pid
*
hidden_dim
*
2
output_start
=
pid
*
hidden_dim
input1_offs
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
<
hidden_dim
# shared for input1, input3, output
input3_offs
=
hidden_dim
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
output_offs
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
x1
=
tl
.
load
(
hidden_states_ptr
+
input_start
+
input1_offs
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
x3
=
tl
.
load
(
hidden_states_ptr
+
input_start
+
input3_offs
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# silu
# cast down before mul to better match training?
silu_x1
=
x1
*
tl
.
sigmoid
(
x1
)
out
=
x3
*
silu_x1
.
to
(
hidden_states_ptr
.
dtype
.
element_ty
)
if
quant_max
is
not
None
:
raise
NotImplementedError
()
tl
.
store
(
out_hidden_states_ptr
+
output_start
+
output_offs
,
out
,
mask
=
mask
)
def
silu_and_mul_triton
(
hidden_states
,
scales
=
None
,
quantize
=
None
,
# dtype to quantize to
out
=
None
,
):
bs
,
in_hidden_dim
=
hidden_states
.
shape
hidden_dim
=
in_hidden_dim
//
2
if
out
is
None
:
out_hidden_states
=
torch
.
empty
(
(
bs
,
hidden_dim
),
dtype
=
quantize
or
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
else
:
assert
out
.
shape
==
(
bs
,
hidden_dim
)
assert
out
.
dtype
==
(
quantize
or
hidden_states
.
dtype
)
out_hidden_states
=
out
out_scales
=
None
static_scale
=
False
if
quantize
is
not
None
:
if
scales
is
None
:
out_scales
=
torch
.
empty
(
(
bs
,),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
else
:
out_scales
=
scales
static_scale
=
True
max_warps
=
16
if
_is_hip
else
32
config
=
{
# 8 ele per thread (not tuned)
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
8
*
32
)),
max_warps
),
4
),
}
silu_and_mul_kernel
[(
bs
,)](
out_hidden_states
,
out_scales
,
hidden_states
,
quant_max
=
torch
.
finfo
(
quantize
).
max
if
quantize
is
not
None
else
None
,
static_scale
=
static_scale
,
hidden_dim
=
hidden_dim
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
hidden_dim
),
**
config
,
)
if
quantize
is
not
None
:
return
out_hidden_states
,
out_scales
else
:
return
out_hidden_states
,
None
python/sglang/srt/layers/moe/router.py
View file @
86d10d22
...
@@ -45,11 +45,14 @@ def fused_moe_router_kernel(
...
@@ -45,11 +45,14 @@ def fused_moe_router_kernel(
logits
=
tl
.
sum
((
w_router
.
to
(
tl
.
float32
)
*
x
[
None
,
:].
to
(
tl
.
float32
)),
axis
=-
1
)
logits
=
tl
.
sum
((
w_router
.
to
(
tl
.
float32
)
*
x
[
None
,
:].
to
(
tl
.
float32
)),
axis
=-
1
)
# logit softcap
# logit softcap
logits_scaled
=
logits
/
moe_softcapping
if
moe_softcapping
==
0
:
exped
=
tl
.
exp
(
2
*
logits_scaled
)
logits_softcapped
=
logits
top
=
exped
-
1
else
:
bottom
=
exped
+
1
logits_scaled
=
logits
/
moe_softcapping
logits_softcapped
=
top
/
bottom
*
moe_softcapping
exped
=
tl
.
exp
(
2
*
logits_scaled
)
top
=
exped
-
1
bottom
=
exped
+
1
logits_softcapped
=
top
/
bottom
*
moe_softcapping
# Add bias after softcapping
# Add bias after softcapping
if
is_correction_bias
:
if
is_correction_bias
:
...
@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel(
...
@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel(
b_ptrs
+=
BLOCK_SIZE_K
b_ptrs
+=
BLOCK_SIZE_K
# 4. logit softcap
# 4. logit softcap
logits_scaled
=
acc
/
moe_softcapping
if
moe_softcapping
==
0
:
exped
=
tl
.
exp
(
2
*
logits_scaled
)
logits_softcapped
=
acc
logits_softcapped
=
(
exped
-
1
)
/
(
exped
+
1
)
*
moe_softcapping
else
:
logits_scaled
=
acc
/
moe_softcapping
exped
=
tl
.
exp
(
2
*
logits_scaled
)
logits_softcapped
=
(
exped
-
1
)
/
(
exped
+
1
)
*
moe_softcapping
# 5. top1
# 5. top1
arange_block_size_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_N
)[
None
,
:]
arange_block_size_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_N
)[
None
,
:]
...
@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
...
@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
# 7. handle topk == 2
# 7. handle topk == 2
if
topk
==
2
:
if
topk
==
2
:
cond_top2
=
(
arange_block_size_n
<
num_experts
)
and
(
cond_top2
=
(
arange_block_size_n
<
num_experts
)
&
(
arange_block_size_n
!=
top1
[:,
None
]
arange_block_size_n
!=
top1
[:,
None
]
)
)
top2
=
tl
.
argmax
(
top2
=
tl
.
argmax
(
...
...
python/sglang/srt/layers/radix_attention.py
View file @
86d10d22
...
@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
...
@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
v_head_dim
:
int
=
-
1
,
v_head_dim
:
int
=
-
1
,
sliding_window_size
:
int
=
-
1
,
sliding_window_size
:
int
=
-
1
,
is_cross_attention
:
bool
=
False
,
is_cross_attention
:
bool
=
False
,
pos_encoding_mode
:
str
=
"NONE"
,
logit_capping_method
:
str
=
"tanh"
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
...
@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
...
@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
self
.
quant_method
.
create_weights
(
self
)
self
.
quant_method
.
create_weights
(
self
)
self
.
attn_type
=
attn_type
self
.
attn_type
=
attn_type
self
.
pos_encoding_mode
=
pos_encoding_mode
self
.
logit_capping_method
=
logit_capping_method
self
.
xai_temperature_len
=
-
1
def
forward
(
def
forward
(
self
,
self
,
q
,
q
,
...
...
python/sglang/srt/models/grok.py
View file @
86d10d22
This diff is collapsed.
Click to expand it.
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
0 → 100644
View file @
86d10d22
import
functools
import
json
from
typing
import
AbstractSet
,
Collection
,
List
,
Literal
,
Union
class
TiktokenProcessor
:
def
__init__
(
self
,
name
:
str
):
self
.
tokenizer
=
TiktokenTokenizer
(
name
)
def
image_processor
(
self
,
image
):
return
{
"pixel_values"
:
[
image
]}
RESERVED_TOKEN_TEXTS
=
[
f
"<|reserved_
{
i
}
|>"
for
i
in
range
(
3
,
128
)]
CONTROL_TOKEN_TEXTS
=
[
f
"<|control
{
i
}
|>"
for
i
in
range
(
1
,
705
)]
PAD
=
"<|pad|>"
EOS
=
"<|eos|>"
SEP
=
"<|separator|>"
DEFAULT_SPECIAL_TOKENS
=
[
PAD
,
SEP
,
EOS
]
DEFAULT_CONTROL_TOKENS
=
{
"pad"
:
PAD
,
"sep"
:
EOS
,
"eos"
:
SEP
}
# default + separate each single digit
PAT_STR_B
=
r
"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
class
TiktokenTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
tiktoken
from
jinja2
import
Template
# Read the JSON
with
open
(
tokenizer_path
,
"rb"
)
as
fin
:
xtok_dict
=
json
.
load
(
fin
)
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict
mergeable_ranks
=
{
bytes
(
item
[
"bytes"
]):
item
[
"token"
]
for
item
in
xtok_dict
[
"regular_tokens"
]
}
special_tokens
=
{
bytes
(
item
[
"bytes"
]).
decode
():
item
[
"token"
]
for
item
in
xtok_dict
[
"special_tokens"
]
}
if
xtok_dict
[
"word_split"
]
==
"V1"
:
pad_str
=
PAT_STR_B
else
:
assert
False
,
f
"Unknown word_split:
{
xtok_dict
[
'word_split'
]
}
"
pad_str
=
xtok_dict
.
get
(
"pat_str"
,
pad_str
)
kwargs
=
{
"name"
:
tokenizer_path
,
"pat_str"
:
pad_str
,
"mergeable_ranks"
:
mergeable_ranks
,
"special_tokens"
:
special_tokens
,
}
if
"default_allowed_special"
in
xtok_dict
:
default_allowed_special
=
set
(
[
bytes
(
bytes_list
).
decode
()
for
bytes_list
in
xtok_dict
[
"default_allowed_special"
]
]
)
if
"vocab_size"
in
xtok_dict
:
kwargs
[
"explicit_n_vocab"
]
=
xtok_dict
[
"vocab_size"
]
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__
default_allowed_special
=
None
control_tokens
=
DEFAULT_CONTROL_TOKENS
tokenizer
=
tiktoken
.
Encoding
(
**
kwargs
)
tokenizer
.
_default_allowed_special
=
default_allowed_special
or
set
()
tokenizer
.
_control_tokens
=
control_tokens
def
encode_patched
(
self
,
text
:
str
,
*
,
allowed_special
:
Union
[
Literal
[
"all"
],
AbstractSet
[
str
]
]
=
set
(),
# noqa: B006
disallowed_special
:
Union
[
Literal
[
"all"
],
Collection
[
str
]]
=
"all"
,
)
->
List
[
int
]:
if
isinstance
(
allowed_special
,
set
):
allowed_special
|=
self
.
_default_allowed_special
return
tiktoken
.
Encoding
.
encode
(
self
,
text
,
allowed_special
=
allowed_special
,
disallowed_special
=
(),
)
tokenizer
.
encode
=
functools
.
partial
(
encode_patched
,
tokenizer
)
# Allow more tokens to prevent crash
tokenizer
.
_default_allowed_special
|=
set
(
DEFAULT_CONTROL_TOKENS
.
values
())
tokenizer
.
_default_allowed_special
|=
set
(
CONTROL_TOKEN_TEXTS
+
RESERVED_TOKEN_TEXTS
)
# Convert to HF interface
self
.
tokenizer
=
tokenizer
self
.
bos_token_id
=
None
self
.
eos_token_id
=
tokenizer
.
_special_tokens
[
EOS
]
self
.
vocab_size
=
tokenizer
.
n_vocab
self
.
chat_template
=
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>
\n\n
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
self
.
chat_template_jinja
=
Template
(
self
.
chat_template
)
self
.
additional_stop_token_ids
=
None
def
encode
(
self
,
x
,
add_special_tokens
=
False
):
return
self
.
tokenizer
.
encode
(
x
)
def
decode
(
self
,
x
,
*
args
,
**
kwargs
):
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
if
len
(
batch
)
>
0
and
isinstance
(
batch
[
0
],
int
):
batch
=
[[
x
]
for
x
in
batch
]
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
,
tools
=
None
):
ret
=
self
.
chat_template_jinja
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
def
__call__
(
self
,
text
,
**
kwargs
):
return
{
"input_ids"
:
self
.
encode
(
text
),
}
def
init_xgrammar
(
self
):
from
xgrammar
import
TokenizerInfo
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE
=
"<|xg_special_token_{}|>"
enc
=
self
.
tokenizer
encoded_vocab
=
{
**
enc
.
_mergeable_ranks
,
**
enc
.
_special_tokens
}
encoded_vocab
=
[
token
for
token
,
_
in
sorted
(
encoded_vocab
.
items
(),
key
=
lambda
x
:
x
[
1
])
]
override_stop_tokens
=
[
2
]
# eos
# These are treated as special tokens in xgrammar; we want to avoid them
# For now, xgrammar treats anything starting with b'\x00' as a special token
xgrammar_special_token_ids
=
[]
for
i
,
token
in
enumerate
(
encoded_vocab
):
if
isinstance
(
token
,
bytes
)
and
token
.
startswith
(
b
"
\x00
"
):
xgrammar_special_token_ids
.
append
(
i
)
for
i
,
id
in
enumerate
(
xgrammar_special_token_ids
):
encoded_vocab
[
id
]
=
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE
.
format
(
i
)
tokenizer_info
=
TokenizerInfo
(
encoded_vocab
,
stop_token_ids
=
override_stop_tokens
)
assert
len
(
tokenizer_info
.
special_token_ids
)
==
0
return
tokenizer_info
,
override_stop_tokens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment