Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
c1d2061f
Unverified
Commit
c1d2061f
authored
Aug 05, 2025
by
Ying Sheng
Committed by
GitHub
Aug 05, 2025
Browse files
Add initial support for gpt-oss (#8824)
parent
556e4143
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1595 additions
and
47 deletions
+1595
-47
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+85
-14
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+17
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+36
-8
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+0
-5
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+134
-8
python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
...ang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
+178
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+29
-0
python/sglang/srt/layers/quantization/mxfp4_tensor.py
python/sglang/srt/layers/quantization/mxfp4_tensor.py
+133
-0
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+52
-7
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-2
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+923
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-0
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
c1d2061f
...
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
window_kv_indptr
=
torch
.
zeros_like
(
kv_indptr_buf
)
self
.
window_kv_indptr
=
torch
.
zeros_like
(
kv_indptr_buf
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
token_to_kv_pool_allocator
=
model_runner
.
token_to_kv_pool_allocator
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
self
.
qo_indptr
=
torch
.
zeros
(
self
.
qo_indptr
=
torch
.
zeros
(
...
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
bs
,
bs
,
self
.
device
,
self
.
device
,
self
.
token_to_kv_pool_allocator
,
)
)
)
)
window_num_kv_splits
=
torch
.
empty
(
window_num_kv_splits
=
torch
.
empty
(
...
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
None
mask_indptr
=
None
max_extend_len
=
None
max_extend_len
=
None
elif
forward_batch
.
forward_mode
.
is_target_verify
():
elif
forward_batch
.
forward_mode
.
is_target_verify
():
# TODO: Support sliding window in spec inference
bs
=
len
(
forward_batch
.
req_pool_indices
)
bs
=
len
(
forward_batch
.
req_pool_indices
)
qo_indptr
=
torch
.
arange
(
qo_indptr
=
torch
.
arange
(
0
,
0
,
...
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
=
(
update_sliding_window_buffer
(
self
.
window_kv_indptr
,
self
.
req_to_token
,
self
.
sliding_window_size
,
forward_batch
.
seq_lens
,
forward_batch
.
req_pool_indices
,
bs
,
self
.
device
,
self
.
token_to_kv_pool_allocator
,
)
)
custom_mask
=
spec_info
.
custom_mask
custom_mask
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_mask_len
=
self
.
num_draft_tokens
*
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
...
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
bs
,
bs
,
self
.
device
,
self
.
device
,
self
.
token_to_kv_pool_allocator
,
)
)
qo_indptr
=
self
.
qo_indptr
qo_indptr
=
self
.
qo_indptr
...
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
):
):
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indptr
,
_
=
update_sliding_window_buffer_cuda_graph
(
window_kv_indptr
,
window_kv_indices
,
_
=
(
self
.
window_kv_indptr
,
update_sliding_window_buffer_cuda_graph
(
window_kv_indices
,
self
.
window_kv_indptr
,
self
.
req_to_token
,
window_kv_indices
,
self
.
sliding_window_size
,
self
.
req_to_token
,
seq_lens
[:
bs
],
self
.
sliding_window_size
,
req_pool_indices
,
seq_lens
[:
bs
],
bs
,
req_pool_indices
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
)
)
else
:
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
...
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indptr
,
window_kv_indices
,
_
=
(
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
)
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
...
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
):
):
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
_
,
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
self
.
window_kv_indptr
,
window_kv_indices
,
window_kv_indices
,
self
.
req_to_token
,
self
.
req_to_token
,
...
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens
[:
bs
],
seq_lens
[:
bs
],
req_pool_indices
[:
bs
],
req_pool_indices
[:
bs
],
bs
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
)
self
.
get_num_kv_splits
(
self
.
get_num_kv_splits
(
window_num_kv_splits
[:
num_token
],
window_kv_lens
[:
bs
]
window_num_kv_splits
[:
num_token
],
window_kv_lens
[:
bs
]
...
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
_
,
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
...
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
sk
=
None
,
):
):
# TODO: reuse the buffer across layers
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
...
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
max_extend_len
,
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
sliding_window_size
,
sliding_window_size
=
sliding_window_size
,
sk
=
sk
,
)
)
return
o
return
o
...
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
sk
=
None
,
):
):
# During torch.compile, there is a bug in rotary_emb that causes the
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
# output value to have a 3D tensor shape. This reshapes the output correctly.
...
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
max_kv_splits
,
self
.
max_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
sk
=
sk
,
)
)
return
o
return
o
...
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
...
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
req_pool_indices
,
req_pool_indices
,
bs
,
bs
,
device
,
device
,
token_to_kv_pool_allocator
=
None
,
):
):
window_kv_lens
=
torch
.
minimum
(
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
torch
.
tensor
(
sliding_window_size
),
)
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
...
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
...
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
window_kv_indices
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
req_to_token
.
stride
(
0
),
)
)
# full to swa index mapping
if
hasattr
(
token_to_kv_pool_allocator
,
"translate_loc_from_full_to_swa"
):
kv_last_index
=
window_kv_indptr
[
-
1
]
window_kv_indices
[:
kv_last_index
]
=
(
token_to_kv_pool_allocator
.
translate_loc_from_full_to_swa
(
window_kv_indices
[:
kv_last_index
]
)
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
...
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
...
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
seq_lens
,
seq_lens
,
req_pool_indices
,
req_pool_indices
,
bs
,
bs
,
token_to_kv_pool_allocator
=
None
,
):
):
window_kv_lens
=
torch
.
minimum
(
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
torch
.
tensor
(
sliding_window_size
),
)
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
...
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
...
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
window_kv_indices
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
req_to_token
.
stride
(
0
),
)
)
return
window_kv_indptr
,
window_kv_lens
# full to swa index mapping
if
hasattr
(
token_to_kv_pool_allocator
,
"translate_loc_from_full_to_swa"
):
kv_last_index
=
window_kv_indptr
[
-
1
]
window_kv_indices
[:
kv_last_index
]
=
(
token_to_kv_pool_allocator
.
translate_loc_from_full_to_swa
(
window_kv_indices
[:
kv_last_index
]
)
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
c1d2061f
...
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
...
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
O
,
O
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
sk_ptr
,
stride_mid_ob
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_os
,
...
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
...
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
MIN_BLOCK_KV
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
HAS_SK
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
...
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
e_max
=
n_e_max
if
HAS_SK
:
cur_sk
=
tl
.
load
(
sk_ptr
+
cur_head
)
e_sum
+=
tl
.
exp
(
cur_sk
-
e_max
)
tl
.
store
(
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
acc
/
e_sum
,
...
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
...
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sk
=
None
,
):
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
MAX_KV_SPLITS
=
max_kv_splits
MAX_KV_SPLITS
=
max_kv_splits
HAS_SK
=
sk
is
not
None
extra_kargs
=
{}
extra_kargs
=
{}
if
_is_hip
:
if
_is_hip
:
...
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
...
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
o
,
o
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
sk
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
logits
.
stride
(
2
),
...
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
...
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
Lv
=
Lv
,
HAS_SK
=
HAS_SK
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
2
,
num_stages
=
2
,
**
extra_kargs
,
**
extra_kargs
,
...
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
...
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sk
=
None
,
):
):
_decode_att_m_fwd
(
_decode_att_m_fwd
(
q
,
q
,
...
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
...
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sk
,
)
)
...
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
...
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sk
=
None
,
):
):
_decode_grouped_att_m_fwd
(
_decode_grouped_att_m_fwd
(
q
,
q
,
...
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
...
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sk
,
)
)
...
@@ -687,6 +701,7 @@ def decode_attention_fwd(
...
@@ -687,6 +701,7 @@ def decode_attention_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sk
=
None
,
):
):
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
...
@@ -709,6 +724,7 @@ def decode_attention_fwd(
...
@@ -709,6 +724,7 @@ def decode_attention_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
sk
=
sk
,
)
)
else
:
else
:
# GQA/MQA/MLA
# GQA/MQA/MLA
...
@@ -725,4 +741,5 @@ def decode_attention_fwd(
...
@@ -725,4 +741,5 @@ def decode_attention_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
sk
=
sk
,
)
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
c1d2061f
...
@@ -51,6 +51,7 @@ def _fwd_kernel(
...
@@ -51,6 +51,7 @@ def _fwd_kernel(
kv_indices
,
kv_indices
,
mask_ptr
,
mask_ptr
,
mask_indptr
,
mask_indptr
,
sk_ptr
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
stride_qbs
,
stride_qbs
,
...
@@ -78,6 +79,7 @@ def _fwd_kernel(
...
@@ -78,6 +79,7 @@ def _fwd_kernel(
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
HAS_SK
:
tl
.
constexpr
,
):
):
cur_seq
=
tl
.
program_id
(
0
)
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -178,13 +180,17 @@ def _fwd_kernel(
...
@@ -178,13 +180,17 @@ def _fwd_kernel(
final_mask
&=
custom_mask
final_mask
&=
custom_mask
if
SLIDING_WINDOW_SIZE
>
0
:
if
SLIDING_WINDOW_SIZE
>
0
:
# Add mask where q_id <= kv_id + sliding_window_size
# Add mask where q_id <= kv_id + sliding_window_size
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
<=
(
# q_id = prefix_len + cur_m, kv_id = cur_n
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
window_mask
=
(
)
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
]
)
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
final_mask
&=
window_mask
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
...
@@ -242,6 +248,7 @@ def _fwd_kernel(
...
@@ -242,6 +248,7 @@ def _fwd_kernel(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
:
if
USE_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
custom_mask
=
tl
.
load
(
mask_ptr
mask_ptr
...
@@ -254,18 +261,30 @@ def _fwd_kernel(
...
@@ -254,18 +261,30 @@ def _fwd_kernel(
other
=
0
,
other
=
0
,
)
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
q
k
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
final_mas
k
&
=
custom_mask
elif
IS_CAUSAL
:
elif
IS_CAUSAL
:
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
start_n
+
offs_n
[
None
,
:]
)
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
q
k
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
final_mas
k
&
=
mask_causual
else
:
else
:
mask_non_causal
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
mask_non_causal
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_non_causal
,
qk
,
float
(
"-inf"
))
final_mask
&=
mask_non_causal
if
SLIDING_WINDOW_SIZE
>
0
:
# Add mask where q_id <= kv_id + sliding_window_size
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
...
@@ -283,6 +302,10 @@ def _fwd_kernel(
...
@@ -283,6 +302,10 @@ def _fwd_kernel(
e_max
=
n_e_max
e_max
=
n_e_max
if
HAS_SK
:
cur_sk
=
tl
.
load
(
sk_ptr
+
cur_head
)
deno
+=
tl
.
exp
(
cur_sk
-
e_max
)
offs_o
=
(
offs_o
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_obs
*
stride_obs
...
@@ -321,6 +344,7 @@ def extend_attention_fwd(
...
@@ -321,6 +344,7 @@ def extend_attention_fwd(
logit_cap
=
0.0
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
skip_prefix_custom_mask
=
True
,
sliding_window_size
=-
1
,
sliding_window_size
=-
1
,
sk
=
None
,
):
):
"""
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
@@ -386,6 +410,8 @@ def extend_attention_fwd(
...
@@ -386,6 +410,8 @@ def extend_attention_fwd(
# Skip custom mask for prefix part
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
HAS_SK
=
sk
is
not
None
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
num_stages
=
1
num_stages
=
1
...
@@ -405,6 +431,7 @@ def extend_attention_fwd(
...
@@ -405,6 +431,7 @@ def extend_attention_fwd(
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_indptr
,
mask_indptr
,
sk
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
q_extend
.
stride
(
0
),
q_extend
.
stride
(
0
),
...
@@ -431,6 +458,7 @@ def extend_attention_fwd(
...
@@ -431,6 +458,7 @@ def extend_attention_fwd(
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
IS_CAUSAL
=
is_causal
,
IS_CAUSAL
=
is_causal
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
HAS_SK
=
HAS_SK
,
STORE_TRANSPOSE
=
_is_hip
,
STORE_TRANSPOSE
=
_is_hip
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
...
...
python/sglang/srt/layers/linear.py
View file @
c1d2061f
...
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
...
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
else
self
.
weight_loader
else
self
.
weight_loader
),
),
)
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
c1d2061f
...
@@ -134,6 +134,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -134,6 +134,10 @@ class FusedMoE(torch.nn.Module):
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
enable_flashinfer_cutlass_moe
:
Optional
[
bool
]
=
False
,
enable_flashinfer_cutlass_moe
:
Optional
[
bool
]
=
False
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
use_weight_loader_fused
:
bool
=
False
,
with_bias
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -148,6 +152,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -148,6 +152,10 @@ class FusedMoE(torch.nn.Module):
self
.
expert_map_cpu
=
None
self
.
expert_map_cpu
=
None
self
.
expert_map_gpu
=
None
self
.
expert_map_gpu
=
None
# For activation
self
.
activation_alpha
=
activation_alpha
self
.
swiglu_limit
=
swiglu_limit
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
logger
.
warning
(
"Disable flashinfer MoE when quantization config is None."
)
logger
.
warning
(
"Disable flashinfer MoE when quantization config is None."
)
enable_flashinfer_cutlass_moe
=
False
enable_flashinfer_cutlass_moe
=
False
...
@@ -191,7 +199,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -191,7 +199,7 @@ class FusedMoE(torch.nn.Module):
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
self
.
use_triton_kernels
self
.
use_triton_kernels
,
with_bias
=
with_bias
)
)
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
...
@@ -206,7 +214,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -206,7 +214,12 @@ class FusedMoE(torch.nn.Module):
intermediate_size
=
self
.
intermediate_size_per_partition
,
intermediate_size
=
self
.
intermediate_size_per_partition
,
intermediate_size_per_partition
=
self
.
intermediate_size_per_partition
,
intermediate_size_per_partition
=
self
.
intermediate_size_per_partition
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
,
weight_loader
=
(
self
.
weight_loader
if
not
use_weight_loader_fused
else
self
.
weight_loader_fused
),
with_bias
=
with_bias
,
)
)
def
_load_per_tensor_weight_scale
(
def
_load_per_tensor_weight_scale
(
...
@@ -234,6 +247,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -234,6 +247,7 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
,
tp_rank
:
int
,
is_bias
:
bool
=
False
,
):
):
# Load grouped weight scales for group quantization
# Load grouped weight scales for group quantization
# or model weights
# or model weights
...
@@ -244,14 +258,16 @@ class FusedMoE(torch.nn.Module):
...
@@ -244,14 +258,16 @@ class FusedMoE(torch.nn.Module):
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
is_bias
=
is_bias
,
)
)
elif
shard_id
in
(
"w1"
,
"w3"
):
elif
shard_id
in
(
"w1"
,
"w3"
,
"w13"
):
self
.
_load_w13
(
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
is_bias
=
is_bias
,
)
)
def
_load_per_channel_weight_scale
(
def
_load_per_channel_weight_scale
(
...
@@ -281,17 +297,30 @@ class FusedMoE(torch.nn.Module):
...
@@ -281,17 +297,30 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
,
tp_rank
:
int
,
is_bias
:
bool
=
False
,
):
):
# Index the loaded weight for tp sharding.
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
assert
shard_id
in
{
"w1"
,
"w3"
,
"w13"
}
if
is_bias
:
# if this weight is a bias, the last dimension must be the sharded dimension
shard_dim
=
-
1
if
shard_id
in
{
"w1"
,
"w3"
}:
# non-fused version
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
elif
shard_id
in
{
"w13"
}:
# fused version
shard_size
=
expert_data
.
shape
[
shard_dim
]
else
:
raise
NotImplementedError
# Narrow parameter and load.
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
# trtllm cutlass kernel assumes differently
assert
shard_id
in
(
"w1"
,
"w3"
)
switch_w13
=
getattr
(
self
.
quant_method
,
"load_up_proj_weight_first"
,
False
)
switch_w13
=
getattr
(
self
.
quant_method
,
"load_up_proj_weight_first"
,
False
)
if
(
switch_w13
and
shard_id
==
"w1"
)
or
(
not
switch_w13
and
shard_id
==
"w3"
):
if
(
switch_w13
and
shard_id
==
"w1"
)
or
(
not
switch_w13
and
shard_id
==
"w3"
):
start
=
shard_size
start
=
shard_size
...
@@ -310,7 +339,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -310,7 +339,8 @@ class FusedMoE(torch.nn.Module):
)
)
else
:
else
:
if
not
self
.
use_presharded_weights
:
if
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
if
not
is_bias
and
self
.
use_triton_kernels
:
# do not transpose for bias
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
narrow
(
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
shard_dim
,
shard_size
*
tp_rank
,
shard_size
...
@@ -326,6 +356,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -326,6 +356,7 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
,
tp_rank
:
int
,
is_bias
:
bool
=
False
,
):
):
"""Load w2 weights for down projection.
"""Load w2 weights for down projection.
...
@@ -356,7 +387,14 @@ class FusedMoE(torch.nn.Module):
...
@@ -356,7 +387,14 @@ class FusedMoE(torch.nn.Module):
# Index the loaded weight for tp sharding.
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
# Narrow parameter and load.
shard_size
=
expert_data
.
shape
[
shard_dim
]
if
is_bias
:
# this expert_data is a bias, not weight,
# for w2_bias in TP, it does not need to be sharded
shard_size
=
expert_data
.
shape
[
-
1
]
else
:
# this parameter is a weight matrix
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
shard_size
=
expert_data
.
shape
[
shard_dim
]
if
_is_cpu
:
if
_is_cpu
:
expert_data
,
loaded_weight
=
narrow_padded_param_and_loaded_weight
(
expert_data
,
loaded_weight
=
narrow_padded_param_and_loaded_weight
(
...
@@ -369,7 +407,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -369,7 +407,7 @@ class FusedMoE(torch.nn.Module):
not
self
.
use_presharded_weights
,
not
self
.
use_presharded_weights
,
)
)
else
:
else
:
if
not
self
.
use_presharded_weights
:
if
not
is_bias
and
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
if
shard_size
*
tp_rank
+
shard_size
>
loaded_weight
.
shape
[
shard_dim
]:
if
shard_size
*
tp_rank
+
shard_size
>
loaded_weight
.
shape
[
shard_dim
]:
...
@@ -658,6 +696,68 @@ class FusedMoE(torch.nn.Module):
...
@@ -658,6 +696,68 @@ class FusedMoE(torch.nn.Module):
)
)
return
return
def
weight_loader_fused
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
)
->
None
:
tp_rank
=
self
.
moe_tp_rank
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO: check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight
=
(
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsWNA16MoEMethod"
)
else
loaded_weight
)
if
shard_id
not
in
(
"w13"
,
"w2"
):
raise
ValueError
(
f
"shard_id must be ['w13','w2'] but "
f
"got
{
shard_id
}
."
)
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM
=
{
"w13"
:
1
,
"w2"
:
2
}
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE
=
{
"w13"
:
2
,
"w2"
:
1
}
expert_data
=
param
.
data
is_bias
=
expert_data
.
dim
()
==
2
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
if
self
.
use_triton_kernels
:
is_transposed
=
True
shard_dim
=
(
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
if
not
is_transposed
else
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE
[
shard_id
]
)
# Case model weights
if
"weight"
in
weight_name
:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
is_bias
=
is_bias
,
)
return
else
:
logging
.
warning
(
f
"Unsupported weight_name
{
weight_name
}
for FusedMoE weight_loader_fused. Nothing is loaded."
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
StandardTopKOutput
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
StandardTopKOutput
):
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -673,6 +773,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -673,6 +773,12 @@ class FusedMoE(torch.nn.Module):
# Matrix multiply.
# Matrix multiply.
with
use_symmetric_memory
(
get_tp_group
())
as
sm
:
with
use_symmetric_memory
(
get_tp_group
())
as
sm
:
kwargs
=
{}
if
self
.
activation_alpha
is
not
None
:
kwargs
[
"activation_alpha"
]
=
self
.
activation_alpha
if
self
.
swiglu_limit
is
not
None
:
kwargs
[
"swiglu_limit"
]
=
self
.
swiglu_limit
final_hidden_states
=
self
.
quant_method
.
apply
(
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
self
,
layer
=
self
,
x
=
hidden_states
,
x
=
hidden_states
,
...
@@ -691,6 +797,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -691,6 +797,7 @@ class FusedMoE(torch.nn.Module):
==
"ModelOptNvFp4FusedMoEMethod"
==
"ModelOptNvFp4FusedMoEMethod"
else
{}
else
{}
),
),
**
kwargs
,
)
)
sm
.
tag
(
final_hidden_states
)
sm
.
tag
(
final_hidden_states
)
...
@@ -728,6 +835,25 @@ class FusedMoE(torch.nn.Module):
...
@@ -728,6 +835,25 @@ class FusedMoE(torch.nn.Module):
]
]
]
]
@
classmethod
def
make_expert_params_mapping_fused
(
cls
,
ckpt_gate_up_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_gate_up_proj_bias_name
:
str
,
ckpt_down_proj_bias_name
:
str
,
):
return
[
(
"experts.w13_weight"
,
f
"experts.
{
ckpt_gate_up_proj_name
}
"
,
"w13"
),
(
"experts.w13_weight_bias"
,
f
"experts.
{
ckpt_gate_up_proj_bias_name
}
"
,
"w13"
,
),
(
"experts.w2_weight"
,
f
"experts.
{
ckpt_down_proj_name
}
"
,
"w2"
),
(
"experts.w2_weight_bias"
,
f
"experts.
{
ckpt_down_proj_bias_name
}
"
,
"w2"
),
]
@
classmethod
@
classmethod
def
make_expert_input_scale_params_mapping
(
def
make_expert_input_scale_params_mapping
(
cls
,
cls
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
View file @
c1d2061f
...
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
...
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
import
torch
import
torch
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
triton_kernels.matmul_ogs
import
matmul_ogs
from
triton_kernels.matmul_ogs
import
(
FlexCtx
,
FnSpecs
,
FusedActivation
,
PrecisionConfig
,
matmul_ogs
,
)
from
triton_kernels.numerics
import
InFlexData
from
triton_kernels.routing
import
GatherIndx
,
RoutingData
,
ScatterIndx
from
triton_kernels.routing
import
GatherIndx
,
RoutingData
,
ScatterIndx
from
triton_kernels.swiglu
import
swiglu_fn
from
sglang.srt.utils
import
direct_register_custom_op
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.topk
import
TopKOutput
def
quantize
(
w
,
dtype
,
dev
,
**
opt
):
if
dtype
==
"bf16"
:
return
w
.
to
(
torch
.
bfloat16
),
InFlexData
()
elif
dtype
==
"fp8"
:
wq
=
w
.
to
(
torch
.
float8_e4m3fn
).
transpose
(
-
1
,
-
2
).
contiguous
().
transpose
(
-
1
,
-
2
)
return
(
wq
,
InFlexData
(
dtype
=
wq
.
dtype
,
scale
=
w
.
abs
().
max
().
unsqueeze
(
0
)),
MicroscalingCtx
(),
)
else
:
assert
dtype
==
"mx4"
,
f
"
{
dtype
=
}
"
swizzle_mx_scale
=
opt
[
"swizzle_mx_scale"
]
swizzle_axis
=
2
if
swizzle_mx_scale
else
None
w
=
w
.
to
(
torch
.
bfloat16
)
w
,
mx_scales
,
weight_scale_shape
=
downcast_to_mxfp
(
w
,
torch
.
uint8
,
axis
=
1
,
swizzle_axis
=
swizzle_axis
)
return
(
w
,
InFlexData
(),
MicroscalingCtx
(
weight_scale
=
mx_scales
,
swizzle_mx
=
swizzle_mx_scale
,
actual_weight_scale_shape
=
weight_scale_shape
,
),
)
def
triton_kernel_moe_forward
(
def
triton_kernel_moe_forward
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -146,3 +181,143 @@ def triton_kernel_fused_experts(
...
@@ -146,3 +181,143 @@ def triton_kernel_fused_experts(
)
)
return
intermediate_cache3
return
intermediate_cache3
def
triton_kernel_moe_with_bias_forward
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
b2
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
assert
topk_output
.
format
.
is_triton_kernel
()
routing_data
,
gather_idx
,
scatter_idx
=
topk_output
return
triton_kernel_fused_experts_with_bias
(
hidden_states
,
w1
,
b1
,
w2
,
b2
,
routing_data
,
gather_idx
,
scatter_idx
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
def
triton_kernel_fused_experts_with_bias
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
b2
:
torch
.
Tensor
,
routing_data
:
RoutingData
,
gather_indx
:
GatherIndx
,
scatter_indx
:
ScatterIndx
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
assert
use_fp8_w8a8
==
False
,
"use_fp8_w8a8 is not supported"
assert
per_channel_quant
==
False
,
"per_channel_quant is not supported"
assert
expert_map
==
None
,
"expert_map is not supported"
assert
w1_scale
==
None
,
"w1_scale is not supported"
assert
w2_scale
==
None
,
"w2_scale is not supported"
assert
a1_scale
==
None
,
"a1_scale is not supported"
assert
a2_scale
==
None
,
"a2_scale is not supported"
assert
block_shape
==
None
,
"block_shape is not supported"
# type check
assert
hidden_states
.
dtype
==
torch
.
bfloat16
,
"hidden_states must be bfloat16"
assert
w1
.
dtype
==
torch
.
bfloat16
,
"w1 must be bfloat16"
assert
w2
.
dtype
==
torch
.
bfloat16
,
"w2 must be bfloat16"
# Shape check
assert
hidden_states
.
ndim
==
2
,
"hidden_states must be 2D"
assert
(
hidden_states
.
shape
[
-
1
]
==
w1
.
shape
[
-
2
]
),
f
"hidden_states shape[-1]
{
hidden_states
.
shape
}
must be equal to w1 shape[-2]
{
w1
.
shape
}
"
assert
(
w2
.
shape
[
-
1
]
==
w1
.
shape
[
1
]
),
f
"w2 shape[-1]
{
w2
.
shape
[
-
1
]
}
must be equal to w1 shape[1]
{
w1
.
shape
[
1
]
}
"
# feature check
assert
inplace
==
False
,
"Inplace is not supported in new triton MoE kernel"
E
,
_
,
_
=
w1
.
shape
if
global_num_experts
==
-
1
:
global_num_experts
=
E
device
=
"cuda"
optg
=
dict
()
w1
,
w1_flex
=
quantize
(
w1
,
"bf16"
,
device
,
**
optg
)
w1_pcg
=
PrecisionConfig
(
flex_ctx
=
FlexCtx
(
rhs_data
=
w1_flex
))
w2
,
w2_flex
=
quantize
(
w2
,
"bf16"
,
device
,
**
optg
)
w2_pcg
=
PrecisionConfig
(
flex_ctx
=
FlexCtx
(
rhs_data
=
w2_flex
))
act
=
FusedActivation
(
FnSpecs
(
"swiglu"
,
swiglu_fn
,
(
"alpha"
,
"limit"
)),
(
activation_alpha
,
swiglu_limit
),
2
,
)
intermediate_cache
=
matmul_ogs
(
hidden_states
,
w1
,
b1
,
routing_data
,
gather_indx
=
gather_indx
,
precision_config
=
w1_pcg
,
gammas
=
None
,
fused_activation
=
act
,
)
return
matmul_ogs
(
intermediate_cache
,
w2
,
b2
,
routing_data
,
scatter_indx
=
scatter_indx
,
precision_config
=
w2_pcg
,
gammas
=
routing_data
.
gate_scal
,
)
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
c1d2061f
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.mxfp4_tensor
import
MXFP4QuantizeUtil
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm100_supported
try
:
try
:
...
@@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
...
@@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
align
,
align
,
ceil_div
,
get_bool_env_var
,
get_bool_env_var
,
get_cuda_version
,
get_cuda_version
,
get_device_capability
,
get_device_capability
,
...
@@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear(
...
@@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear(
return
output
.
to
(
dtype
=
input_2d
.
dtype
).
view
(
*
output_shape
)
return
output
.
to
(
dtype
=
input_2d
.
dtype
).
view
(
*
output_shape
)
def
dequant_mxfp4
(
w_block
:
torch
.
Tensor
,
w_scale
:
torch
.
Tensor
,
out_dtype
,
)
->
torch
.
Tensor
:
"""
:param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte
:param w_scale: (batch, n, k), uint8
:return: (batch, n, k * 32), float32
"""
assert
w_block
.
dtype
==
torch
.
uint8
assert
w_scale
.
dtype
==
torch
.
uint8
batch
,
n
,
k
,
pack_dim
=
w_block
.
shape
batch_
,
n_
,
k_
=
w_scale
.
shape
assert
pack_dim
==
16
assert
batch
==
batch_
assert
n
==
n_
assert
k
==
k_
out_raw
=
MXFP4QuantizeUtil
.
dequantize
(
quantized_data
=
w_block
,
scale
=
w_scale
,
dtype
=
out_dtype
,
block_sizes
=
[
32
]
)
return
out_raw
.
reshape
(
batch
,
n
,
k
*
32
)
def
input_to_float8
(
def
input_to_float8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
fp8_dtype
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
fp8_dtype
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
python/sglang/srt/layers/quantization/mxfp4_tensor.py
0 → 100644
View file @
c1d2061f
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
# https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/mxfp4_tensor.py
class
MXFP4QuantizeUtil
:
E2M1_max
=
6.0
E2M1_values
=
[
0
,
0.5
,
1
,
1.5
,
2
,
3
,
4
,
6
]
E2M1_bounds
=
torch
.
tensor
([
0.25
,
0.75
,
1.25
,
1.75
,
2.5
,
3.5
,
5
])
@
classmethod
def
quantize
(
cls
,
input
:
torch
.
Tensor
,
block_size
:
int
|
None
)
->
tuple
:
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
Args:
input (torch.Tensor): The input tensor to be quantized.
block_sizes (dict | None): The block sizes for quantization.
"""
def
cast_fp4
(
x
):
sign
=
torch
.
sign
(
x
)
sign_bit
=
(
2
-
sign
)
//
2
ord_
=
torch
.
sum
(
(
x
.
abs
().
unsqueeze
(
-
1
)
-
cls
.
E2M1_bounds
.
to
(
x
.
device
))
>
0
,
dim
=-
1
)
fp4_val
=
(
sign_bit
*
0b1000
+
ord_
).
to
(
torch
.
uint8
)
return
fp4_val
def
fuse_uint4_to_uint8
(
x
):
# If the last dimension is odd, pad with zeros
# If this behavior is not desired, please modify the code accordingly
left_side
=
x
[...,
0
::
2
]
# Even indices (0, 2, 4...)
right_side
=
x
[...,
1
::
2
]
# Odd indices (1, 3, 5...)
new_data
=
(
right_side
.
clone
()
<<
4
)
# Put odd indices (higher addresses) in high bits
new_data
[
...,
:
left_side
.
shape
[
-
1
]
]
+=
left_side
# Put even indices in low bits
return
new_data
if
block_size
is
None
:
block_size
=
32
original_shape
=
input
.
shape
original_dtype
=
input
.
dtype
input
=
input
.
view
(
-
1
,
block_size
)
# get scales
input_amax
=
input
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
descale
=
input_amax
/
cls
.
E2M1_max
min_value
=
torch
.
tensor
(
-
127.0
,
device
=
descale
.
device
)
e8m0_scale
=
torch
.
ceil
(
torch
.
maximum
(
torch
.
log2
(
descale
),
min_value
))
input
=
(
input
/
torch
.
exp2
(
e8m0_scale
)).
view
(
original_shape
)
input_q
=
cast_fp4
(
input
)
input_q
=
fuse_uint4_to_uint8
(
input_q
)
e8m0_scale
=
(
e8m0_scale
+
127
).
to
(
torch
.
uint8
)
return
cls
(
original_shape
,
original_dtype
,
input_q
),
e8m0_scale
@
classmethod
def
dequantize
(
cls
,
quantized_data
,
dtype
:
torch
.
dtype
,
scale
,
block_sizes
):
"""Dequantze MXFP4 packed tensor to a target dtype."""
def
unfuse_uint8_to_uint4
(
x
):
"""Unfuse uint8 values back to uint4 values.
This is the inverse operation of fuse_uint4_to_uint8.
"""
# Extract the lower 4 bits (even indices)
left_side
=
x
&
0x0F
# Extract the upper 4 bits (odd indices)
right_side
=
(
x
>>
4
)
&
0x0F
# Create a new tensor with alternating values
shape
=
list
(
x
.
shape
)
shape
[
-
1
]
=
shape
[
-
1
]
*
2
result
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
uint8
,
device
=
x
.
device
)
# Fill in the values - even indices get low bits, odd indices get high bits
result
[...,
0
::
2
]
=
left_side
# Even indices from low bits
result
[...,
1
::
2
]
=
right_side
# Odd indices from high bits
return
result
e8m0_scale
=
scale
block_size
=
block_sizes
[
-
1
]
# Unfuse the uint8 values back to uint4
x_unfused
=
unfuse_uint8_to_uint4
(
quantized_data
)
# Extract sign and magnitude
sign
=
1
-
2
*
((
x_unfused
&
0b1000
)
>>
3
).
to
(
torch
.
float32
)
# Extract sign bit and convert to +1/-1
magnitude
=
x_unfused
&
0b0111
# Extract magnitude bits
magnitude
=
magnitude
.
to
(
torch
.
long
)
# Create a tensor with the E2M1 values
values
=
torch
.
tensor
(
cls
.
E2M1_values
,
device
=
quantized_data
.
device
)
# Use gather to index the values tensor properly
# We need to reshape magnitude to match the dimensions we want to gather along
original_shape
=
magnitude
.
shape
x_float
=
values
[
magnitude
.
reshape
(
-
1
)].
reshape
(
original_shape
)
# Apply sign and scale
x_float
=
sign
.
float
()
*
x_float
# Reshape to apply block-wise scaling
x_float
=
x_float
.
reshape
(
-
1
,
block_size
)
# Apply the E8M0 scale
scale_factor
=
torch
.
exp2
(
e8m0_scale
.
float
()
-
127
)
scale_factor
=
scale_factor
.
reshape
(
-
1
,
1
)
# Reshape for proper broadcasting
# Apply scaling and reshape back to original shape
x_float
=
x_float
*
scale_factor
# Reshape back to the original shape
return
x_float
.
reshape
(
original_shape
).
to
(
dtype
)
python/sglang/srt/layers/quantization/unquant.py
View file @
c1d2061f
...
@@ -126,17 +126,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -126,17 +126,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
"""MoE method without quantization."""
def
__init__
(
self
,
use_triton_kernels
:
bool
=
False
):
def
__init__
(
self
,
use_triton_kernels
:
bool
=
False
,
with_bias
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
use_triton_kernels
=
use_triton_kernels
self
.
use_triton_kernels
=
use_triton_kernels
self
.
with_bias
=
with_bias
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
if
torch
.
cuda
.
is_available
()
and
has_triton_kernels
:
if
torch
.
cuda
.
is_available
()
and
has_triton_kernels
:
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
as
_tk_forward
,
triton_kernel_moe_forward
as
_tk_forward
,
)
)
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_with_bias_forward
as
_tk_with_bias_forward
,
)
self
.
triton_kernel_moe_forward
=
_tk_forward
self
.
triton_kernel_moe_forward
=
_tk_forward
self
.
triton_kernel_moe_with_bias_forward
=
_tk_with_bias_forward
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -158,6 +164,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -158,6 +164,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
if
self
.
with_bias
:
w13_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_bias"
,
w13_weight_bias
)
set_weight_attrs
(
w13_weight_bias
,
extra_weight_attrs
)
# down_proj (row parallel)
# down_proj (row parallel)
w2_weight_n
,
w2_weight_k
=
(
w2_weight_n
,
w2_weight_k
=
(
hidden_size
,
hidden_size
,
...
@@ -172,6 +186,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -172,6 +186,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
if
self
.
with_bias
:
w2_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_bias"
,
w2_weight_bias
)
set_weight_attrs
(
w2_weight_bias
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
_use_aiter
:
if
_use_aiter
:
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
...
@@ -202,7 +224,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -202,7 +224,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace
:
bool
=
True
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
kwargs
=
{}
if
activation_alpha
is
not
None
:
kwargs
[
"activation_alpha"
]
=
activation_alpha
if
swiglu_limit
is
not
None
:
kwargs
[
"swiglu_limit"
]
=
swiglu_limit
return
self
.
forward
(
return
self
.
forward
(
x
=
x
,
x
=
x
,
...
@@ -213,6 +242,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -213,6 +242,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace
=
inplace
,
inplace
=
inplace
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
**
kwargs
,
)
)
def
forward_cuda
(
def
forward_cuda
(
...
@@ -226,15 +256,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -226,15 +256,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace
:
bool
=
True
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
return
self
.
triton_kernel_moe_forward
(
if
self
.
with_bias
:
hidden_states
=
x
,
return
self
.
triton_kernel_moe_with_bias_forward
(
w1
=
layer
.
w13_weight
,
hidden_states
=
x
,
w2
=
layer
.
w2_weight
,
w1
=
layer
.
w13_weight
,
topk_output
=
topk_output
,
w2
=
layer
.
w2_weight
,
)
b1
=
layer
.
w13_weight_bias
,
b2
=
layer
.
w2_weight_bias
,
topk_output
=
topk_output
,
activation
=
activation
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
else
:
return
self
.
triton_kernel_moe_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
topk_output
=
topk_output
,
)
else
:
else
:
if
_use_aiter
:
if
_use_aiter
:
assert
not
no_combine
,
"unsupported"
assert
not
no_combine
,
"unsupported"
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
c1d2061f
...
@@ -917,8 +917,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -917,8 +917,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_hybrid
=
False
is_hybrid
=
False
if
isinstance
(
token_to_kv_pool_allocator
,
SWATokenToKVPoolAllocator
):
if
isinstance
(
token_to_kv_pool_allocator
,
SWATokenToKVPoolAllocator
):
assert
isinstance
(
tree_cache
,
SWARadixCache
)
or
isinstance
(
assert
(
tree_cache
,
SWAChunkCache
tree_cache
is
None
or
isinstance
(
tree_cache
,
SWARadixCache
)
or
isinstance
(
tree_cache
,
SWAChunkCache
)
),
"SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
),
"SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
is_hybrid
=
True
is_hybrid
=
True
...
...
python/sglang/srt/models/gpt_oss.py
0 → 100644
View file @
c1d2061f
This diff is collapsed.
Click to expand it.
python/sglang/srt/server_args.py
View file @
c1d2061f
...
@@ -457,6 +457,10 @@ class ServerArgs:
...
@@ -457,6 +457,10 @@ class ServerArgs:
raise
ValueError
(
raise
ValueError
(
"trtllm_mla backend does not support speculative decoding yet."
"trtllm_mla backend does not support speculative decoding yet."
)
)
model_arch
=
self
.
get_hf_config
().
architectures
[
0
]
if
model_arch
in
[
"GptOssForCausalLM"
]:
self
.
attention_backend
=
"triton"
self
.
enable_triton_kernel_moe
=
True
# Set page size
# Set page size
if
self
.
page_size
is
None
:
if
self
.
page_size
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment