Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3f2e315f
Unverified
Commit
3f2e315f
authored
Aug 09, 2025
by
tql.99
Committed by
GitHub
Aug 09, 2025
Browse files
optimize: reduce shulffle and quantization overhead in cutlass_moe sm90 (#8962)
Co-authored-by:
戚余航
<
qiyuhang@bytedance.com
>
parent
6e215118
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
70 additions
and
16 deletions
+70
-16
python/sglang/srt/layers/moe/cutlass_moe.py
python/sglang/srt/layers/moe/cutlass_moe.py
+11
-16
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+59
-0
No files found.
python/sglang/srt/layers/moe/cutlass_moe.py
View file @
3f2e315f
...
...
@@ -9,7 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import
torch
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm90_supported
,
is_sm100_supported
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
...
...
@@ -124,6 +124,7 @@ def cutlass_fused_experts_fp8(
if
is_cuda
:
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_group_transpose
,
per_token_group_quant_fp8_hopper_moe_mn_major
,
sglang_per_token_group_quant_fp8
,
)
...
...
@@ -152,15 +153,12 @@ def cutlass_fused_experts_fp8(
k
,
)
if
is_sm100_supported
():
a_q
,
a1_scale
=
sglang_per_token_group_quant_fp8
(
a
,
128
)
rep_a_q
=
shuffle_rows
(
a_q
,
a_map
,
(
m
*
topk
,
k
))
rep_a1_scales
=
shuffle_rows
(
a1_scale
,
a_map
,
(
m
*
topk
,
int
(
k
/
128
)))
else
:
rep_a
=
shuffle_rows
(
a
,
a_map
,
(
m
*
topk
,
k
))
rep_a_q
,
rep_a1_scales
=
per_token_group_quant_fp8_hopper_moe_mn_major
(
rep_a
,
expert_offsets
,
problem_sizes1
,
128
)
a_q
,
a1_scale
=
sglang_per_token_group_quant_fp8
(
a
,
128
)
rep_a_q
=
shuffle_rows
(
a_q
,
a_map
,
(
m
*
topk
,
k
))
rep_a1_scales
=
shuffle_rows
(
a1_scale
,
a_map
,
(
m
*
topk
,
int
(
k
/
128
)))
if
not
is_sm100_supported
():
rep_a1_scales
=
per_group_transpose
(
rep_a1_scales
,
expert_offsets
)
w1_scale
=
w1_scale
.
contiguous
()
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
...
...
@@ -193,12 +191,9 @@ def cutlass_fused_experts_fp8(
intermediate
=
torch
.
empty
((
m
*
topk
,
n
),
device
=
device
,
dtype
=
out_dtype
)
silu_and_mul
(
c1
,
intermediate
)
if
is_sm100_supported
():
intemediate_q
,
a2_scale
=
sglang_per_token_group_quant_fp8
(
intermediate
,
128
)
else
:
intemediate_q
,
a2_scale
=
per_token_group_quant_fp8_hopper_moe_mn_major
(
intermediate
,
expert_offsets
,
problem_sizes2
,
128
)
intemediate_q
,
a2_scale
=
sglang_per_token_group_quant_fp8
(
intermediate
,
128
)
if
not
is_sm100_supported
():
a2_scale
=
per_group_transpose
(
a2_scale
,
expert_offsets
)
w2_scale
=
w2_scale
.
contiguous
()
fp8_blockwise_scaled_grouped_mm
(
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
3f2e315f
...
...
@@ -1356,3 +1356,62 @@ def per_token_group_quant_fp8_hopper_moe_mn_major(
expert_tokens_alignment
,
)
return
a_q
,
sfa
@
triton
.
jit
def
_per_group_transpose
(
data_ptr
:
torch
.
Tensor
,
trans_data_ptr
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
k
:
int
,
M_ALIGNMENT
:
tl
.
constexpr
,
BLOCK_SIZE_M
:
tl
.
constexpr
,
BLOCK_SIZE_K
:
tl
.
constexpr
,
):
expert_id
=
tl
.
program_id
(
0
)
m_id
=
tl
.
program_id
(
1
)
k_id
=
tl
.
program_id
(
2
)
curr_expert_offset
=
tl
.
load
(
expert_offsets
+
expert_id
)
next_expert_offset
=
tl
.
load
(
expert_offsets
+
expert_id
+
1
)
num_tokens_of_expert
=
next_expert_offset
-
curr_expert_offset
tl
.
multiple_of
(
curr_expert_offset
,
M_ALIGNMENT
)
tl
.
multiple_of
(
next_expert_offset
,
M_ALIGNMENT
)
data_start_ptr
=
data_ptr
+
curr_expert_offset
*
k
trans_data_start_ptr
=
trans_data_ptr
+
curr_expert_offset
*
k
k_coord
=
k_id
*
BLOCK_SIZE_K
+
tl
.
arange
(
0
,
BLOCK_SIZE_K
)
k_mask
=
k_coord
<
k
for
start_m
in
tl
.
range
(
0
,
num_tokens_of_expert
,
BLOCK_SIZE_M
*
tl
.
num_programs
(
1
)):
m_coord
=
start_m
+
m_id
*
BLOCK_SIZE_M
+
tl
.
arange
(
0
,
BLOCK_SIZE_M
)
m_mask
=
m_coord
<
num_tokens_of_expert
off
=
m_coord
[:,
None
]
*
k
+
k_coord
[
None
,
:]
trans_off
=
m_coord
[:,
None
]
+
k_coord
[
None
,
:]
*
num_tokens_of_expert
mask
=
m_mask
[:,
None
]
&
k_mask
[
None
,
:]
data
=
tl
.
load
(
data_start_ptr
+
off
,
mask
=
mask
)
tl
.
store
(
trans_data_start_ptr
+
trans_off
,
data
,
mask
=
mask
)
def
per_group_transpose
(
a
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
M_ALIGNMENT
:
int
=
1
,
)
->
torch
.
Tensor
:
assert
a
.
dim
()
==
2
assert
a
.
is_contiguous
(),
"`a` is not contiguous"
m
,
k
=
a
.
size
()
trans_a
=
torch
.
empty_like
(
a
)
num_experts
=
expert_offsets
.
size
(
0
)
-
1
grid
=
lambda
META
:
(
num_experts
,
triton
.
cdiv
((
m
+
num_experts
-
1
)
//
num_experts
,
META
[
"BLOCK_SIZE_M"
]),
triton
.
cdiv
(
k
,
META
[
"BLOCK_SIZE_K"
]),
)
_per_group_transpose
[
grid
](
a
,
trans_a
,
expert_offsets
,
k
,
M_ALIGNMENT
,
BLOCK_SIZE_M
=
16
,
BLOCK_SIZE_K
=
8
)
return
trans_a
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment