Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
86d10d22
Unverified
Commit
86d10d22
authored
Aug 23, 2025
by
Lianmin Zheng
Committed by
GitHub
Aug 23, 2025
Browse files
Update grok.py and tiktoken tokenizer (#9532)
parent
83871aa1
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
732 additions
and
64 deletions
+732
-64
python/sglang/srt/constrained/xgrammar_backend.py
python/sglang/srt/constrained/xgrammar_backend.py
+10
-6
python/sglang/srt/hf_transformers_utils.py
python/sglang/srt/hf_transformers_utils.py
+5
-0
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+16
-2
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+31
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+18
-0
python/sglang/srt/layers/elementwise.py
python/sglang/srt/layers/elementwise.py
+94
-0
python/sglang/srt/layers/moe/router.py
python/sglang/srt/layers/moe/router.py
+15
-9
python/sglang/srt/layers/radix_attention.py
python/sglang/srt/layers/radix_attention.py
+6
-0
python/sglang/srt/models/grok.py
python/sglang/srt/models/grok.py
+376
-47
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
+161
-0
No files found.
python/sglang/srt/constrained/xgrammar_backend.py
View file @
86d10d22
...
@@ -162,12 +162,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
...
@@ -162,12 +162,16 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
):
):
super
().
__init__
()
super
().
__init__
()
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
if
hasattr
(
tokenizer
,
"init_xgrammar"
):
# This ensures consistency between what the model considers EOS and what XGrammar uses
# For special tokenizer
tokenizer_info
=
TokenizerInfo
.
from_huggingface
(
tokenizer_info
,
override_stop_tokens
=
tokenizer
.
init_xgrammar
()
tokenizer
,
vocab_size
=
vocab_size
,
stop_token_ids
=
model_eos_token_ids
else
:
)
# Create TokenizerInfo with model's EOS tokens as the authoritative stop tokens
override_stop_tokens
=
None
# This ensures consistency between what the model considers EOS and what XGrammar uses
tokenizer_info
=
TokenizerInfo
.
from_huggingface
(
tokenizer
,
vocab_size
=
vocab_size
,
stop_token_ids
=
model_eos_token_ids
)
override_stop_tokens
=
None
self
.
grammar_compiler
=
GrammarCompiler
(
tokenizer_info
=
tokenizer_info
)
self
.
grammar_compiler
=
GrammarCompiler
(
tokenizer_info
=
tokenizer_info
)
self
.
vocab_size
=
vocab_size
self
.
vocab_size
=
vocab_size
...
...
python/sglang/srt/hf_transformers_utils.py
View file @
86d10d22
...
@@ -263,6 +263,11 @@ def get_tokenizer(
...
@@ -263,6 +263,11 @@ def get_tokenizer(
**
kwargs
,
**
kwargs
,
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
)
->
Union
[
PreTrainedTokenizer
,
PreTrainedTokenizerFast
]:
"""Gets a tokenizer for the given model name via Huggingface."""
"""Gets a tokenizer for the given model name via Huggingface."""
if
tokenizer_name
.
endswith
(
".json"
):
from
sglang.srt.tokenizer.tiktoken_tokenizer
import
TiktokenTokenizer
return
TiktokenTokenizer
(
tokenizer_name
)
if
tokenizer_mode
==
"slow"
:
if
tokenizer_mode
==
"slow"
:
if
kwargs
.
get
(
"use_fast"
,
False
):
if
kwargs
.
get
(
"use_fast"
,
False
):
raise
ValueError
(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
raise
ValueError
(
"Cannot use the fast tokenizer in slow tokenizer mode."
)
...
...
python/sglang/srt/layers/attention/triton_backend.py
View file @
86d10d22
...
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
...
@@ -20,6 +20,14 @@ if TYPE_CHECKING:
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
from
sglang.srt.speculative.eagle_utils
import
EagleDraftInput
,
EagleVerifyInput
def
logit_capping_mod
(
logit_capping_method
,
logit_cap
):
# positive logit_cap -> tanh cap
if
logit_capping_method
==
"tanh"
:
return
logit_cap
else
:
raise
ValueError
()
@
dataclass
@
dataclass
class
ForwardMetadata
:
class
ForwardMetadata
:
attn_logits
:
torch
.
Tensor
attn_logits
:
torch
.
Tensor
...
@@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -718,6 +726,8 @@ class TritonAttnBackend(AttentionBackend):
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
)
)
logits_soft_cap
=
logit_capping_mod
(
layer
.
logit_capping_method
,
layer
.
logit_cap
)
causal
=
True
causal
=
True
if
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
if
layer
.
attn_type
==
AttentionType
.
ENCODER_ONLY
:
causal
=
False
causal
=
False
...
@@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -750,10 +760,11 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
mask_indptr
,
self
.
forward_metadata
.
max_extend_len
,
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
l
ayer
.
logi
t_cap
,
l
ogit_cap
=
logits_sof
t_cap
,
sliding_window_size
=
sliding_window_size
,
sliding_window_size
=
sliding_window_size
,
sinks
=
sinks
,
sinks
=
sinks
,
window_kv_offsets
=
window_kv_offsets
,
window_kv_offsets
=
window_kv_offsets
,
xai_temperature_len
=
layer
.
xai_temperature_len
,
)
)
return
o
return
o
...
@@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -777,6 +788,8 @@ class TritonAttnBackend(AttentionBackend):
else
:
else
:
o
=
torch
.
empty_like
(
q
)
o
=
torch
.
empty_like
(
q
)
logits_soft_cap
=
logit_capping_mod
(
layer
.
logit_capping_method
,
layer
.
logit_cap
)
if
save_kv_cache
:
if
save_kv_cache
:
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
forward_batch
.
token_to_kv_pool
.
set_kv_buffer
(
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
layer
,
forward_batch
.
out_cache_loc
,
k
,
v
...
@@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -801,8 +814,9 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
num_kv_splits
,
self
.
forward_metadata
.
num_kv_splits
,
self
.
max_kv_splits
,
self
.
max_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
l
ayer
.
logi
t_cap
,
l
ogit_cap
=
logits_sof
t_cap
,
sinks
=
sinks
,
sinks
=
sinks
,
xai_temperature_len
=
layer
.
xai_temperature_len
,
)
)
return
o
return
o
...
...
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
86d10d22
...
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
...
@@ -69,6 +69,7 @@ def _fwd_kernel_stage1(
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
...
@@ -85,6 +86,12 @@ def _fwd_kernel_stage1(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_batch_seq_len
-
1
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
_qtemp
=
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
_qtemp
,
1.0
)
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
off_q
=
cur_batch
*
stride_qbs
+
cur_head
*
stride_qh
+
offs_d
kv_len_per_split
=
(
kv_len_per_split
=
(
...
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
...
@@ -122,6 +129,9 @@ def _fwd_kernel_stage1(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
qk
=
tl
.
where
(
offs_n
<
split_kv_end
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
offs_n
<
split_kv_end
,
qk
,
float
(
"-inf"
))
offs_buf_v
=
(
offs_buf_v
=
(
...
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
...
@@ -181,6 +191,7 @@ def _decode_att_m_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
xai_temperature_len
=-
1
,
):
):
BLOCK
=
64
BLOCK
=
64
# [TODO] work around SGPR limit on MI3xx
# [TODO] work around SGPR limit on MI3xx
...
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
...
@@ -230,6 +241,7 @@ def _decode_att_m_fwd(
BLOCK_N
=
BLOCK
,
BLOCK_N
=
BLOCK
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
2
,
num_stages
=
2
,
Lk
=
Lk
,
Lk
=
Lk
,
...
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
...
@@ -266,6 +278,7 @@ def _fwd_grouped_kernel_stage1(
BLOCK_H
:
tl
.
constexpr
,
BLOCK_H
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lk
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
):
):
...
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
...
@@ -291,6 +304,12 @@ def _fwd_grouped_kernel_stage1(
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
cur_batch_seq_len
=
tl
.
load
(
kv_indptr
+
cur_batch
+
1
)
-
cur_batch_kv_start_idx
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
kv_splits
=
tl
.
load
(
num_kv_splits
+
cur_batch
)
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_batch_seq_len
-
1
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
_qtemp
=
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
_qtemp
,
1.0
)
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
offs_q
=
cur_batch
*
stride_qbs
+
cur_head
[:,
None
]
*
stride_qh
+
offs_d
[
None
,
:]
if
BLOCK_DPE
>
0
:
if
BLOCK_DPE
>
0
:
...
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
...
@@ -351,6 +370,9 @@ def _fwd_grouped_kernel_stage1(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
qk
=
tl
.
where
(
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
mask_h
[:,
None
]
&
(
offs_n
[
None
,
:]
<
split_kv_end
),
qk
,
float
(
"-inf"
)
)
)
...
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -413,6 +435,7 @@ def _decode_grouped_att_m_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
xai_temperature_len
=-
1
,
):
):
BLOCK
=
32
BLOCK
=
32
Lk
=
k_buffer
.
shape
[
-
1
]
Lk
=
k_buffer
.
shape
[
-
1
]
...
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
...
@@ -480,6 +503,7 @@ def _decode_grouped_att_m_fwd(
BLOCK_H
=
BLOCK_H
,
BLOCK_H
=
BLOCK_H
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
Lk
=
Lk
,
Lk
=
Lk
,
...
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
...
@@ -620,6 +644,7 @@ def decode_attention_fwd_normal(
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sinks
=
None
,
sinks
=
None
,
xai_temperature_len
=-
1
,
):
):
_decode_att_m_fwd
(
_decode_att_m_fwd
(
q
,
q
,
...
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
...
@@ -633,6 +658,7 @@ def decode_attention_fwd_normal(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
xai_temperature_len
,
)
)
_decode_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
,
attn_logits
,
...
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
...
@@ -661,6 +687,7 @@ def decode_attention_fwd_grouped(
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sinks
=
None
,
sinks
=
None
,
xai_temperature_len
=-
1
,
):
):
_decode_grouped_att_m_fwd
(
_decode_grouped_att_m_fwd
(
q
,
q
,
...
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
...
@@ -674,6 +701,7 @@ def decode_attention_fwd_grouped(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
,
logit_cap
,
xai_temperature_len
,
)
)
_decode_softmax_reducev_fwd
(
_decode_softmax_reducev_fwd
(
attn_logits
,
attn_logits
,
...
@@ -702,6 +730,7 @@ def decode_attention_fwd(
...
@@ -702,6 +730,7 @@ def decode_attention_fwd(
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sinks
=
None
,
sinks
=
None
,
xai_temperature_len
=-
1
,
):
):
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
...
@@ -725,6 +754,7 @@ def decode_attention_fwd(
...
@@ -725,6 +754,7 @@ def decode_attention_fwd(
sm_scale
,
sm_scale
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
sinks
=
sinks
,
sinks
=
sinks
,
xai_temperature_len
=
xai_temperature_len
,
)
)
else
:
else
:
# GQA/MQA/MLA
# GQA/MQA/MLA
...
@@ -742,4 +772,5 @@ def decode_attention_fwd(
...
@@ -742,4 +772,5 @@ def decode_attention_fwd(
sm_scale
,
sm_scale
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
sinks
=
sinks
,
sinks
=
sinks
,
xai_temperature_len
=
xai_temperature_len
,
)
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
86d10d22
...
@@ -69,6 +69,7 @@ def _fwd_kernel(
...
@@ -69,6 +69,7 @@ def _fwd_kernel(
stride_buf_vh
,
stride_buf_vh
,
SLIDING_WINDOW_SIZE
:
tl
.
constexpr
,
SLIDING_WINDOW_SIZE
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
logit_cap
:
tl
.
constexpr
,
xai_temperature_len
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lq
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
BLOCK_DMODEL
:
tl
.
constexpr
,
...
@@ -109,6 +110,15 @@ def _fwd_kernel(
...
@@ -109,6 +110,15 @@ def _fwd_kernel(
mask_d
=
offs_d
<
Lq
mask_d
=
offs_d
<
Lq
mask_dv
=
offs_dv
<
Lv
mask_dv
=
offs_dv
<
Lv
if
xai_temperature_len
>
0
:
offs_qidx
=
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
xai_temperature_scale
=
1.0
/
tl
.
log2
(
float
(
xai_temperature_len
))
xai_temperature_reg
=
tl
.
where
(
offs_qidx
>
xai_temperature_len
,
tl
.
log2
(
offs_qidx
.
to
(
tl
.
float32
))
*
xai_temperature_scale
,
1.0
,
)
offs_q
=
(
offs_q
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_qbs
*
stride_qbs
...
@@ -203,6 +213,9 @@ def _fwd_kernel(
...
@@ -203,6 +213,9 @@ def _fwd_kernel(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max
=
tl
.
max
(
qk
,
1
)
...
@@ -306,6 +319,9 @@ def _fwd_kernel(
...
@@ -306,6 +319,9 @@ def _fwd_kernel(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
if
xai_temperature_len
>
0
:
qk
*=
xai_temperature_reg
[:,
None
]
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max
=
tl
.
max
(
qk
,
1
)
...
@@ -373,6 +389,7 @@ def extend_attention_fwd(
...
@@ -373,6 +389,7 @@ def extend_attention_fwd(
sliding_window_size
=-
1
,
sliding_window_size
=-
1
,
sinks
=
None
,
sinks
=
None
,
window_kv_offsets
=
None
,
window_kv_offsets
=
None
,
xai_temperature_len
=-
1
,
):
):
"""
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
@@ -477,6 +494,7 @@ def extend_attention_fwd(
...
@@ -477,6 +494,7 @@ def extend_attention_fwd(
v_buffer
.
stride
(
1
),
v_buffer
.
stride
(
1
),
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
SLIDING_WINDOW_SIZE
=
sliding_window_size
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
xai_temperature_len
=
xai_temperature_len
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DMODEL
=
BLOCK_DMODEL
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DPE
=
BLOCK_DPE
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
...
...
python/sglang/srt/layers/elementwise.py
View file @
86d10d22
...
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
...
@@ -486,3 +486,97 @@ def gelu_and_mul_triton(
return
out_hidden_states
,
out_scales
return
out_hidden_states
,
out_scales
else
:
else
:
return
out_hidden_states
,
None
return
out_hidden_states
,
None
# silu on first half of vector
@
triton
.
jit
def
silu_and_mul_kernel
(
out_hidden_states_ptr
,
# (bs, hidden_dim)
out_scales_ptr
,
# (bs,)
hidden_states_ptr
,
# (bs, hidden_dim * 2)
quant_max
:
tl
.
constexpr
,
static_scale
:
tl
.
constexpr
,
hidden_dim
:
tl
.
constexpr
,
# the output hidden_dim
BLOCK_SIZE
:
tl
.
constexpr
,
):
pid
=
tl
.
program_id
(
axis
=
0
)
input_start
=
pid
*
hidden_dim
*
2
output_start
=
pid
*
hidden_dim
input1_offs
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
mask
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
<
hidden_dim
# shared for input1, input3, output
input3_offs
=
hidden_dim
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
output_offs
=
tl
.
arange
(
0
,
BLOCK_SIZE
)
x1
=
tl
.
load
(
hidden_states_ptr
+
input_start
+
input1_offs
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
x3
=
tl
.
load
(
hidden_states_ptr
+
input_start
+
input3_offs
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
# silu
# cast down before mul to better match training?
silu_x1
=
x1
*
tl
.
sigmoid
(
x1
)
out
=
x3
*
silu_x1
.
to
(
hidden_states_ptr
.
dtype
.
element_ty
)
if
quant_max
is
not
None
:
raise
NotImplementedError
()
tl
.
store
(
out_hidden_states_ptr
+
output_start
+
output_offs
,
out
,
mask
=
mask
)
def
silu_and_mul_triton
(
hidden_states
,
scales
=
None
,
quantize
=
None
,
# dtype to quantize to
out
=
None
,
):
bs
,
in_hidden_dim
=
hidden_states
.
shape
hidden_dim
=
in_hidden_dim
//
2
if
out
is
None
:
out_hidden_states
=
torch
.
empty
(
(
bs
,
hidden_dim
),
dtype
=
quantize
or
hidden_states
.
dtype
,
device
=
hidden_states
.
device
,
)
else
:
assert
out
.
shape
==
(
bs
,
hidden_dim
)
assert
out
.
dtype
==
(
quantize
or
hidden_states
.
dtype
)
out_hidden_states
=
out
out_scales
=
None
static_scale
=
False
if
quantize
is
not
None
:
if
scales
is
None
:
out_scales
=
torch
.
empty
(
(
bs
,),
dtype
=
torch
.
float32
,
device
=
hidden_states
.
device
)
else
:
out_scales
=
scales
static_scale
=
True
max_warps
=
16
if
_is_hip
else
32
config
=
{
# 8 ele per thread (not tuned)
"num_warps"
:
max
(
min
(
triton
.
next_power_of_2
(
triton
.
cdiv
(
hidden_dim
,
8
*
32
)),
max_warps
),
4
),
}
silu_and_mul_kernel
[(
bs
,)](
out_hidden_states
,
out_scales
,
hidden_states
,
quant_max
=
torch
.
finfo
(
quantize
).
max
if
quantize
is
not
None
else
None
,
static_scale
=
static_scale
,
hidden_dim
=
hidden_dim
,
BLOCK_SIZE
=
triton
.
next_power_of_2
(
hidden_dim
),
**
config
,
)
if
quantize
is
not
None
:
return
out_hidden_states
,
out_scales
else
:
return
out_hidden_states
,
None
python/sglang/srt/layers/moe/router.py
View file @
86d10d22
...
@@ -45,11 +45,14 @@ def fused_moe_router_kernel(
...
@@ -45,11 +45,14 @@ def fused_moe_router_kernel(
logits
=
tl
.
sum
((
w_router
.
to
(
tl
.
float32
)
*
x
[
None
,
:].
to
(
tl
.
float32
)),
axis
=-
1
)
logits
=
tl
.
sum
((
w_router
.
to
(
tl
.
float32
)
*
x
[
None
,
:].
to
(
tl
.
float32
)),
axis
=-
1
)
# logit softcap
# logit softcap
logits_scaled
=
logits
/
moe_softcapping
if
moe_softcapping
==
0
:
exped
=
tl
.
exp
(
2
*
logits_scaled
)
logits_softcapped
=
logits
top
=
exped
-
1
else
:
bottom
=
exped
+
1
logits_scaled
=
logits
/
moe_softcapping
logits_softcapped
=
top
/
bottom
*
moe_softcapping
exped
=
tl
.
exp
(
2
*
logits_scaled
)
top
=
exped
-
1
bottom
=
exped
+
1
logits_softcapped
=
top
/
bottom
*
moe_softcapping
# Add bias after softcapping
# Add bias after softcapping
if
is_correction_bias
:
if
is_correction_bias
:
...
@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel(
...
@@ -207,9 +210,12 @@ def fused_moe_router_large_bs_kernel(
b_ptrs
+=
BLOCK_SIZE_K
b_ptrs
+=
BLOCK_SIZE_K
# 4. logit softcap
# 4. logit softcap
logits_scaled
=
acc
/
moe_softcapping
if
moe_softcapping
==
0
:
exped
=
tl
.
exp
(
2
*
logits_scaled
)
logits_softcapped
=
acc
logits_softcapped
=
(
exped
-
1
)
/
(
exped
+
1
)
*
moe_softcapping
else
:
logits_scaled
=
acc
/
moe_softcapping
exped
=
tl
.
exp
(
2
*
logits_scaled
)
logits_softcapped
=
(
exped
-
1
)
/
(
exped
+
1
)
*
moe_softcapping
# 5. top1
# 5. top1
arange_block_size_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_N
)[
None
,
:]
arange_block_size_n
=
tl
.
arange
(
0
,
BLOCK_SIZE_N
)[
None
,
:]
...
@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
...
@@ -234,7 +240,7 @@ def fused_moe_router_large_bs_kernel(
# 7. handle topk == 2
# 7. handle topk == 2
if
topk
==
2
:
if
topk
==
2
:
cond_top2
=
(
arange_block_size_n
<
num_experts
)
and
(
cond_top2
=
(
arange_block_size_n
<
num_experts
)
&
(
arange_block_size_n
!=
top1
[:,
None
]
arange_block_size_n
!=
top1
[:,
None
]
)
)
top2
=
tl
.
argmax
(
top2
=
tl
.
argmax
(
...
...
python/sglang/srt/layers/radix_attention.py
View file @
86d10d22
...
@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
...
@@ -52,6 +52,8 @@ class RadixAttention(nn.Module):
v_head_dim
:
int
=
-
1
,
v_head_dim
:
int
=
-
1
,
sliding_window_size
:
int
=
-
1
,
sliding_window_size
:
int
=
-
1
,
is_cross_attention
:
bool
=
False
,
is_cross_attention
:
bool
=
False
,
pos_encoding_mode
:
str
=
"NONE"
,
logit_capping_method
:
str
=
"tanh"
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
use_irope
:
bool
=
False
,
use_irope
:
bool
=
False
,
...
@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
...
@@ -81,6 +83,10 @@ class RadixAttention(nn.Module):
self
.
quant_method
.
create_weights
(
self
)
self
.
quant_method
.
create_weights
(
self
)
self
.
attn_type
=
attn_type
self
.
attn_type
=
attn_type
self
.
pos_encoding_mode
=
pos_encoding_mode
self
.
logit_capping_method
=
logit_capping_method
self
.
xai_temperature_len
=
-
1
def
forward
(
def
forward
(
self
,
self
,
q
,
q
,
...
...
python/sglang/srt/models/grok.py
View file @
86d10d22
...
@@ -16,7 +16,6 @@
...
@@ -16,7 +16,6 @@
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
"""Inference-only Grok1 model."""
"""Inference-only Grok1 model."""
import
functools
import
functools
import
json
import
logging
import
logging
import
math
import
math
import
os
import
os
...
@@ -35,9 +34,16 @@ from sglang.srt.distributed import (
...
@@ -35,9 +34,16 @@ from sglang.srt.distributed import (
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
)
)
from
sglang.srt.layers.elementwise
import
fused_dual_residual_rmsnorm
,
fused_rmsnorm
from
sglang.srt.layers.activation
import
GeluAndMul
from
sglang.srt.layers.elementwise
import
(
experts_combine_triton
,
fused_dual_residual_rmsnorm
,
fused_rmsnorm
,
gelu_and_mul_triton
,
)
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.layernorm
import
RMSNorm
from
sglang.srt.layers.linear
import
(
from
sglang.srt.layers.linear
import
(
MergedColumnParallelLinear
,
QKVParallelLinear
,
QKVParallelLinear
,
ReplicatedLinear
,
ReplicatedLinear
,
RowParallelLinear
,
RowParallelLinear
,
...
@@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim
...
@@ -49,7 +55,12 @@ from sglang.srt.layers.moe.router import fused_moe_router_shim
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.moe.topk
import
TopK
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.radix_attention
import
RadixAttention
from
sglang.srt.layers.rotary_embedding
import
get_rope
from
sglang.srt.layers.rotary_embedding
import
(
RotaryEmbedding
,
_yarn_find_correction_range
,
_yarn_get_mscale
,
get_rope
,
)
from
sglang.srt.layers.vocab_parallel_embedding
import
(
from
sglang.srt.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
ParallelLMHead
,
VocabParallelEmbedding
,
VocabParallelEmbedding
,
...
@@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
...
@@ -58,13 +69,60 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_executor.forward_batch_info
import
ForwardBatch
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.loader
import
DefaultModelLoader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.utils
import
dump_to_file
from
sglang.srt.utils
import
add_prefix
,
dispose_tensor
,
dump_to_file
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
# Dump tensors for debugging
debug_tensor_dump_output_folder
=
None
debug_tensor_dump_output_folder
=
None
debug_tensor_dump_prefill_only
=
False
# Skip all the other tensor dumps, only dump the target logits
debug_tensor_dump_only_target_logprobs
=
False
debug_tensor_dump_inject
=
False
debug_tensor_dump_inject
=
False
debug_tensor_dump_layers
=
None
debug_tensor_dump_test
=
False
class
Grok1MLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
layer_id
:
int
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
reduce_results
=
True
,
use_presharded_weights
:
bool
=
False
,
split_gate_up
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
self
.
gate_up_proj
=
MergedColumnParallelLinear
(
hidden_size
,
[
intermediate_size
]
*
2
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"gate_up_proj"
,
prefix
),
use_presharded_weights
=
use_presharded_weights
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
quant_config
=
quant_config
,
prefix
=
add_prefix
(
"down_proj"
,
prefix
),
reduce_results
=
reduce_results
,
use_presharded_weights
=
use_presharded_weights
,
)
self
.
act_fn
=
GeluAndMul
(
approximate
=
"tanh"
)
self
.
layer_id
=
layer_id
def
forward
(
self
,
x
):
gate_up
,
_
=
self
.
gate_up_proj
(
x
)
x
,
_
=
gelu_and_mul_triton
(
gate_up
)
x
,
_
=
self
.
down_proj
(
x
)
return
x
class
Grok1MoE
(
nn
.
Module
):
class
Grok1MoE
(
nn
.
Module
):
...
@@ -87,10 +145,11 @@ class Grok1MoE(nn.Module):
...
@@ -87,10 +145,11 @@ class Grok1MoE(nn.Module):
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
tp_size
:
Optional
[
int
]
=
None
,
reduce_results
=
True
,
reduce_results
:
bool
=
True
,
use_presharded_weights
:
bool
=
False
,
use_presharded_weights
:
bool
=
False
,
inplace
:
bool
=
True
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
prefix
:
str
=
""
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -145,6 +204,135 @@ class Grok1MoE(nn.Module):
...
@@ -145,6 +204,135 @@ class Grok1MoE(nn.Module):
return
self
.
experts
(
hidden_states
,
topk_output
)
return
self
.
experts
(
hidden_states
,
topk_output
)
def
_yarn_linear_ramp_mask
(
low
:
float
,
high
:
float
,
dim
:
int
,
dtype
:
torch
.
dtype
)
->
torch
.
Tensor
:
if
low
==
high
:
low
-=
0.001
# Prevent singularity
linear_func
=
(
torch
.
arange
(
dim
,
dtype
=
dtype
)
-
low
)
/
(
high
-
low
)
ramp_func
=
torch
.
clamp
(
linear_func
,
0
,
1
)
return
ramp_func
def
get_rope_scaling
(
config
):
rope_type
=
getattr
(
config
,
"rope_type"
,
None
)
if
rope_type
:
original_max_position_embeddings
=
getattr
(
config
,
"original_max_position_embeddings"
,
None
)
scaling_factor
=
getattr
(
config
,
"scaling_factor"
,
None
)
extrapolation_factor
=
getattr
(
config
,
"extrapolation_factor"
,
1.0
)
attn_factor
=
getattr
(
config
,
"attn_factor"
,
1.0
)
beta_fast
=
getattr
(
config
,
"beta_fast"
,
32
)
beta_slow
=
getattr
(
config
,
"beta_slow"
,
1
)
rope_scaling
=
{
"extra_method"
:
rope_type
,
"max_position_embeddings"
:
original_max_position_embeddings
,
"scaling_factor"
:
scaling_factor
,
"extrapolation_factor"
:
extrapolation_factor
,
"attn_factor"
:
attn_factor
,
"beta_fast"
:
beta_fast
,
"beta_slow"
:
beta_slow
,
"dtype"
:
torch
.
float
,
}
return
rope_scaling
else
:
return
None
class
ScalingRotaryEmbedding
(
RotaryEmbedding
):
"""Scale the RotaryEmbedding in a way similar to YaRN method. https://arxiv.org/pdf/2309.00071."""
def
__init__
(
self
,
head_size
:
int
,
rotary_dim
:
int
,
max_position_embeddings
:
int
,
base
:
int
,
is_neox_style
:
bool
,
scaling_factor
:
float
,
dtype
:
torch
.
dtype
,
*
,
extra_method
:
str
=
"yarn_log"
,
extrapolation_factor
:
float
=
1
,
attn_factor
:
float
=
1
,
beta_fast
:
int
=
32
,
beta_slow
:
int
=
1
,
)
->
None
:
self
.
scaling_factor
=
scaling_factor
self
.
extra_method
=
extra_method
self
.
extrapolation_factor
=
extrapolation_factor
self
.
attn_factor
=
attn_factor
self
.
beta_fast
=
beta_fast
self
.
beta_slow
=
beta_slow
# Get n-d magnitude scaling corrected for interpolation
self
.
mscale
=
float
(
_yarn_get_mscale
(
self
.
scaling_factor
)
*
attn_factor
)
super
().
__init__
(
head_size
,
rotary_dim
,
max_position_embeddings
,
base
,
is_neox_style
,
dtype
)
def
_compute_inv_freq
(
self
,
scaling_factor
:
float
)
->
torch
.
Tensor
:
pos_freqs
=
self
.
base
**
(
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
/
self
.
rotary_dim
)
inv_freq_extrapolation
=
1.0
/
pos_freqs
inv_freq_interpolation
=
1.0
/
(
scaling_factor
*
pos_freqs
)
low
,
high
=
_yarn_find_correction_range
(
self
.
beta_fast
,
self
.
beta_slow
,
self
.
rotary_dim
,
self
.
base
,
self
.
max_position_embeddings
,
)
# Get n-d rotational scaling corrected for extrapolation
inv_freq_mask
=
(
1
-
_yarn_linear_ramp_mask
(
low
,
high
,
self
.
rotary_dim
//
2
,
dtype
=
torch
.
float
)
)
*
self
.
extrapolation_factor
if
self
.
extra_method
in
[
"original"
]:
inv_freq
=
inv_freq_extrapolation
elif
self
.
extra_method
in
[
"yarn"
,
"yarn_linear"
]:
inv_freq
=
(
inv_freq_interpolation
*
(
1
-
inv_freq_mask
)
+
inv_freq_extrapolation
*
inv_freq_mask
)
elif
self
.
extra_method
==
"yarn_log"
:
inv_freq
=
torch
.
exp
(
torch
.
log
(
inv_freq_extrapolation
)
*
inv_freq_mask
+
torch
.
log
(
inv_freq_interpolation
)
*
(
1.0
-
inv_freq_mask
)
)
elif
self
.
extra_method
==
"theta_scale"
:
exponents
=
torch
.
arange
(
0
,
self
.
rotary_dim
,
2
,
dtype
=
torch
.
float
)
theta_scale_exponent
=
self
.
base
**
(
math
.
log
(
self
.
max_position_embeddings
*
self
.
scaling_factor
/
(
2
*
math
.
pi
)
)
/
math
.
log
(
self
.
max_position_embeddings
/
(
2
*
math
.
pi
))
)
inv_freq
=
torch
.
tensor
(
1.0
/
(
theta_scale_exponent
**
(
exponents
/
self
.
rotary_dim
)),
dtype
=
torch
.
float32
,
)
else
:
raise
ValueError
(
f
"Unknown extrapolation method:
{
self
.
extra_method
}
"
)
return
inv_freq
def
_compute_cos_sin_cache
(
self
)
->
torch
.
Tensor
:
inv_freq
=
self
.
_compute_inv_freq
(
self
.
scaling_factor
)
t
=
torch
.
arange
(
self
.
max_position_embeddings
*
self
.
scaling_factor
,
dtype
=
torch
.
float32
)
freqs
=
torch
.
einsum
(
"i,j -> ij"
,
t
,
inv_freq
)
# cos = freqs.cos() * self.mscale
# sin = freqs.sin() * self.mscale
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
return
cache
class
Grok1Attention
(
nn
.
Module
):
class
Grok1Attention
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
self
,
...
@@ -157,7 +345,9 @@ class Grok1Attention(nn.Module):
...
@@ -157,7 +345,9 @@ class Grok1Attention(nn.Module):
rope_theta
:
float
=
10000
,
rope_theta
:
float
=
10000
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
load_presharded_attn
:
bool
=
False
,
load_presharded_attn
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -183,7 +373,9 @@ class Grok1Attention(nn.Module):
...
@@ -183,7 +373,9 @@ class Grok1Attention(nn.Module):
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
scaling
=
self
.
head_dim
**-
0.5
self
.
rope_theta
=
rope_theta
self
.
rope_theta
=
rope_theta
rope_scaling
=
get_rope_scaling
(
config
)
self
.
load_presharded_attn
=
load_presharded_attn
self
.
load_presharded_attn
=
load_presharded_attn
self
.
alt_stream
=
alt_stream
or
torch
.
cuda
.
Stream
()
self
.
qkv_proj
=
QKVParallelLinear
(
self
.
qkv_proj
=
QKVParallelLinear
(
hidden_size
,
hidden_size
,
...
@@ -195,6 +387,7 @@ class Grok1Attention(nn.Module):
...
@@ -195,6 +387,7 @@ class Grok1Attention(nn.Module):
tp_rank
=
attn_tp_rank
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
tp_size
=
attn_tp_size
,
load_presharded_attn
=
self
.
load_presharded_attn
,
load_presharded_attn
=
self
.
load_presharded_attn
,
prefix
=
add_prefix
(
"qkv_proj"
,
prefix
),
)
)
self
.
o_proj
=
RowParallelLinear
(
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
self
.
total_num_heads
*
self
.
head_dim
,
...
@@ -205,6 +398,7 @@ class Grok1Attention(nn.Module):
...
@@ -205,6 +398,7 @@ class Grok1Attention(nn.Module):
tp_rank
=
attn_tp_rank
,
tp_rank
=
attn_tp_rank
,
tp_size
=
attn_tp_size
,
tp_size
=
attn_tp_size
,
use_presharded_weights
=
self
.
load_presharded_attn
,
use_presharded_weights
=
self
.
load_presharded_attn
,
prefix
=
add_prefix
(
"o_proj"
,
prefix
),
)
)
self
.
rotary_emb
=
get_rope
(
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
self
.
head_dim
,
...
@@ -214,7 +408,37 @@ class Grok1Attention(nn.Module):
...
@@ -214,7 +408,37 @@ class Grok1Attention(nn.Module):
is_neox_style
=
True
,
is_neox_style
=
True
,
)
)
self
.
rope_rotate_half_dims
=
getattr
(
config
,
"rope_rotate_half_dims"
,
False
)
if
rope_scaling
is
not
None
:
self
.
rotary_emb
=
ScalingRotaryEmbedding
(
self
.
head_dim
,
rotary_dim
=
(
self
.
head_dim
if
not
self
.
rope_rotate_half_dims
else
self
.
head_dim
//
2
),
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
**
rope_scaling
,
)
pos_encoding_mode
=
"NONE"
else
:
self
.
rotary_emb
=
get_rope
(
self
.
head_dim
,
rotary_dim
=
(
self
.
head_dim
if
not
self
.
rope_rotate_half_dims
else
self
.
head_dim
//
2
),
max_position
=
max_position
,
base
=
int
(
self
.
rope_theta
),
is_neox_style
=
True
,
)
pos_encoding_mode
=
"NONE"
logit_cap
=
max
(
getattr
(
config
,
"attn_logit_softcapping"
,
30.0
),
0.0
)
logit_cap
=
max
(
getattr
(
config
,
"attn_logit_softcapping"
,
30.0
),
0.0
)
logit_capping_method
=
getattr
(
config
,
"attn_logit_softcapping_method"
,
"tanh"
)
self
.
attn
=
RadixAttention
(
self
.
attn
=
RadixAttention
(
self
.
num_heads
,
self
.
num_heads
,
...
@@ -224,7 +448,11 @@ class Grok1Attention(nn.Module):
...
@@ -224,7 +448,11 @@ class Grok1Attention(nn.Module):
layer_id
=
layer_id
,
layer_id
=
layer_id
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
pos_encoding_mode
=
pos_encoding_mode
,
logit_capping_method
=
logit_capping_method
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
)
self
.
attn
.
xai_temperature_len
=
getattr
(
self
.
config
,
"attn_temperature_len"
,
-
1
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -256,6 +484,8 @@ class Grok1Attention(nn.Module):
...
@@ -256,6 +484,8 @@ class Grok1Attention(nn.Module):
)
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
qkv
,
_
=
self
.
qkv_proj
(
hidden_states
)
dispose_tensor
(
hidden_states
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
q
,
k
=
self
.
rotary_emb
(
positions
,
q
,
k
)
...
@@ -288,6 +518,7 @@ class Grok1Attention(nn.Module):
...
@@ -288,6 +518,7 @@ class Grok1Attention(nn.Module):
)
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
attn_output
=
self
.
attn
(
q
,
k
,
v
,
forward_batch
)
del
q
,
k
,
v
,
qkv
if
debug_tensor_dump_output_folder
:
if
debug_tensor_dump_output_folder
:
dump_to_file
(
dump_to_file
(
...
@@ -312,49 +543,89 @@ class Grok1DecoderLayer(nn.Module):
...
@@ -312,49 +543,89 @@ class Grok1DecoderLayer(nn.Module):
load_presharded_moe
:
bool
=
False
,
load_presharded_moe
:
bool
=
False
,
load_presharded_attn
:
bool
=
False
,
load_presharded_attn
:
bool
=
False
,
load_presharded_mlp
:
bool
=
False
,
load_presharded_mlp
:
bool
=
False
,
alt_stream
:
Optional
[
torch
.
cuda
.
Stream
]
=
None
,
skip_moe
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
num_experts
=
config
.
num_local_experts
self
.
num_experts
=
config
.
num_local_experts
self
.
hidden_size
=
config
.
hidden_size
self
.
hidden_size
=
config
.
hidden_size
self
.
residual_moe
=
getattr
(
config
,
"residual_moe"
,
False
)
self
.
layer_id
=
layer_id
self
.
layer_id
=
layer_id
self
.
alt_stream
=
alt_stream
or
torch
.
cuda
.
Stream
()
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
rope_theta
=
getattr
(
config
,
"rope_theta"
,
10000
)
self
.
self_attn
=
Grok1Attention
(
self
.
self_attn
=
Grok1Attention
(
config
=
config
,
config
=
config
,
hidden_size
=
self
.
hidden_size
,
hidden_size
=
self
.
hidden_size
,
num_heads
=
config
.
num_attention_heads
,
num_heads
=
config
.
num_attention_heads
,
max_position
=
config
.
max_position_embeddings
,
max_position
=
(
config
.
context_len
if
hasattr
(
config
,
"context_len"
)
else
config
.
max_position_embeddings
),
num_kv_heads
=
config
.
num_key_value_heads
,
num_kv_heads
=
config
.
num_key_value_heads
,
layer_id
=
layer_id
,
layer_id
=
layer_id
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
quant_config
=
quant_config
,
quant_config
=
quant_config
,
reduce_results
=
False
,
reduce_results
=
False
,
alt_stream
=
self
.
alt_stream
,
load_presharded_attn
=
load_presharded_attn
,
load_presharded_attn
=
load_presharded_attn
,
prefix
=
add_prefix
(
"attn"
,
prefix
),
)
)
self
.
block_sparse_moe
=
Grok1MoE
(
config
=
config
,
split_gate_up
=
not
getattr
(
config
,
"merge_gate_up"
,
True
)
layer_id
=
layer_id
,
if
self
.
num_experts
>
0
:
num_experts
=
config
.
num_local_experts
,
self
.
block_sparse_moe
=
Grok1MoE
(
top_k
=
config
.
num_experts_per_tok
,
config
=
config
,
hidden_size
=
config
.
hidden_size
,
layer_id
=
layer_id
,
intermediate_size
=
getattr
(
num_experts
=
config
.
num_local_experts
,
config
,
top_k
=
config
.
num_experts_per_tok
,
"moe_intermediate_size"
,
hidden_size
=
config
.
hidden_size
,
getattr
(
config
,
"intermediate_size"
,
None
),
intermediate_size
=
getattr
(
),
config
,
quant_config
=
quant_config
,
"moe_intermediate_size"
,
reduce_results
=
True
,
getattr
(
config
,
"intermediate_size"
,
None
),
use_presharded_weights
=
load_presharded_moe
,
),
inplace
=
True
,
quant_config
=
quant_config
,
no_combine
=
False
,
# just a suggestion to not combine topk
reduce_results
=
not
self
.
residual_moe
,
)
use_presharded_weights
=
load_presharded_moe
,
inplace
=
False
,
# not self.residual_moe,
no_combine
=
False
,
# self.residual_moe, # just a suggestion to not combine topk
prefix
=
add_prefix
(
"block_sparse_moe"
,
prefix
),
)
if
self
.
residual_moe
:
self
.
mlp
=
Grok1MLP
(
hidden_size
=
config
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
quant_config
=
quant_config
,
reduce_results
=
False
,
use_presharded_weights
=
load_presharded_mlp
,
layer_id
=
layer_id
,
split_gate_up
=
split_gate_up
,
)
else
:
raise
NotImplementedError
()
self
.
pre_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attn_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_moe_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
pre_moe_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_moe_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_moe_norm
=
RMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
ffn
=
self
.
block_sparse_moe
if
self
.
num_experts
>
0
:
if
self
.
residual_moe
:
# NOTE: self.block_sparse_moe modifies the input in-place,
# so we have to call it later. Be aware of any possible related errors.
if
get_tensor_model_parallel_world_size
()
>
1
:
self
.
ffn
=
lambda
x
:
tensor_model_parallel_all_reduce
(
self
.
moe_with_rmoe
(
x
)
)
else
:
self
.
ffn
=
self
.
moe_with_rmoe
else
:
self
.
ffn
=
self
.
block_sparse_moe
else
:
raise
NotImplementedError
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -364,6 +635,10 @@ class Grok1DecoderLayer(nn.Module):
...
@@ -364,6 +635,10 @@ class Grok1DecoderLayer(nn.Module):
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
deferred_norm
:
Optional
[
RMSNorm
]
=
None
,
deferred_norm
:
Optional
[
RMSNorm
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
RMSNorm
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
RMSNorm
]:
hidden_states_original
=
hidden_states
residual_original
=
residual
# Self Attention
# Self Attention
if
deferred_norm
is
not
None
:
if
deferred_norm
is
not
None
:
assert
residual
is
not
None
assert
residual
is
not
None
...
@@ -386,6 +661,14 @@ class Grok1DecoderLayer(nn.Module):
...
@@ -386,6 +661,14 @@ class Grok1DecoderLayer(nn.Module):
hidden_states
,
hidden_states
,
)
)
if
residual_original
is
not
None
:
dispose_tensor
(
residual_original
)
dispose_flag
=
False
if
residual
is
not
hidden_states_original
:
dispose_flag
=
True
dispose_tensor
(
hidden_states_original
)
hidden_states
=
self
.
self_attn
(
hidden_states
=
self
.
self_attn
(
positions
=
positions
,
positions
=
positions
,
hidden_states
=
hidden_states
,
hidden_states
=
hidden_states
,
...
@@ -403,10 +686,23 @@ class Grok1DecoderLayer(nn.Module):
...
@@ -403,10 +686,23 @@ class Grok1DecoderLayer(nn.Module):
self
.
post_attn_norm
.
variance_epsilon
,
self
.
post_attn_norm
.
variance_epsilon
,
)
)
if
not
dispose_flag
:
dispose_tensor
(
hidden_states_original
)
# Fully Connected
# Fully Connected
hidden_states
=
self
.
ffn
(
hidden_states
)
hidden_states
=
self
.
ffn
(
hidden_states
)
return
hidden_states
,
residual
,
self
.
post_moe_norm
# defer layernorm
return
hidden_states
,
residual
,
self
.
post_moe_norm
# defer layernorm
def
moe_with_rmoe
(
self
,
x
):
current_stream
=
torch
.
cuda
.
current_stream
()
self
.
alt_stream
.
wait_stream
(
current_stream
)
mlp_result
=
self
.
mlp
(
x
)
with
torch
.
cuda
.
stream
(
self
.
alt_stream
):
# moe should not be inplace because of stream race condition
moe_result
=
self
.
block_sparse_moe
(
x
)
current_stream
.
wait_stream
(
self
.
alt_stream
)
return
(
mlp_result
+
moe_result
)
/
1.4142135623730951
class
Grok1Model
(
nn
.
Module
):
class
Grok1Model
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -417,6 +713,8 @@ class Grok1Model(nn.Module):
...
@@ -417,6 +713,8 @@ class Grok1Model(nn.Module):
load_presharded_embedding
:
bool
=
False
,
load_presharded_embedding
:
bool
=
False
,
load_presharded_attn
:
bool
=
False
,
load_presharded_attn
:
bool
=
False
,
load_presharded_mlp
:
bool
=
False
,
load_presharded_mlp
:
bool
=
False
,
replicate_embedding
:
bool
=
False
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -427,7 +725,11 @@ class Grok1Model(nn.Module):
...
@@ -427,7 +725,11 @@ class Grok1Model(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
config
.
hidden_size
,
config
.
hidden_size
,
use_presharded_weights
=
load_presharded_embedding
,
use_presharded_weights
=
load_presharded_embedding
,
enable_tp
=
not
replicate_embedding
,
prefix
=
add_prefix
(
"embed_tokens"
,
prefix
),
)
)
self
.
alt_stream
=
torch
.
cuda
.
Stream
()
self
.
layers
=
nn
.
ModuleList
(
self
.
layers
=
nn
.
ModuleList
(
[
[
Grok1DecoderLayer
(
Grok1DecoderLayer
(
...
@@ -437,6 +739,7 @@ class Grok1Model(nn.Module):
...
@@ -437,6 +739,7 @@ class Grok1Model(nn.Module):
load_presharded_moe
=
load_presharded_moe
,
load_presharded_moe
=
load_presharded_moe
,
load_presharded_attn
=
load_presharded_attn
,
load_presharded_attn
=
load_presharded_attn
,
load_presharded_mlp
=
load_presharded_mlp
,
load_presharded_mlp
=
load_presharded_mlp
,
alt_stream
=
self
.
alt_stream
,
)
)
for
i
in
range
(
config
.
num_hidden_layers
)
for
i
in
range
(
config
.
num_hidden_layers
)
]
]
...
@@ -506,6 +809,7 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -506,6 +809,7 @@ class Grok1ForCausalLM(nn.Module):
self
,
self
,
config
:
PretrainedConfig
,
config
:
PretrainedConfig
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
)
->
None
:
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
config
=
config
self
.
config
=
config
...
@@ -514,7 +818,8 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -514,7 +818,8 @@ class Grok1ForCausalLM(nn.Module):
# Get presharded weights.
# Get presharded weights.
self
.
load_presharded_mlp
=
getattr
(
config
,
"load_presharded_mlp"
,
False
)
self
.
load_presharded_mlp
=
getattr
(
config
,
"load_presharded_mlp"
,
False
)
self
.
load_presharded_moe
=
(
self
.
load_presharded_moe
=
(
self
.
config
.
num_local_experts
>
0
getattr
(
config
,
"load_presharded_moe"
,
True
)
and
self
.
config
.
num_local_experts
>
0
and
get_tensor_model_parallel_world_size
()
>
1
and
get_tensor_model_parallel_world_size
()
>
1
)
)
self
.
load_presharded_attn
=
getattr
(
config
,
"load_presharded_attn"
,
False
)
self
.
load_presharded_attn
=
getattr
(
config
,
"load_presharded_attn"
,
False
)
...
@@ -529,6 +834,11 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -529,6 +834,11 @@ class Grok1ForCausalLM(nn.Module):
or
self
.
load_presharded_embedding
or
self
.
load_presharded_embedding
)
)
default_replicate_lm_head
=
False
self
.
replicate_lm_head
=
getattr
(
config
,
"replicate_lm_head"
,
default_replicate_lm_head
)
if
self
.
is_weights_presharded
:
if
self
.
is_weights_presharded
:
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
setattr
(
DefaultModelLoader
,
"_prepare_weights"
,
_prepare_presharded_weights
)
...
@@ -536,6 +846,7 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -536,6 +846,7 @@ class Grok1ForCausalLM(nn.Module):
self
.
replicate_lm_head
=
getattr
(
self
.
replicate_lm_head
=
getattr
(
config
,
"replicate_lm_head"
,
default_replicate_lm_head
config
,
"replicate_lm_head"
,
default_replicate_lm_head
)
)
self
.
replicate_embedding
=
getattr
(
config
,
"replicate_embedding"
,
False
)
self
.
model
=
Grok1Model
(
self
.
model
=
Grok1Model
(
config
,
config
,
...
@@ -544,6 +855,8 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -544,6 +855,8 @@ class Grok1ForCausalLM(nn.Module):
load_presharded_embedding
=
self
.
load_presharded_embedding
,
load_presharded_embedding
=
self
.
load_presharded_embedding
,
load_presharded_attn
=
self
.
load_presharded_attn
,
load_presharded_attn
=
self
.
load_presharded_attn
,
load_presharded_mlp
=
self
.
load_presharded_mlp
,
load_presharded_mlp
=
self
.
load_presharded_mlp
,
replicate_embedding
=
self
.
replicate_embedding
,
prefix
=
add_prefix
(
"model"
,
prefix
),
)
)
lm_head_params_dtype
=
None
lm_head_params_dtype
=
None
...
@@ -553,6 +866,7 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -553,6 +866,7 @@ class Grok1ForCausalLM(nn.Module):
config
.
vocab_size
,
config
.
vocab_size
,
bias
=
False
,
bias
=
False
,
params_dtype
=
lm_head_params_dtype
,
params_dtype
=
lm_head_params_dtype
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
skip_all_gather
=
True
)
self
.
logits_processor
=
LogitsProcessor
(
config
,
skip_all_gather
=
True
)
else
:
else
:
...
@@ -561,6 +875,7 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -561,6 +875,7 @@ class Grok1ForCausalLM(nn.Module):
config
.
hidden_size
,
config
.
hidden_size
,
use_presharded_weights
=
self
.
load_presharded_embedding
,
use_presharded_weights
=
self
.
load_presharded_embedding
,
params_dtype
=
lm_head_params_dtype
,
params_dtype
=
lm_head_params_dtype
,
prefix
=
add_prefix
(
"lm_head"
,
prefix
),
)
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
self
.
logits_processor
=
LogitsProcessor
(
config
)
...
@@ -577,6 +892,7 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -577,6 +892,7 @@ class Grok1ForCausalLM(nn.Module):
f
"#parameters (analytical):
{
self
.
get_num_params_analytical
()
/
1e9
:.
2
f
}
B, "
f
"#parameters (analytical):
{
self
.
get_num_params_analytical
()
/
1e9
:.
2
f
}
B, "
f
"#parameters (actual):
{
self
.
get_num_params_torch
()
/
1e9
:.
2
f
}
B"
f
"#parameters (actual):
{
self
.
get_num_params_torch
()
/
1e9
:.
2
f
}
B"
)
)
self
.
loaded_param_names
=
set
()
def
forward
(
def
forward
(
self
,
self
,
...
@@ -596,11 +912,13 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -596,11 +912,13 @@ class Grok1ForCausalLM(nn.Module):
def
load_weights
(
def
load_weights
(
self
,
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
num_experts
:
Optional
[
int
]
=
None
,
ignore_parent_name
:
bool
=
False
,
ignore_parent_name
:
bool
=
False
,
check_hit_names
:
bool
=
True
,
model_config
:
PretrainedConfig
|
None
=
None
,
)
->
dict
[
str
,
torch
.
Tensor
]:
)
->
dict
[
str
,
torch
.
Tensor
]:
if
num_experts
is
None
:
if
model_config
is
None
:
num_experts
=
self
.
config
.
num_local_experts
model_config
=
self
.
config
stacked_params_mapping
=
[]
stacked_params_mapping
=
[]
stacked_params_mapping
+=
[
stacked_params_mapping
+=
[
# (param_name, shard_name, shard_id)
# (param_name, shard_name, shard_id)
...
@@ -616,6 +934,7 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -616,6 +934,7 @@ class Grok1ForCausalLM(nn.Module):
# Params for weights, fp8 weight scales, fp8 activation scales
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
# (param_name, weight_name, expert_id, shard_id)
num_experts
=
model_config
.
num_local_experts
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
expert_params_mapping
=
FusedMoE
.
make_expert_params_mapping
(
ckpt_gate_proj_name
=
"w1"
,
ckpt_gate_proj_name
=
"w1"
,
ckpt_down_proj_name
=
"w2"
,
ckpt_down_proj_name
=
"w2"
,
...
@@ -630,23 +949,26 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -630,23 +949,26 @@ class Grok1ForCausalLM(nn.Module):
def
load_weight_wrapper
(
def
load_weight_wrapper
(
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
*
args
,
**
kwargs
name
:
str
,
loaded_weight
:
torch
.
Tensor
,
*
args
,
**
kwargs
):
):
if
ignore_parent_name
:
name
=
name
.
split
(
"."
)[
-
1
]
if
name
not
in
params_dict
:
return
# Fuse constant multipliers into the weights
# Fuse constant multipliers into the weights
if
"lm_head"
in
name
:
if
"lm_head"
in
name
:
loaded_weight
=
(
loaded_weight
=
(
loaded_weight
.
to
(
torch
.
float32
)
loaded_weight
.
to
(
torch
.
float32
)
*
self
.
config
.
output_multiplier_scale
*
model_
config
.
output_multiplier_scale
)
)
original_name
=
name
if
ignore_parent_name
:
name
=
name
.
split
(
"."
)[
-
1
]
if
name
not
in
params_dict
:
logger
.
info
(
f
"Skipping
{
name
=
}
in load_weights_wrapper"
)
return
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
weight_loader
(
param
,
loaded_weight
,
*
args
,
**
kwargs
)
hit_names
.
add
(
name
)
hit_names
.
add
(
name
)
self
.
loaded_param_names
.
add
(
original_name
)
for
name
,
loaded_weight
in
weights
:
for
name
,
loaded_weight
in
weights
:
if
"rotary_emb.inv_freq"
in
name
:
if
"rotary_emb.inv_freq"
in
name
:
...
@@ -685,19 +1007,22 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -685,19 +1007,22 @@ class Grok1ForCausalLM(nn.Module):
load_weight_wrapper
(
name
=
name
,
loaded_weight
=
loaded_weight
)
load_weight_wrapper
(
name
=
name
,
loaded_weight
=
loaded_weight
)
if
len
(
hit_names
)
>
5
:
if
check_hit_names
:
missing
=
all_names
-
hit_names
if
len
(
hit_names
)
>
5
:
missing_exclude_scales
=
{
x
for
x
in
missing
if
"scale"
not
in
x
}
missing
=
all_names
-
hit_names
logger
.
info
(
missing_exclude_scales
=
{
x
for
x
in
missing
if
"scale"
not
in
x
}
f
"#all_names:
{
len
(
all_names
)
}
, #hit_names:
{
len
(
hit_names
)
}
, #missing_exclude_scales:
{
len
(
missing_exclude_scales
)
}
"
,
logger
.
info
(
)
f
"#all_names:
{
len
(
all_names
)
}
, #hit_names:
{
len
(
hit_names
)
}
, #missing_exclude_scales:
{
len
(
missing_exclude_scales
)
}
"
,
if
len
(
missing_exclude_scales
)
>
0
:
raise
ValueError
(
f
"load_weights failed because some weights are missing:
{
missing_exclude_scales
=
}
."
)
)
if
len
(
missing_exclude_scales
)
>
0
:
raise
ValueError
(
f
"load_weights failed because some weights are missing:
{
missing_exclude_scales
=
}
."
)
elif
len
(
hit_names
)
==
0
:
elif
len
(
hit_names
)
==
0
:
raise
ValueError
(
"load_weights failed because it did not hit any names."
)
raise
ValueError
(
f
"load_weights failed because it did not hit any names.
{
all_names
=
}
{
hit_names
=
}
"
)
return
hit_names
return
hit_names
...
@@ -708,7 +1033,11 @@ class Grok1ForCausalLM(nn.Module):
...
@@ -708,7 +1033,11 @@ class Grok1ForCausalLM(nn.Module):
"moe_intermediate_size"
,
"moe_intermediate_size"
,
getattr
(
cfg
,
"intermediate_size"
,
None
),
getattr
(
cfg
,
"intermediate_size"
,
None
),
)
)
num_experts
=
cfg
.
num_local_experts
residual_moe
=
getattr
(
cfg
,
"residual_moe"
,
False
)
if
cfg
.
num_local_experts
>
0
:
num_experts
=
cfg
.
num_local_experts
+
(
1
if
residual_moe
else
0
)
else
:
num_experts
=
1
wq
=
(
wq
=
(
cfg
.
num_hidden_layers
cfg
.
num_hidden_layers
...
...
python/sglang/srt/tokenizer/tiktoken_tokenizer.py
0 → 100644
View file @
86d10d22
import
functools
import
json
from
typing
import
AbstractSet
,
Collection
,
List
,
Literal
,
Union
class
TiktokenProcessor
:
def
__init__
(
self
,
name
:
str
):
self
.
tokenizer
=
TiktokenTokenizer
(
name
)
def
image_processor
(
self
,
image
):
return
{
"pixel_values"
:
[
image
]}
RESERVED_TOKEN_TEXTS
=
[
f
"<|reserved_
{
i
}
|>"
for
i
in
range
(
3
,
128
)]
CONTROL_TOKEN_TEXTS
=
[
f
"<|control
{
i
}
|>"
for
i
in
range
(
1
,
705
)]
PAD
=
"<|pad|>"
EOS
=
"<|eos|>"
SEP
=
"<|separator|>"
DEFAULT_SPECIAL_TOKENS
=
[
PAD
,
SEP
,
EOS
]
DEFAULT_CONTROL_TOKENS
=
{
"pad"
:
PAD
,
"sep"
:
EOS
,
"eos"
:
SEP
}
# default + separate each single digit
PAT_STR_B
=
r
"""(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+"""
class
TiktokenTokenizer
:
def
__init__
(
self
,
tokenizer_path
):
import
tiktoken
from
jinja2
import
Template
# Read the JSON
with
open
(
tokenizer_path
,
"rb"
)
as
fin
:
xtok_dict
=
json
.
load
(
fin
)
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::from_xtok_dict
mergeable_ranks
=
{
bytes
(
item
[
"bytes"
]):
item
[
"token"
]
for
item
in
xtok_dict
[
"regular_tokens"
]
}
special_tokens
=
{
bytes
(
item
[
"bytes"
]).
decode
():
item
[
"token"
]
for
item
in
xtok_dict
[
"special_tokens"
]
}
if
xtok_dict
[
"word_split"
]
==
"V1"
:
pad_str
=
PAT_STR_B
else
:
assert
False
,
f
"Unknown word_split:
{
xtok_dict
[
'word_split'
]
}
"
pad_str
=
xtok_dict
.
get
(
"pat_str"
,
pad_str
)
kwargs
=
{
"name"
:
tokenizer_path
,
"pat_str"
:
pad_str
,
"mergeable_ranks"
:
mergeable_ranks
,
"special_tokens"
:
special_tokens
,
}
if
"default_allowed_special"
in
xtok_dict
:
default_allowed_special
=
set
(
[
bytes
(
bytes_list
).
decode
()
for
bytes_list
in
xtok_dict
[
"default_allowed_special"
]
]
)
if
"vocab_size"
in
xtok_dict
:
kwargs
[
"explicit_n_vocab"
]
=
xtok_dict
[
"vocab_size"
]
# Copy from train/xlm/tokenizers/tiktoken_wrapper.py::Encoding::__init__
default_allowed_special
=
None
control_tokens
=
DEFAULT_CONTROL_TOKENS
tokenizer
=
tiktoken
.
Encoding
(
**
kwargs
)
tokenizer
.
_default_allowed_special
=
default_allowed_special
or
set
()
tokenizer
.
_control_tokens
=
control_tokens
def
encode_patched
(
self
,
text
:
str
,
*
,
allowed_special
:
Union
[
Literal
[
"all"
],
AbstractSet
[
str
]
]
=
set
(),
# noqa: B006
disallowed_special
:
Union
[
Literal
[
"all"
],
Collection
[
str
]]
=
"all"
,
)
->
List
[
int
]:
if
isinstance
(
allowed_special
,
set
):
allowed_special
|=
self
.
_default_allowed_special
return
tiktoken
.
Encoding
.
encode
(
self
,
text
,
allowed_special
=
allowed_special
,
disallowed_special
=
(),
)
tokenizer
.
encode
=
functools
.
partial
(
encode_patched
,
tokenizer
)
# Allow more tokens to prevent crash
tokenizer
.
_default_allowed_special
|=
set
(
DEFAULT_CONTROL_TOKENS
.
values
())
tokenizer
.
_default_allowed_special
|=
set
(
CONTROL_TOKEN_TEXTS
+
RESERVED_TOKEN_TEXTS
)
# Convert to HF interface
self
.
tokenizer
=
tokenizer
self
.
bos_token_id
=
None
self
.
eos_token_id
=
tokenizer
.
_special_tokens
[
EOS
]
self
.
vocab_size
=
tokenizer
.
n_vocab
self
.
chat_template
=
"{% for message in messages %}{% if message['role'] == 'user' %}{{ 'Human: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'system' %}{{ 'System: ' + message['content'].strip() + '<|separator|>
\n\n
' }}{% elif message['role'] == 'assistant' %}{{ 'Assistant: ' + message['content'] + '<|separator|>
\n\n
' }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ 'Assistant:' }}{% endif %}"
self
.
chat_template_jinja
=
Template
(
self
.
chat_template
)
self
.
additional_stop_token_ids
=
None
def
encode
(
self
,
x
,
add_special_tokens
=
False
):
return
self
.
tokenizer
.
encode
(
x
)
def
decode
(
self
,
x
,
*
args
,
**
kwargs
):
return
self
.
tokenizer
.
decode
(
x
)
def
batch_decode
(
self
,
batch
,
skip_special_tokens
=
True
,
spaces_between_special_tokens
=
False
):
if
len
(
batch
)
>
0
and
isinstance
(
batch
[
0
],
int
):
batch
=
[[
x
]
for
x
in
batch
]
return
self
.
tokenizer
.
decode_batch
(
batch
)
def
apply_chat_template
(
self
,
messages
,
tokenize
,
add_generation_prompt
,
tools
=
None
):
ret
=
self
.
chat_template_jinja
.
render
(
messages
=
messages
,
add_generation_prompt
=
add_generation_prompt
)
return
self
.
encode
(
ret
)
if
tokenize
else
ret
def
__call__
(
self
,
text
,
**
kwargs
):
return
{
"input_ids"
:
self
.
encode
(
text
),
}
def
init_xgrammar
(
self
):
from
xgrammar
import
TokenizerInfo
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE
=
"<|xg_special_token_{}|>"
enc
=
self
.
tokenizer
encoded_vocab
=
{
**
enc
.
_mergeable_ranks
,
**
enc
.
_special_tokens
}
encoded_vocab
=
[
token
for
token
,
_
in
sorted
(
encoded_vocab
.
items
(),
key
=
lambda
x
:
x
[
1
])
]
override_stop_tokens
=
[
2
]
# eos
# These are treated as special tokens in xgrammar; we want to avoid them
# For now, xgrammar treats anything starting with b'\x00' as a special token
xgrammar_special_token_ids
=
[]
for
i
,
token
in
enumerate
(
encoded_vocab
):
if
isinstance
(
token
,
bytes
)
and
token
.
startswith
(
b
"
\x00
"
):
xgrammar_special_token_ids
.
append
(
i
)
for
i
,
id
in
enumerate
(
xgrammar_special_token_ids
):
encoded_vocab
[
id
]
=
XGRAMMAR_SPECIAL_TOKEN_TEMPLATE
.
format
(
i
)
tokenizer_info
=
TokenizerInfo
(
encoded_vocab
,
stop_token_ids
=
override_stop_tokens
)
assert
len
(
tokenizer_info
.
special_token_ids
)
==
0
return
tokenizer_info
,
override_stop_tokens
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment