Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e67276ec
Unverified
Commit
e67276ec
authored
Aug 04, 2025
by
tql.99
Committed by
GitHub
Aug 03, 2025
Browse files
feat: support cutlass_moe_fp8 kernel for fusedmoe in sm90 (#8678)
parent
0242bb9c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
32 additions
and
9 deletions
+32
-9
python/sglang/srt/layers/moe/cutlass_moe.py
python/sglang/srt/layers/moe/cutlass_moe.py
+20
-6
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+3
-3
python/sglang/srt/layers/utils.py
python/sglang/srt/layers/utils.py
+9
-0
No files found.
python/sglang/srt/layers/moe/cutlass_moe.py
View file @
e67276ec
...
...
@@ -9,6 +9,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
import
torch
from
sglang.srt.layers.moe.cutlass_moe_params
import
CutlassMoEParams
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.utils
import
is_cuda
_is_cuda
=
is_cuda
()
...
...
@@ -123,6 +124,7 @@ def cutlass_fused_experts_fp8(
if
is_cuda
:
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8_hopper_moe_mn_major
,
sglang_per_token_group_quant_fp8
,
)
...
...
@@ -133,9 +135,7 @@ def cutlass_fused_experts_fp8(
n
=
w2_q
.
size
(
1
)
topk
=
topk_ids
.
size
(
1
)
a_q
,
a1_scale
=
sglang_per_token_group_quant_fp8
(
a
,
128
)
device
=
a_q
.
device
device
=
a
.
device
a_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
c_map
=
torch
.
empty
((
topk_ids
.
numel
()),
dtype
=
torch
.
int32
,
device
=
device
)
...
...
@@ -152,8 +152,16 @@ def cutlass_fused_experts_fp8(
k
,
)
rep_a_q
=
shuffle_rows
(
a_q
,
a_map
,
(
m
*
topk
,
k
))
rep_a1_scales
=
shuffle_rows
(
a1_scale
,
a_map
,
(
m
*
topk
,
int
(
k
/
128
)))
if
is_sm100_supported
():
a_q
,
a1_scale
=
sglang_per_token_group_quant_fp8
(
a
,
128
)
rep_a_q
=
shuffle_rows
(
a_q
,
a_map
,
(
m
*
topk
,
k
))
rep_a1_scales
=
shuffle_rows
(
a1_scale
,
a_map
,
(
m
*
topk
,
int
(
k
/
128
)))
else
:
rep_a
=
shuffle_rows
(
a
,
a_map
,
(
m
*
topk
,
k
))
rep_a_q
,
rep_a1_scales
=
per_token_group_quant_fp8_hopper_moe_mn_major
(
rep_a
,
expert_offsets
,
problem_sizes1
,
128
)
w1_scale
=
w1_scale
.
contiguous
()
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
torch
.
empty
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
...
...
@@ -185,7 +193,13 @@ def cutlass_fused_experts_fp8(
intermediate
=
torch
.
empty
((
m
*
topk
,
n
),
device
=
device
,
dtype
=
out_dtype
)
silu_and_mul
(
c1
,
intermediate
)
intemediate_q
,
a2_scale
=
sglang_per_token_group_quant_fp8
(
intermediate
,
128
)
if
is_sm100_supported
():
intemediate_q
,
a2_scale
=
sglang_per_token_group_quant_fp8
(
intermediate
,
128
)
else
:
intemediate_q
,
a2_scale
=
per_token_group_quant_fp8_hopper_moe_mn_major
(
intermediate
,
expert_offsets
,
problem_sizes2
,
128
)
w2_scale
=
w2_scale
.
contiguous
()
fp8_blockwise_scaled_grouped_mm
(
c2
,
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
e67276ec
...
...
@@ -63,7 +63,7 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize
,
requantize_with_max_scale
,
)
from
sglang.srt.layers.utils
import
is_sm100_supported
from
sglang.srt.layers.utils
import
is_sm90_supported
,
is_sm100_supported
from
sglang.srt.utils
import
(
cpu_has_amx_support
,
get_bool_env_var
,
...
...
@@ -619,7 +619,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
if
(
get_bool_env_var
(
"SGLANG_CUTLASS_MOE"
)
and
self
.
cutlass_fp8_supported
and
is_sm100_supported
()
and
(
is_sm100_supported
()
or
is_sm90_supported
())
):
self
.
ab_strides1
=
torch
.
full
(
(
num_experts
,),
...
...
@@ -1034,7 +1034,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
get_bool_env_var
(
"SGLANG_CUTLASS_MOE"
)
and
self
.
cutlass_fp8_supported
and
self
.
block_quant
and
is_sm100_supported
()
and
(
is_sm100_supported
()
or
is_sm90_supported
())
):
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
...
...
python/sglang/srt/layers/utils.py
View file @
e67276ec
import
logging
import
re
from
functools
import
lru_cache
import
torch
...
...
@@ -35,7 +36,15 @@ class PPMissingLayer(torch.nn.Identity):
return
(
input
,)
if
self
.
return_tuple
else
input
@
lru_cache
(
maxsize
=
1
)
def
is_sm100_supported
(
device
=
None
)
->
bool
:
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
10
)
and
(
torch
.
version
.
cuda
>=
"12.8"
)
@
lru_cache
(
maxsize
=
1
)
def
is_sm90_supported
(
device
=
None
)
->
bool
:
return
(
torch
.
cuda
.
get_device_capability
(
device
)[
0
]
==
9
)
and
(
torch
.
version
.
cuda
>=
"12.3"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment