Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
c1d2061f
"sgl-kernel/python/sgl_kernel/version.py" did not exist on "9376ac361d845b422848fbeefbfa204613ad68e9"
Unverified
Commit
c1d2061f
authored
Aug 05, 2025
by
Ying Sheng
Committed by
GitHub
Aug 05, 2025
Browse files
Add initial support for gpt-oss (#8824)
parent
556e4143
Changes
12
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1595 additions
and
47 deletions
+1595
-47
python/sglang/srt/layers/attention/triton_backend.py
python/sglang/srt/layers/attention/triton_backend.py
+85
-14
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
...glang/srt/layers/attention/triton_ops/decode_attention.py
+17
-0
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
...glang/srt/layers/attention/triton_ops/extend_attention.py
+36
-8
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+0
-5
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
+134
-8
python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
...ang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
+178
-3
python/sglang/srt/layers/quantization/fp8_utils.py
python/sglang/srt/layers/quantization/fp8_utils.py
+29
-0
python/sglang/srt/layers/quantization/mxfp4_tensor.py
python/sglang/srt/layers/quantization/mxfp4_tensor.py
+133
-0
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+52
-7
python/sglang/srt/managers/schedule_batch.py
python/sglang/srt/managers/schedule_batch.py
+4
-2
python/sglang/srt/models/gpt_oss.py
python/sglang/srt/models/gpt_oss.py
+923
-0
python/sglang/srt/server_args.py
python/sglang/srt/server_args.py
+4
-0
No files found.
python/sglang/srt/layers/attention/triton_backend.py
View file @
c1d2061f
...
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
window_kv_indptr
=
torch
.
zeros_like
(
kv_indptr_buf
)
self
.
window_kv_indptr
=
torch
.
zeros_like
(
kv_indptr_buf
)
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
req_to_token
=
model_runner
.
req_to_token_pool
.
req_to_token
self
.
token_to_kv_pool_allocator
=
model_runner
.
token_to_kv_pool_allocator
if
not
self
.
skip_prefill
:
if
not
self
.
skip_prefill
:
self
.
qo_indptr
=
torch
.
zeros
(
self
.
qo_indptr
=
torch
.
zeros
(
...
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
bs
,
bs
,
self
.
device
,
self
.
device
,
self
.
token_to_kv_pool_allocator
,
)
)
)
)
window_num_kv_splits
=
torch
.
empty
(
window_num_kv_splits
=
torch
.
empty
(
...
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr
=
None
mask_indptr
=
None
max_extend_len
=
None
max_extend_len
=
None
elif
forward_batch
.
forward_mode
.
is_target_verify
():
elif
forward_batch
.
forward_mode
.
is_target_verify
():
# TODO: Support sliding window in spec inference
bs
=
len
(
forward_batch
.
req_pool_indices
)
bs
=
len
(
forward_batch
.
req_pool_indices
)
qo_indptr
=
torch
.
arange
(
qo_indptr
=
torch
.
arange
(
0
,
0
,
...
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
=
(
update_sliding_window_buffer
(
self
.
window_kv_indptr
,
self
.
req_to_token
,
self
.
sliding_window_size
,
forward_batch
.
seq_lens
,
forward_batch
.
req_pool_indices
,
bs
,
self
.
device
,
self
.
token_to_kv_pool_allocator
,
)
)
custom_mask
=
spec_info
.
custom_mask
custom_mask
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_mask_len
=
self
.
num_draft_tokens
*
(
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
forward_batch
.
seq_lens
+
self
.
num_draft_tokens
...
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch
.
req_pool_indices
,
forward_batch
.
req_pool_indices
,
bs
,
bs
,
self
.
device
,
self
.
device
,
self
.
token_to_kv_pool_allocator
,
)
)
qo_indptr
=
self
.
qo_indptr
qo_indptr
=
self
.
qo_indptr
...
@@ -423,7 +439,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -423,7 +439,8 @@ class TritonAttnBackend(AttentionBackend):
):
):
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indptr
,
_
=
update_sliding_window_buffer_cuda_graph
(
window_kv_indptr
,
window_kv_indices
,
_
=
(
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
self
.
window_kv_indptr
,
window_kv_indices
,
window_kv_indices
,
self
.
req_to_token
,
self
.
req_to_token
,
...
@@ -431,6 +448,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -431,6 +448,8 @@ class TritonAttnBackend(AttentionBackend):
seq_lens
[:
bs
],
seq_lens
[:
bs
],
req_pool_indices
,
req_pool_indices
,
bs
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
)
)
else
:
else
:
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
kv_indptr
,
kv_indices
=
spec_info
.
kv_indptr
,
spec_info
.
kv_indices
...
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indptr
,
window_kv_indices
,
_
=
(
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
)
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
...
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
):
):
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
_
,
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
self
.
window_kv_indptr
,
window_kv_indices
,
window_kv_indices
,
self
.
req_to_token
,
self
.
req_to_token
,
...
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens
[:
bs
],
seq_lens
[:
bs
],
req_pool_indices
[:
bs
],
req_pool_indices
[:
bs
],
bs
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
)
self
.
get_num_kv_splits
(
self
.
get_num_kv_splits
(
window_num_kv_splits
[:
num_token
],
window_kv_lens
[:
bs
]
window_num_kv_splits
[:
num_token
],
window_kv_lens
[:
bs
]
...
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
kv_indices
,
kv_indices
,
self
.
req_to_token
.
stride
(
0
),
self
.
req_to_token
.
stride
(
0
),
)
)
if
self
.
sliding_window_size
is
not
None
and
self
.
sliding_window_size
>
0
:
window_num_kv_splits
=
self
.
cuda_graph_window_num_kv_splits
window_kv_indices
=
self
.
cuda_graph_window_kv_indices
_
,
_
,
window_kv_lens
=
update_sliding_window_buffer_cuda_graph
(
self
.
window_kv_indptr
,
window_kv_indices
,
self
.
req_to_token
,
self
.
sliding_window_size
,
seq_lens
,
req_pool_indices
,
bs
,
self
.
token_to_kv_pool_allocator
,
)
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
=
self
.
cuda_graph_custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
custom_mask
[:
spec_info
.
custom_mask
.
shape
[
0
]]
=
spec_info
.
custom_mask
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
seq_mask_len
=
self
.
num_draft_tokens
*
(
seq_lens
+
self
.
num_draft_tokens
)
...
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
sk
=
None
,
):
):
# TODO: reuse the buffer across layers
# TODO: reuse the buffer across layers
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
if
layer
.
qk_head_dim
!=
layer
.
v_head_dim
:
...
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
self
.
forward_metadata
.
max_extend_len
,
self
.
forward_metadata
.
max_extend_len
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
sliding_window_size
,
sliding_window_size
=
sliding_window_size
,
sk
=
sk
,
)
)
return
o
return
o
...
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
layer
:
RadixAttention
,
layer
:
RadixAttention
,
forward_batch
:
ForwardBatch
,
forward_batch
:
ForwardBatch
,
save_kv_cache
=
True
,
save_kv_cache
=
True
,
sk
=
None
,
):
):
# During torch.compile, there is a bug in rotary_emb that causes the
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
# output value to have a 3D tensor shape. This reshapes the output correctly.
...
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
...
@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
self
.
max_kv_splits
,
self
.
max_kv_splits
,
layer
.
scaling
,
layer
.
scaling
,
layer
.
logit_cap
,
layer
.
logit_cap
,
sk
=
sk
,
)
)
return
o
return
o
...
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
...
@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
req_pool_indices
,
req_pool_indices
,
bs
,
bs
,
device
,
device
,
token_to_kv_pool_allocator
=
None
,
):
):
window_kv_lens
=
torch
.
minimum
(
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
torch
.
tensor
(
sliding_window_size
),
)
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
...
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
...
@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
window_kv_indices
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
req_to_token
.
stride
(
0
),
)
)
# full to swa index mapping
if
hasattr
(
token_to_kv_pool_allocator
,
"translate_loc_from_full_to_swa"
):
kv_last_index
=
window_kv_indptr
[
-
1
]
window_kv_indices
[:
kv_last_index
]
=
(
token_to_kv_pool_allocator
.
translate_loc_from_full_to_swa
(
window_kv_indices
[:
kv_last_index
]
)
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
...
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
...
@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
seq_lens
,
seq_lens
,
req_pool_indices
,
req_pool_indices
,
bs
,
bs
,
token_to_kv_pool_allocator
=
None
,
):
):
window_kv_lens
=
torch
.
minimum
(
window_kv_lens
=
torch
.
minimum
(
seq_lens
,
seq_lens
,
torch
.
tensor
(
sliding_window_size
+
1
),
torch
.
tensor
(
sliding_window_size
),
)
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
[
1
:
bs
+
1
]
=
torch
.
cumsum
(
window_kv_lens
,
dim
=
0
)
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
window_kv_indptr
=
window_kv_indptr
[:
bs
+
1
]
...
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
...
@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
window_kv_indices
,
window_kv_indices
,
req_to_token
.
stride
(
0
),
req_to_token
.
stride
(
0
),
)
)
return
window_kv_indptr
,
window_kv_lens
# full to swa index mapping
if
hasattr
(
token_to_kv_pool_allocator
,
"translate_loc_from_full_to_swa"
):
kv_last_index
=
window_kv_indptr
[
-
1
]
window_kv_indices
[:
kv_last_index
]
=
(
token_to_kv_pool_allocator
.
translate_loc_from_full_to_swa
(
window_kv_indices
[:
kv_last_index
]
)
)
return
window_kv_indptr
,
window_kv_indices
,
window_kv_lens
python/sglang/srt/layers/attention/triton_ops/decode_attention.py
View file @
c1d2061f
...
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
...
@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
O
,
O
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
sk_ptr
,
stride_mid_ob
,
stride_mid_ob
,
stride_mid_oh
,
stride_mid_oh
,
stride_mid_os
,
stride_mid_os
,
...
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
...
@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
MIN_BLOCK_KV
:
tl
.
constexpr
,
MIN_BLOCK_KV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
BLOCK_DV
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
Lv
:
tl
.
constexpr
,
HAS_SK
:
tl
.
constexpr
,
):
):
cur_batch
=
tl
.
program_id
(
0
)
cur_batch
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
...
@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_sum
=
e_sum
*
old_scale
+
exp_logic
e_max
=
n_e_max
e_max
=
n_e_max
if
HAS_SK
:
cur_sk
=
tl
.
load
(
sk_ptr
+
cur_head
)
e_sum
+=
tl
.
exp
(
cur_sk
-
e_max
)
tl
.
store
(
tl
.
store
(
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
O
+
cur_batch
*
stride_obs
+
cur_head
*
stride_oh
+
offs_d
,
acc
/
e_sum
,
acc
/
e_sum
,
...
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
...
@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sk
=
None
,
):
):
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
batch
,
head_num
=
q
.
shape
[
0
],
q
.
shape
[
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
Lv
=
v_buffer
.
shape
[
-
1
]
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
BLOCK_DV
=
triton
.
next_power_of_2
(
Lv
)
MAX_KV_SPLITS
=
max_kv_splits
MAX_KV_SPLITS
=
max_kv_splits
HAS_SK
=
sk
is
not
None
extra_kargs
=
{}
extra_kargs
=
{}
if
_is_hip
:
if
_is_hip
:
...
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
...
@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
o
,
o
,
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
sk
,
logits
.
stride
(
0
),
logits
.
stride
(
0
),
logits
.
stride
(
1
),
logits
.
stride
(
1
),
logits
.
stride
(
2
),
logits
.
stride
(
2
),
...
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
...
@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
MIN_BLOCK_KV
=
_MIN_BLOCK_KV
,
BLOCK_DV
=
BLOCK_DV
,
BLOCK_DV
=
BLOCK_DV
,
Lv
=
Lv
,
Lv
=
Lv
,
HAS_SK
=
HAS_SK
,
num_warps
=
4
,
num_warps
=
4
,
num_stages
=
2
,
num_stages
=
2
,
**
extra_kargs
,
**
extra_kargs
,
...
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
...
@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sk
=
None
,
):
):
_decode_att_m_fwd
(
_decode_att_m_fwd
(
q
,
q
,
...
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
...
@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sk
,
)
)
...
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
...
@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sk
=
None
,
):
):
_decode_grouped_att_m_fwd
(
_decode_grouped_att_m_fwd
(
q
,
q
,
...
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
...
@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
kv_indptr
,
kv_indptr
,
num_kv_splits
,
num_kv_splits
,
max_kv_splits
,
max_kv_splits
,
sk
,
)
)
...
@@ -687,6 +701,7 @@ def decode_attention_fwd(
...
@@ -687,6 +701,7 @@ def decode_attention_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
0.0
,
logit_cap
=
0.0
,
sk
=
None
,
):
):
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
max_kv_splits
==
attn_logits
.
shape
[
2
]
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
assert
q
.
shape
[
0
]
<=
kv_indptr
.
shape
[
0
]
-
1
...
@@ -709,6 +724,7 @@ def decode_attention_fwd(
...
@@ -709,6 +724,7 @@ def decode_attention_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
sk
=
sk
,
)
)
else
:
else
:
# GQA/MQA/MLA
# GQA/MQA/MLA
...
@@ -725,4 +741,5 @@ def decode_attention_fwd(
...
@@ -725,4 +741,5 @@ def decode_attention_fwd(
max_kv_splits
,
max_kv_splits
,
sm_scale
,
sm_scale
,
logit_cap
=
logit_cap
,
logit_cap
=
logit_cap
,
sk
=
sk
,
)
)
python/sglang/srt/layers/attention/triton_ops/extend_attention.py
View file @
c1d2061f
...
@@ -51,6 +51,7 @@ def _fwd_kernel(
...
@@ -51,6 +51,7 @@ def _fwd_kernel(
kv_indices
,
kv_indices
,
mask_ptr
,
mask_ptr
,
mask_indptr
,
mask_indptr
,
sk_ptr
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
stride_qbs
,
stride_qbs
,
...
@@ -78,6 +79,7 @@ def _fwd_kernel(
...
@@ -78,6 +79,7 @@ def _fwd_kernel(
IS_CAUSAL
:
tl
.
constexpr
,
IS_CAUSAL
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
SKIP_PREFIX_CUSTOM_MASK
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
STORE_TRANSPOSE
:
tl
.
constexpr
,
HAS_SK
:
tl
.
constexpr
,
):
):
cur_seq
=
tl
.
program_id
(
0
)
cur_seq
=
tl
.
program_id
(
0
)
cur_head
=
tl
.
program_id
(
1
)
cur_head
=
tl
.
program_id
(
1
)
...
@@ -178,13 +180,17 @@ def _fwd_kernel(
...
@@ -178,13 +180,17 @@ def _fwd_kernel(
final_mask
&=
custom_mask
final_mask
&=
custom_mask
if
SLIDING_WINDOW_SIZE
>
0
:
if
SLIDING_WINDOW_SIZE
>
0
:
# Add mask where q_id <= kv_id + sliding_window_size
# Add mask where q_id <= kv_id + sliding_window_size
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
<=
(
# q_id = prefix_len + cur_m, kv_id = cur_n
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
window_mask
=
(
)
cur_seq_len_prefix
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
]
)
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
final_mask
&=
window_mask
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
...
@@ -242,6 +248,7 @@ def _fwd_kernel(
...
@@ -242,6 +248,7 @@ def _fwd_kernel(
if
logit_cap
>
0
:
if
logit_cap
>
0
:
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
qk
=
logit_cap
*
tanh
(
qk
/
logit_cap
)
final_mask
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
if
USE_CUSTOM_MASK
:
if
USE_CUSTOM_MASK
:
custom_mask
=
tl
.
load
(
custom_mask
=
tl
.
load
(
mask_ptr
mask_ptr
...
@@ -254,18 +261,30 @@ def _fwd_kernel(
...
@@ -254,18 +261,30 @@ def _fwd_kernel(
other
=
0
,
other
=
0
,
)
)
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
custom_mask
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
q
k
=
tl
.
where
(
custom_mask
,
qk
,
float
(
"-inf"
))
final_mas
k
&
=
custom_mask
elif
IS_CAUSAL
:
elif
IS_CAUSAL
:
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
mask_causual
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
>=
(
start_n
+
offs_n
[
None
,
:]
start_n
+
offs_n
[
None
,
:]
)
)
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
mask_causual
&=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
q
k
=
tl
.
where
(
mask_causual
,
qk
,
float
(
"-inf"
))
final_mas
k
&
=
mask_causual
else
:
else
:
mask_non_causal
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
mask_non_causal
=
mask_m
[:,
None
]
&
mask_n
[
None
,
:]
qk
=
tl
.
where
(
mask_non_causal
,
qk
,
float
(
"-inf"
))
final_mask
&=
mask_non_causal
if
SLIDING_WINDOW_SIZE
>
0
:
# Add mask where q_id <= kv_id + sliding_window_size
window_mask
=
(
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
<=
(
start_n
+
offs_n
[
None
,
:]
+
SLIDING_WINDOW_SIZE
)
final_mask
&=
window_mask
qk
=
tl
.
where
(
final_mask
,
qk
,
float
(
"-inf"
))
row_max
=
tl
.
max
(
qk
,
1
)
row_max_fixed
=
tl
.
where
(
row_max
==
float
(
"-inf"
),
-
1e20
,
row_max
)
n_e_max
=
tl
.
maximum
(
row_max_fixed
,
e_max
)
n_e_max
=
tl
.
maximum
(
tl
.
max
(
qk
,
1
),
e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
re_scale
=
tl
.
exp
(
e_max
-
n_e_max
)
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
p
=
tl
.
exp
(
qk
-
n_e_max
[:,
None
])
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
deno
=
deno
*
re_scale
+
tl
.
sum
(
p
,
1
)
...
@@ -283,6 +302,10 @@ def _fwd_kernel(
...
@@ -283,6 +302,10 @@ def _fwd_kernel(
e_max
=
n_e_max
e_max
=
n_e_max
if
HAS_SK
:
cur_sk
=
tl
.
load
(
sk_ptr
+
cur_head
)
deno
+=
tl
.
exp
(
cur_sk
-
e_max
)
offs_o
=
(
offs_o
=
(
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
(
cur_seq_extend_start_idx
+
cur_block_m
*
BLOCK_M
+
offs_m
[:,
None
])
*
stride_obs
*
stride_obs
...
@@ -321,6 +344,7 @@ def extend_attention_fwd(
...
@@ -321,6 +344,7 @@ def extend_attention_fwd(
logit_cap
=
0.0
,
logit_cap
=
0.0
,
skip_prefix_custom_mask
=
True
,
skip_prefix_custom_mask
=
True
,
sliding_window_size
=-
1
,
sliding_window_size
=-
1
,
sk
=
None
,
):
):
"""
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
q_extend, k_extend, v_extend, o_extend: contiguous tensors
...
@@ -386,6 +410,8 @@ def extend_attention_fwd(
...
@@ -386,6 +410,8 @@ def extend_attention_fwd(
# Skip custom mask for prefix part
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
SKIP_PREFIX_CUSTOM_MASK
=
skip_prefix_custom_mask
HAS_SK
=
sk
is
not
None
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
grid
=
(
batch_size
,
head_num
,
triton
.
cdiv
(
max_len_extend
,
BLOCK_M
))
num_stages
=
1
num_stages
=
1
...
@@ -405,6 +431,7 @@ def extend_attention_fwd(
...
@@ -405,6 +431,7 @@ def extend_attention_fwd(
kv_indices
,
kv_indices
,
custom_mask
,
custom_mask
,
mask_indptr
,
mask_indptr
,
sk
,
sm_scale
,
sm_scale
,
kv_group_num
,
kv_group_num
,
q_extend
.
stride
(
0
),
q_extend
.
stride
(
0
),
...
@@ -431,6 +458,7 @@ def extend_attention_fwd(
...
@@ -431,6 +458,7 @@ def extend_attention_fwd(
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
USE_CUSTOM_MASK
=
USE_CUSTOM_MASK
,
IS_CAUSAL
=
is_causal
,
IS_CAUSAL
=
is_causal
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
SKIP_PREFIX_CUSTOM_MASK
=
SKIP_PREFIX_CUSTOM_MASK
,
HAS_SK
=
HAS_SK
,
STORE_TRANSPOSE
=
_is_hip
,
STORE_TRANSPOSE
=
_is_hip
,
num_warps
=
num_warps
,
num_warps
=
num_warps
,
num_stages
=
num_stages
,
num_stages
=
num_stages
,
...
...
python/sglang/srt/layers/linear.py
View file @
c1d2061f
...
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
...
@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
else
self
.
weight_loader
else
self
.
weight_loader
),
),
)
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
params_dtype
))
...
...
python/sglang/srt/layers/moe/fused_moe_triton/layer.py
View file @
c1d2061f
...
@@ -134,6 +134,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -134,6 +134,10 @@ class FusedMoE(torch.nn.Module):
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
enable_flashinfer_cutlass_moe
:
Optional
[
bool
]
=
False
,
enable_flashinfer_cutlass_moe
:
Optional
[
bool
]
=
False
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
use_weight_loader_fused
:
bool
=
False
,
with_bias
=
False
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -148,6 +152,10 @@ class FusedMoE(torch.nn.Module):
...
@@ -148,6 +152,10 @@ class FusedMoE(torch.nn.Module):
self
.
expert_map_cpu
=
None
self
.
expert_map_cpu
=
None
self
.
expert_map_gpu
=
None
self
.
expert_map_gpu
=
None
# For activation
self
.
activation_alpha
=
activation_alpha
self
.
swiglu_limit
=
swiglu_limit
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
if
enable_flashinfer_cutlass_moe
and
quant_config
is
None
:
logger
.
warning
(
"Disable flashinfer MoE when quantization config is None."
)
logger
.
warning
(
"Disable flashinfer MoE when quantization config is None."
)
enable_flashinfer_cutlass_moe
=
False
enable_flashinfer_cutlass_moe
=
False
...
@@ -191,7 +199,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -191,7 +199,7 @@ class FusedMoE(torch.nn.Module):
if
quant_config
is
None
:
if
quant_config
is
None
:
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
self
.
quant_method
:
Optional
[
QuantizeMethodBase
]
=
UnquantizedFusedMoEMethod
(
self
.
use_triton_kernels
self
.
use_triton_kernels
,
with_bias
=
with_bias
)
)
else
:
else
:
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
,
prefix
)
...
@@ -206,7 +214,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -206,7 +214,12 @@ class FusedMoE(torch.nn.Module):
intermediate_size
=
self
.
intermediate_size_per_partition
,
intermediate_size
=
self
.
intermediate_size_per_partition
,
intermediate_size_per_partition
=
self
.
intermediate_size_per_partition
,
intermediate_size_per_partition
=
self
.
intermediate_size_per_partition
,
params_dtype
=
params_dtype
,
params_dtype
=
params_dtype
,
weight_loader
=
self
.
weight_loader
,
weight_loader
=
(
self
.
weight_loader
if
not
use_weight_loader_fused
else
self
.
weight_loader_fused
),
with_bias
=
with_bias
,
)
)
def
_load_per_tensor_weight_scale
(
def
_load_per_tensor_weight_scale
(
...
@@ -234,6 +247,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -234,6 +247,7 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
,
tp_rank
:
int
,
is_bias
:
bool
=
False
,
):
):
# Load grouped weight scales for group quantization
# Load grouped weight scales for group quantization
# or model weights
# or model weights
...
@@ -244,14 +258,16 @@ class FusedMoE(torch.nn.Module):
...
@@ -244,14 +258,16 @@ class FusedMoE(torch.nn.Module):
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
is_bias
=
is_bias
,
)
)
elif
shard_id
in
(
"w1"
,
"w3"
):
elif
shard_id
in
(
"w1"
,
"w3"
,
"w13"
):
self
.
_load_w13
(
self
.
_load_w13
(
shard_id
=
shard_id
,
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
tp_rank
=
tp_rank
,
is_bias
=
is_bias
,
)
)
def
_load_per_channel_weight_scale
(
def
_load_per_channel_weight_scale
(
...
@@ -281,17 +297,30 @@ class FusedMoE(torch.nn.Module):
...
@@ -281,17 +297,30 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
,
tp_rank
:
int
,
is_bias
:
bool
=
False
,
):
):
# Index the loaded weight for tp sharding.
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
assert
shard_id
in
{
"w1"
,
"w3"
,
"w13"
}
if
is_bias
:
# if this weight is a bias, the last dimension must be the sharded dimension
shard_dim
=
-
1
if
shard_id
in
{
"w1"
,
"w3"
}:
# non-fused version
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
shard_size
=
expert_data
.
shape
[
shard_dim
]
//
2
elif
shard_id
in
{
"w13"
}:
# fused version
shard_size
=
expert_data
.
shape
[
shard_dim
]
else
:
raise
NotImplementedError
# Narrow parameter and load.
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
# trtllm cutlass kernel assumes differently
assert
shard_id
in
(
"w1"
,
"w3"
)
switch_w13
=
getattr
(
self
.
quant_method
,
"load_up_proj_weight_first"
,
False
)
switch_w13
=
getattr
(
self
.
quant_method
,
"load_up_proj_weight_first"
,
False
)
if
(
switch_w13
and
shard_id
==
"w1"
)
or
(
not
switch_w13
and
shard_id
==
"w3"
):
if
(
switch_w13
and
shard_id
==
"w1"
)
or
(
not
switch_w13
and
shard_id
==
"w3"
):
start
=
shard_size
start
=
shard_size
...
@@ -310,7 +339,8 @@ class FusedMoE(torch.nn.Module):
...
@@ -310,7 +339,8 @@ class FusedMoE(torch.nn.Module):
)
)
else
:
else
:
if
not
self
.
use_presharded_weights
:
if
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
if
not
is_bias
and
self
.
use_triton_kernels
:
# do not transpose for bias
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
narrow
(
loaded_weight
=
loaded_weight
.
narrow
(
shard_dim
,
shard_size
*
tp_rank
,
shard_size
shard_dim
,
shard_size
*
tp_rank
,
shard_size
...
@@ -326,6 +356,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -326,6 +356,7 @@ class FusedMoE(torch.nn.Module):
shard_id
:
str
,
shard_id
:
str
,
loaded_weight
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
,
tp_rank
:
int
,
tp_rank
:
int
,
is_bias
:
bool
=
False
,
):
):
"""Load w2 weights for down projection.
"""Load w2 weights for down projection.
...
@@ -356,6 +387,13 @@ class FusedMoE(torch.nn.Module):
...
@@ -356,6 +387,13 @@ class FusedMoE(torch.nn.Module):
# Index the loaded weight for tp sharding.
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
# Narrow parameter and load.
if
is_bias
:
# this expert_data is a bias, not weight,
# for w2_bias in TP, it does not need to be sharded
shard_size
=
expert_data
.
shape
[
-
1
]
else
:
# this parameter is a weight matrix
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
shard_size
=
expert_data
.
shape
[
shard_dim
]
shard_size
=
expert_data
.
shape
[
shard_dim
]
if
_is_cpu
:
if
_is_cpu
:
...
@@ -369,7 +407,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -369,7 +407,7 @@ class FusedMoE(torch.nn.Module):
not
self
.
use_presharded_weights
,
not
self
.
use_presharded_weights
,
)
)
else
:
else
:
if
not
self
.
use_presharded_weights
:
if
not
is_bias
and
not
self
.
use_presharded_weights
:
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
loaded_weight
=
loaded_weight
.
transpose
(
-
2
,
-
1
)
if
shard_size
*
tp_rank
+
shard_size
>
loaded_weight
.
shape
[
shard_dim
]:
if
shard_size
*
tp_rank
+
shard_size
>
loaded_weight
.
shape
[
shard_dim
]:
...
@@ -658,6 +696,68 @@ class FusedMoE(torch.nn.Module):
...
@@ -658,6 +696,68 @@ class FusedMoE(torch.nn.Module):
)
)
return
return
def
weight_loader_fused
(
self
,
param
:
torch
.
nn
.
Parameter
,
loaded_weight
:
torch
.
Tensor
,
weight_name
:
str
,
shard_id
:
str
,
)
->
None
:
tp_rank
=
self
.
moe_tp_rank
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO: check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight
=
(
loaded_weight
.
t
().
contiguous
()
if
(
self
.
quant_method
.
__class__
.
__name__
==
"CompressedTensorsWNA16MoEMethod"
)
else
loaded_weight
)
if
shard_id
not
in
(
"w13"
,
"w2"
):
raise
ValueError
(
f
"shard_id must be ['w13','w2'] but "
f
"got
{
shard_id
}
."
)
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM
=
{
"w13"
:
1
,
"w2"
:
2
}
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE
=
{
"w13"
:
2
,
"w2"
:
1
}
expert_data
=
param
.
data
is_bias
=
expert_data
.
dim
()
==
2
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed
=
getattr
(
param
,
"is_transposed"
,
False
)
if
self
.
use_triton_kernels
:
is_transposed
=
True
shard_dim
=
(
SHARD_ID_TO_SHARDED_DIM
[
shard_id
]
if
not
is_transposed
else
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE
[
shard_id
]
)
# Case model weights
if
"weight"
in
weight_name
:
self
.
_load_model_weight_or_group_weight_scale
(
shard_id
=
shard_id
,
shard_dim
=
shard_dim
,
loaded_weight
=
loaded_weight
,
expert_data
=
expert_data
,
tp_rank
=
tp_rank
,
is_bias
=
is_bias
,
)
return
else
:
logging
.
warning
(
f
"Unsupported weight_name
{
weight_name
}
for FusedMoE weight_loader_fused. Nothing is loaded."
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
StandardTopKOutput
):
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
topk_output
:
StandardTopKOutput
):
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -673,6 +773,12 @@ class FusedMoE(torch.nn.Module):
...
@@ -673,6 +773,12 @@ class FusedMoE(torch.nn.Module):
# Matrix multiply.
# Matrix multiply.
with
use_symmetric_memory
(
get_tp_group
())
as
sm
:
with
use_symmetric_memory
(
get_tp_group
())
as
sm
:
kwargs
=
{}
if
self
.
activation_alpha
is
not
None
:
kwargs
[
"activation_alpha"
]
=
self
.
activation_alpha
if
self
.
swiglu_limit
is
not
None
:
kwargs
[
"swiglu_limit"
]
=
self
.
swiglu_limit
final_hidden_states
=
self
.
quant_method
.
apply
(
final_hidden_states
=
self
.
quant_method
.
apply
(
layer
=
self
,
layer
=
self
,
x
=
hidden_states
,
x
=
hidden_states
,
...
@@ -691,6 +797,7 @@ class FusedMoE(torch.nn.Module):
...
@@ -691,6 +797,7 @@ class FusedMoE(torch.nn.Module):
==
"ModelOptNvFp4FusedMoEMethod"
==
"ModelOptNvFp4FusedMoEMethod"
else
{}
else
{}
),
),
**
kwargs
,
)
)
sm
.
tag
(
final_hidden_states
)
sm
.
tag
(
final_hidden_states
)
...
@@ -728,6 +835,25 @@ class FusedMoE(torch.nn.Module):
...
@@ -728,6 +835,25 @@ class FusedMoE(torch.nn.Module):
]
]
]
]
@
classmethod
def
make_expert_params_mapping_fused
(
cls
,
ckpt_gate_up_proj_name
:
str
,
ckpt_down_proj_name
:
str
,
ckpt_gate_up_proj_bias_name
:
str
,
ckpt_down_proj_bias_name
:
str
,
):
return
[
(
"experts.w13_weight"
,
f
"experts.
{
ckpt_gate_up_proj_name
}
"
,
"w13"
),
(
"experts.w13_weight_bias"
,
f
"experts.
{
ckpt_gate_up_proj_bias_name
}
"
,
"w13"
,
),
(
"experts.w2_weight"
,
f
"experts.
{
ckpt_down_proj_name
}
"
,
"w2"
),
(
"experts.w2_weight_bias"
,
f
"experts.
{
ckpt_down_proj_bias_name
}
"
,
"w2"
),
]
@
classmethod
@
classmethod
def
make_expert_input_scale_params_mapping
(
def
make_expert_input_scale_params_mapping
(
cls
,
cls
,
...
...
python/sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py
View file @
c1d2061f
...
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
...
@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
import
torch
import
torch
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
triton_kernels.matmul_ogs
import
matmul_ogs
from
triton_kernels.matmul_ogs
import
(
FlexCtx
,
FnSpecs
,
FusedActivation
,
PrecisionConfig
,
matmul_ogs
,
)
from
triton_kernels.numerics
import
InFlexData
from
triton_kernels.routing
import
GatherIndx
,
RoutingData
,
ScatterIndx
from
triton_kernels.routing
import
GatherIndx
,
RoutingData
,
ScatterIndx
from
triton_kernels.swiglu
import
swiglu_fn
from
sglang.srt.utils
import
direct_register_custom_op
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.topk
import
TopKOutput
def
quantize
(
w
,
dtype
,
dev
,
**
opt
):
if
dtype
==
"bf16"
:
return
w
.
to
(
torch
.
bfloat16
),
InFlexData
()
elif
dtype
==
"fp8"
:
wq
=
w
.
to
(
torch
.
float8_e4m3fn
).
transpose
(
-
1
,
-
2
).
contiguous
().
transpose
(
-
1
,
-
2
)
return
(
wq
,
InFlexData
(
dtype
=
wq
.
dtype
,
scale
=
w
.
abs
().
max
().
unsqueeze
(
0
)),
MicroscalingCtx
(),
)
else
:
assert
dtype
==
"mx4"
,
f
"
{
dtype
=
}
"
swizzle_mx_scale
=
opt
[
"swizzle_mx_scale"
]
swizzle_axis
=
2
if
swizzle_mx_scale
else
None
w
=
w
.
to
(
torch
.
bfloat16
)
w
,
mx_scales
,
weight_scale_shape
=
downcast_to_mxfp
(
w
,
torch
.
uint8
,
axis
=
1
,
swizzle_axis
=
swizzle_axis
)
return
(
w
,
InFlexData
(),
MicroscalingCtx
(
weight_scale
=
mx_scales
,
swizzle_mx
=
swizzle_mx_scale
,
actual_weight_scale_shape
=
weight_scale_shape
,
),
)
def
triton_kernel_moe_forward
(
def
triton_kernel_moe_forward
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
...
@@ -146,3 +181,143 @@ def triton_kernel_fused_experts(
...
@@ -146,3 +181,143 @@ def triton_kernel_fused_experts(
)
)
return
intermediate_cache3
return
intermediate_cache3
def
triton_kernel_moe_with_bias_forward
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
b2
:
torch
.
Tensor
,
topk_output
:
TopKOutput
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
assert
topk_output
.
format
.
is_triton_kernel
()
routing_data
,
gather_idx
,
scatter_idx
=
topk_output
return
triton_kernel_fused_experts_with_bias
(
hidden_states
,
w1
,
b1
,
w2
,
b2
,
routing_data
,
gather_idx
,
scatter_idx
,
inplace
=
inplace
,
activation
=
activation
,
use_fp8_w8a8
=
use_fp8_w8a8
,
per_channel_quant
=
per_channel_quant
,
global_num_experts
=
global_num_experts
,
expert_map
=
expert_map
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
a1_scale
=
a1_scale
,
a2_scale
=
a2_scale
,
block_shape
=
block_shape
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
def
triton_kernel_fused_experts_with_bias
(
hidden_states
:
torch
.
Tensor
,
w1
:
torch
.
Tensor
,
b1
:
torch
.
Tensor
,
w2
:
torch
.
Tensor
,
b2
:
torch
.
Tensor
,
routing_data
:
RoutingData
,
gather_indx
:
GatherIndx
,
scatter_indx
:
ScatterIndx
,
inplace
:
bool
=
False
,
activation
:
str
=
"silu"
,
use_fp8_w8a8
:
bool
=
False
,
per_channel_quant
:
bool
=
False
,
global_num_experts
:
int
=
-
1
,
expert_map
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
block_shape
:
Optional
[
list
[
int
]]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
int
]
=
None
,
)
->
torch
.
Tensor
:
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
assert
use_fp8_w8a8
==
False
,
"use_fp8_w8a8 is not supported"
assert
per_channel_quant
==
False
,
"per_channel_quant is not supported"
assert
expert_map
==
None
,
"expert_map is not supported"
assert
w1_scale
==
None
,
"w1_scale is not supported"
assert
w2_scale
==
None
,
"w2_scale is not supported"
assert
a1_scale
==
None
,
"a1_scale is not supported"
assert
a2_scale
==
None
,
"a2_scale is not supported"
assert
block_shape
==
None
,
"block_shape is not supported"
# type check
assert
hidden_states
.
dtype
==
torch
.
bfloat16
,
"hidden_states must be bfloat16"
assert
w1
.
dtype
==
torch
.
bfloat16
,
"w1 must be bfloat16"
assert
w2
.
dtype
==
torch
.
bfloat16
,
"w2 must be bfloat16"
# Shape check
assert
hidden_states
.
ndim
==
2
,
"hidden_states must be 2D"
assert
(
hidden_states
.
shape
[
-
1
]
==
w1
.
shape
[
-
2
]
),
f
"hidden_states shape[-1]
{
hidden_states
.
shape
}
must be equal to w1 shape[-2]
{
w1
.
shape
}
"
assert
(
w2
.
shape
[
-
1
]
==
w1
.
shape
[
1
]
),
f
"w2 shape[-1]
{
w2
.
shape
[
-
1
]
}
must be equal to w1 shape[1]
{
w1
.
shape
[
1
]
}
"
# feature check
assert
inplace
==
False
,
"Inplace is not supported in new triton MoE kernel"
E
,
_
,
_
=
w1
.
shape
if
global_num_experts
==
-
1
:
global_num_experts
=
E
device
=
"cuda"
optg
=
dict
()
w1
,
w1_flex
=
quantize
(
w1
,
"bf16"
,
device
,
**
optg
)
w1_pcg
=
PrecisionConfig
(
flex_ctx
=
FlexCtx
(
rhs_data
=
w1_flex
))
w2
,
w2_flex
=
quantize
(
w2
,
"bf16"
,
device
,
**
optg
)
w2_pcg
=
PrecisionConfig
(
flex_ctx
=
FlexCtx
(
rhs_data
=
w2_flex
))
act
=
FusedActivation
(
FnSpecs
(
"swiglu"
,
swiglu_fn
,
(
"alpha"
,
"limit"
)),
(
activation_alpha
,
swiglu_limit
),
2
,
)
intermediate_cache
=
matmul_ogs
(
hidden_states
,
w1
,
b1
,
routing_data
,
gather_indx
=
gather_indx
,
precision_config
=
w1_pcg
,
gammas
=
None
,
fused_activation
=
act
,
)
return
matmul_ogs
(
intermediate_cache
,
w2
,
b2
,
routing_data
,
scatter_indx
=
scatter_indx
,
precision_config
=
w2_pcg
,
gammas
=
routing_data
.
gate_scal
,
)
python/sglang/srt/layers/quantization/fp8_utils.py
View file @
c1d2061f
...
@@ -4,6 +4,7 @@ import torch
...
@@ -4,6 +4,7 @@ import torch
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization
import
deep_gemm_wrapper
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.fp8_kernel
import
sglang_per_token_group_quant_fp8
from
sglang.srt.layers.quantization.mxfp4_tensor
import
MXFP4QuantizeUtil
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm100_supported
try
:
try
:
...
@@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
...
@@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
)
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
align
,
align
,
ceil_div
,
get_bool_env_var
,
get_bool_env_var
,
get_cuda_version
,
get_cuda_version
,
get_device_capability
,
get_device_capability
,
...
@@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear(
...
@@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear(
return
output
.
to
(
dtype
=
input_2d
.
dtype
).
view
(
*
output_shape
)
return
output
.
to
(
dtype
=
input_2d
.
dtype
).
view
(
*
output_shape
)
def
dequant_mxfp4
(
w_block
:
torch
.
Tensor
,
w_scale
:
torch
.
Tensor
,
out_dtype
,
)
->
torch
.
Tensor
:
"""
:param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte
:param w_scale: (batch, n, k), uint8
:return: (batch, n, k * 32), float32
"""
assert
w_block
.
dtype
==
torch
.
uint8
assert
w_scale
.
dtype
==
torch
.
uint8
batch
,
n
,
k
,
pack_dim
=
w_block
.
shape
batch_
,
n_
,
k_
=
w_scale
.
shape
assert
pack_dim
==
16
assert
batch
==
batch_
assert
n
==
n_
assert
k
==
k_
out_raw
=
MXFP4QuantizeUtil
.
dequantize
(
quantized_data
=
w_block
,
scale
=
w_scale
,
dtype
=
out_dtype
,
block_sizes
=
[
32
]
)
return
out_raw
.
reshape
(
batch
,
n
,
k
*
32
)
def
input_to_float8
(
def
input_to_float8
(
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
fp8_dtype
x
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
=
fp8_dtype
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
...
...
python/sglang/srt/layers/quantization/mxfp4_tensor.py
0 → 100644
View file @
c1d2061f
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
# https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/mxfp4_tensor.py
class
MXFP4QuantizeUtil
:
E2M1_max
=
6.0
E2M1_values
=
[
0
,
0.5
,
1
,
1.5
,
2
,
3
,
4
,
6
]
E2M1_bounds
=
torch
.
tensor
([
0.25
,
0.75
,
1.25
,
1.75
,
2.5
,
3.5
,
5
])
@
classmethod
def
quantize
(
cls
,
input
:
torch
.
Tensor
,
block_size
:
int
|
None
)
->
tuple
:
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
Args:
input (torch.Tensor): The input tensor to be quantized.
block_sizes (dict | None): The block sizes for quantization.
"""
def
cast_fp4
(
x
):
sign
=
torch
.
sign
(
x
)
sign_bit
=
(
2
-
sign
)
//
2
ord_
=
torch
.
sum
(
(
x
.
abs
().
unsqueeze
(
-
1
)
-
cls
.
E2M1_bounds
.
to
(
x
.
device
))
>
0
,
dim
=-
1
)
fp4_val
=
(
sign_bit
*
0b1000
+
ord_
).
to
(
torch
.
uint8
)
return
fp4_val
def
fuse_uint4_to_uint8
(
x
):
# If the last dimension is odd, pad with zeros
# If this behavior is not desired, please modify the code accordingly
left_side
=
x
[...,
0
::
2
]
# Even indices (0, 2, 4...)
right_side
=
x
[...,
1
::
2
]
# Odd indices (1, 3, 5...)
new_data
=
(
right_side
.
clone
()
<<
4
)
# Put odd indices (higher addresses) in high bits
new_data
[
...,
:
left_side
.
shape
[
-
1
]
]
+=
left_side
# Put even indices in low bits
return
new_data
if
block_size
is
None
:
block_size
=
32
original_shape
=
input
.
shape
original_dtype
=
input
.
dtype
input
=
input
.
view
(
-
1
,
block_size
)
# get scales
input_amax
=
input
.
abs
().
max
(
dim
=-
1
,
keepdim
=
True
).
values
descale
=
input_amax
/
cls
.
E2M1_max
min_value
=
torch
.
tensor
(
-
127.0
,
device
=
descale
.
device
)
e8m0_scale
=
torch
.
ceil
(
torch
.
maximum
(
torch
.
log2
(
descale
),
min_value
))
input
=
(
input
/
torch
.
exp2
(
e8m0_scale
)).
view
(
original_shape
)
input_q
=
cast_fp4
(
input
)
input_q
=
fuse_uint4_to_uint8
(
input_q
)
e8m0_scale
=
(
e8m0_scale
+
127
).
to
(
torch
.
uint8
)
return
cls
(
original_shape
,
original_dtype
,
input_q
),
e8m0_scale
@
classmethod
def
dequantize
(
cls
,
quantized_data
,
dtype
:
torch
.
dtype
,
scale
,
block_sizes
):
"""Dequantze MXFP4 packed tensor to a target dtype."""
def
unfuse_uint8_to_uint4
(
x
):
"""Unfuse uint8 values back to uint4 values.
This is the inverse operation of fuse_uint4_to_uint8.
"""
# Extract the lower 4 bits (even indices)
left_side
=
x
&
0x0F
# Extract the upper 4 bits (odd indices)
right_side
=
(
x
>>
4
)
&
0x0F
# Create a new tensor with alternating values
shape
=
list
(
x
.
shape
)
shape
[
-
1
]
=
shape
[
-
1
]
*
2
result
=
torch
.
zeros
(
shape
,
dtype
=
torch
.
uint8
,
device
=
x
.
device
)
# Fill in the values - even indices get low bits, odd indices get high bits
result
[...,
0
::
2
]
=
left_side
# Even indices from low bits
result
[...,
1
::
2
]
=
right_side
# Odd indices from high bits
return
result
e8m0_scale
=
scale
block_size
=
block_sizes
[
-
1
]
# Unfuse the uint8 values back to uint4
x_unfused
=
unfuse_uint8_to_uint4
(
quantized_data
)
# Extract sign and magnitude
sign
=
1
-
2
*
((
x_unfused
&
0b1000
)
>>
3
).
to
(
torch
.
float32
)
# Extract sign bit and convert to +1/-1
magnitude
=
x_unfused
&
0b0111
# Extract magnitude bits
magnitude
=
magnitude
.
to
(
torch
.
long
)
# Create a tensor with the E2M1 values
values
=
torch
.
tensor
(
cls
.
E2M1_values
,
device
=
quantized_data
.
device
)
# Use gather to index the values tensor properly
# We need to reshape magnitude to match the dimensions we want to gather along
original_shape
=
magnitude
.
shape
x_float
=
values
[
magnitude
.
reshape
(
-
1
)].
reshape
(
original_shape
)
# Apply sign and scale
x_float
=
sign
.
float
()
*
x_float
# Reshape to apply block-wise scaling
x_float
=
x_float
.
reshape
(
-
1
,
block_size
)
# Apply the E8M0 scale
scale_factor
=
torch
.
exp2
(
e8m0_scale
.
float
()
-
127
)
scale_factor
=
scale_factor
.
reshape
(
-
1
,
1
)
# Reshape for proper broadcasting
# Apply scaling and reshape back to original shape
x_float
=
x_float
*
scale_factor
# Reshape back to the original shape
return
x_float
.
reshape
(
original_shape
).
to
(
dtype
)
python/sglang/srt/layers/quantization/unquant.py
View file @
c1d2061f
...
@@ -126,17 +126,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -126,17 +126,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
class
UnquantizedFusedMoEMethod
(
FusedMoEMethodBase
,
CustomOp
):
"""MoE method without quantization."""
"""MoE method without quantization."""
def
__init__
(
self
,
use_triton_kernels
:
bool
=
False
):
def
__init__
(
self
,
use_triton_kernels
:
bool
=
False
,
with_bias
:
bool
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
use_triton_kernels
=
use_triton_kernels
self
.
use_triton_kernels
=
use_triton_kernels
self
.
with_bias
=
with_bias
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_forward
=
None
self
.
triton_kernel_moe_with_bias_forward
=
None
if
torch
.
cuda
.
is_available
()
and
has_triton_kernels
:
if
torch
.
cuda
.
is_available
()
and
has_triton_kernels
:
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_forward
as
_tk_forward
,
triton_kernel_moe_forward
as
_tk_forward
,
)
)
from
sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe
import
(
triton_kernel_moe_with_bias_forward
as
_tk_with_bias_forward
,
)
self
.
triton_kernel_moe_forward
=
_tk_forward
self
.
triton_kernel_moe_forward
=
_tk_forward
self
.
triton_kernel_moe_with_bias_forward
=
_tk_with_bias_forward
def
create_weights
(
def
create_weights
(
self
,
self
,
...
@@ -158,6 +164,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -158,6 +164,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
layer
.
register_parameter
(
"w13_weight"
,
w13_weight
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w13_weight
,
extra_weight_attrs
)
if
self
.
with_bias
:
w13_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
2
*
intermediate_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w13_weight_bias"
,
w13_weight_bias
)
set_weight_attrs
(
w13_weight_bias
,
extra_weight_attrs
)
# down_proj (row parallel)
# down_proj (row parallel)
w2_weight_n
,
w2_weight_k
=
(
w2_weight_n
,
w2_weight_k
=
(
hidden_size
,
hidden_size
,
...
@@ -172,6 +186,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -172,6 +186,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
layer
.
register_parameter
(
"w2_weight"
,
w2_weight
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
set_weight_attrs
(
w2_weight
,
extra_weight_attrs
)
if
self
.
with_bias
:
w2_weight_bias
=
torch
.
nn
.
Parameter
(
torch
.
empty
(
num_experts
,
hidden_size
,
dtype
=
torch
.
float32
),
requires_grad
=
False
,
)
layer
.
register_parameter
(
"w2_weight_bias"
,
w2_weight_bias
)
set_weight_attrs
(
w2_weight_bias
,
extra_weight_attrs
)
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
if
_use_aiter
:
if
_use_aiter
:
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
...
@@ -202,7 +224,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -202,7 +224,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace
:
bool
=
True
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
kwargs
=
{}
if
activation_alpha
is
not
None
:
kwargs
[
"activation_alpha"
]
=
activation_alpha
if
swiglu_limit
is
not
None
:
kwargs
[
"swiglu_limit"
]
=
swiglu_limit
return
self
.
forward
(
return
self
.
forward
(
x
=
x
,
x
=
x
,
...
@@ -213,6 +242,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -213,6 +242,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace
=
inplace
,
inplace
=
inplace
,
no_combine
=
no_combine
,
no_combine
=
no_combine
,
routed_scaling_factor
=
routed_scaling_factor
,
routed_scaling_factor
=
routed_scaling_factor
,
**
kwargs
,
)
)
def
forward_cuda
(
def
forward_cuda
(
...
@@ -226,9 +256,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
...
@@ -226,9 +256,24 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace
:
bool
=
True
,
inplace
:
bool
=
True
,
no_combine
:
bool
=
False
,
no_combine
:
bool
=
False
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
activation_alpha
:
Optional
[
float
]
=
None
,
swiglu_limit
:
Optional
[
float
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
if
self
.
use_triton_kernels
:
if
self
.
use_triton_kernels
:
if
self
.
with_bias
:
return
self
.
triton_kernel_moe_with_bias_forward
(
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w2
=
layer
.
w2_weight
,
b1
=
layer
.
w13_weight_bias
,
b2
=
layer
.
w2_weight_bias
,
topk_output
=
topk_output
,
activation
=
activation
,
activation_alpha
=
activation_alpha
,
swiglu_limit
=
swiglu_limit
,
)
else
:
return
self
.
triton_kernel_moe_forward
(
return
self
.
triton_kernel_moe_forward
(
hidden_states
=
x
,
hidden_states
=
x
,
w1
=
layer
.
w13_weight
,
w1
=
layer
.
w13_weight
,
...
...
python/sglang/srt/managers/schedule_batch.py
View file @
c1d2061f
...
@@ -917,8 +917,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
...
@@ -917,8 +917,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_hybrid
=
False
is_hybrid
=
False
if
isinstance
(
token_to_kv_pool_allocator
,
SWATokenToKVPoolAllocator
):
if
isinstance
(
token_to_kv_pool_allocator
,
SWATokenToKVPoolAllocator
):
assert
isinstance
(
tree_cache
,
SWARadixCache
)
or
isinstance
(
assert
(
tree_cache
,
SWAChunkCache
tree_cache
is
None
or
isinstance
(
tree_cache
,
SWARadixCache
)
or
isinstance
(
tree_cache
,
SWAChunkCache
)
),
"SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
),
"SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
is_hybrid
=
True
is_hybrid
=
True
...
...
python/sglang/srt/models/gpt_oss.py
0 → 100644
View file @
c1d2061f
This diff is collapsed.
Click to expand it.
python/sglang/srt/server_args.py
View file @
c1d2061f
...
@@ -457,6 +457,10 @@ class ServerArgs:
...
@@ -457,6 +457,10 @@ class ServerArgs:
raise
ValueError
(
raise
ValueError
(
"trtllm_mla backend does not support speculative decoding yet."
"trtllm_mla backend does not support speculative decoding yet."
)
)
model_arch
=
self
.
get_hf_config
().
architectures
[
0
]
if
model_arch
in
[
"GptOssForCausalLM"
]:
self
.
attention_backend
=
"triton"
self
.
enable_triton_kernel_moe
=
True
# Set page size
# Set page size
if
self
.
page_size
is
None
:
if
self
.
page_size
is
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment