Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
fda47926
Unverified
Commit
fda47926
authored
Aug 25, 2025
by
Qi Yuhang
Committed by
GitHub
Aug 24, 2025
Browse files
Update CUTLASS 4.2 & Enable K-Major Scale Factor for SM90 FP8 Blockwise Group GEMM (#9559)
parent
a0b22f2f
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
103 additions
and
133 deletions
+103
-133
python/sglang/srt/layers/moe/cutlass_moe.py
python/sglang/srt/layers/moe/cutlass_moe.py
+0
-7
python/sglang/test/test_cutlass_moe.py
python/sglang/test/test_cutlass_moe.py
+33
-28
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-1
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
+49
-40
sgl-kernel/tests/test_fp8_blockwise_moe.py
sgl-kernel/tests/test_fp8_blockwise_moe.py
+20
-57
No files found.
python/sglang/srt/layers/moe/cutlass_moe.py
View file @
fda47926
...
@@ -157,10 +157,6 @@ def cutlass_fused_experts_fp8(
...
@@ -157,10 +157,6 @@ def cutlass_fused_experts_fp8(
rep_a_q
=
shuffle_rows
(
a_q
,
a_map
,
(
m
*
topk
,
k
))
rep_a_q
=
shuffle_rows
(
a_q
,
a_map
,
(
m
*
topk
,
k
))
rep_a1_scales
=
shuffle_rows
(
a1_scale
,
a_map
,
(
m
*
topk
,
int
(
k
/
128
)))
rep_a1_scales
=
shuffle_rows
(
a1_scale
,
a_map
,
(
m
*
topk
,
int
(
k
/
128
)))
if
not
is_sm100_supported
():
rep_a1_scales
=
per_group_transpose
(
rep_a1_scales
,
expert_offsets
)
w1_scale
=
w1_scale
.
contiguous
()
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c1
=
torch
.
empty
((
m
*
topk
,
n
*
2
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
torch
.
empty
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
c2
=
torch
.
empty
((
m
*
topk
,
k
),
device
=
device
,
dtype
=
out_dtype
)
...
@@ -192,9 +188,6 @@ def cutlass_fused_experts_fp8(
...
@@ -192,9 +188,6 @@ def cutlass_fused_experts_fp8(
silu_and_mul
(
c1
,
intermediate
)
silu_and_mul
(
c1
,
intermediate
)
intemediate_q
,
a2_scale
=
sglang_per_token_group_quant_fp8
(
intermediate
,
128
)
intemediate_q
,
a2_scale
=
sglang_per_token_group_quant_fp8
(
intermediate
,
128
)
if
not
is_sm100_supported
():
a2_scale
=
per_group_transpose
(
a2_scale
,
expert_offsets
)
w2_scale
=
w2_scale
.
contiguous
()
fp8_blockwise_scaled_grouped_mm
(
fp8_blockwise_scaled_grouped_mm
(
c2
,
c2
,
...
...
python/sglang/test/test_cutlass_moe.py
View file @
fda47926
...
@@ -8,6 +8,15 @@ from transformers import AutoConfig
...
@@ -8,6 +8,15 @@ from transformers import AutoConfig
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
from
sglang.srt.layers.moe.cutlass_moe
import
cutlass_fused_experts_fp8
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.fused_moe_triton.fused_moe
import
fused_experts
from
sglang.srt.layers.moe.moe_runner.base
import
MoeRunnerConfig
# Copy from: https://github.com/deepseek-ai/DeepGEMM/blob/main/deep_gemm/utils.py
def
calc_diff
(
x
,
y
):
x
,
y
=
x
.
double
(),
y
.
double
()
denominator
=
(
x
*
x
+
y
*
y
).
sum
()
sim
=
2
*
(
x
*
y
).
sum
()
/
denominator
return
1
-
sim
def
get_model_config
(
tp_size
:
int
):
def
get_model_config
(
tp_size
:
int
):
...
@@ -69,16 +78,11 @@ def run_test(tp_size, batch_size, model_config, check=False):
...
@@ -69,16 +78,11 @@ def run_test(tp_size, batch_size, model_config, check=False):
# --- Input Data ---
# --- Input Data ---
# Use bf16/fp16 for input activation based on model config
# Use bf16/fp16 for input activation based on model config
x
=
torch
.
randn
((
batch_size
,
H
),
device
=
"cuda"
,
dtype
=
dtype
)
*
0.0001
x
=
torch
.
randn
((
batch_size
,
H
),
device
=
"cuda"
,
dtype
=
dtype
)
# --- Weights (Generate in higher precision, then convert to FP8) ---
# --- Weights (Generate in higher precision, then convert to FP8) ---
# Generate weights suitable for FP8 conversion (e.g., scaled appropriately)
# Generate weights suitable for FP8 conversion (e.g., scaled appropriately)
w1_hp
=
(
w1_hp
=
torch
.
randn
((
E
,
I
,
H
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
torch
.
randn
((
E
,
I
,
H
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.00001
+
0.00001
w2_hp
=
torch
.
randn
((
E
,
H
,
I
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
)
w2_hp
=
(
torch
.
randn
((
E
,
H
,
I
//
2
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
0.00001
+
0.00001
)
w1
=
to_fp8
(
w1_hp
)
w1
=
to_fp8
(
w1_hp
)
w2
=
to_fp8
(
w2_hp
)
w2
=
to_fp8
(
w2_hp
)
...
@@ -149,13 +153,13 @@ def run_test(tp_size, batch_size, model_config, check=False):
...
@@ -149,13 +153,13 @@ def run_test(tp_size, batch_size, model_config, check=False):
)
)
# Note: Triton expects non-transposed weights
# Note: Triton expects non-transposed weights
moe_config
=
MoeRunnerConfig
(
inplace
=
False
)
triton_lambda
=
lambda
:
fused_experts
(
triton_lambda
=
lambda
:
fused_experts
(
x
,
x
,
w1
,
w1
,
w2
,
w2
,
(
topk_weights
,
topk_ids
,
"dummy"
),
(
topk_weights
,
topk_ids
,
"dummy"
),
inplace
=
False
,
moe_config
,
activation
=
"silu"
,
# Assuming SiLU activation common in MoEs
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
...
@@ -221,32 +225,19 @@ def run_test(tp_size, batch_size, model_config, check=False):
...
@@ -221,32 +225,19 @@ def run_test(tp_size, batch_size, model_config, check=False):
w1
,
# Original shape
w1
,
# Original shape
w2
,
# Original shape
w2
,
# Original shape
(
topk_weights
,
topk_ids
,
"dummy"
),
(
topk_weights
,
topk_ids
,
"dummy"
),
inplace
=
False
,
# Important: Use False to get output tensor
moe_config
,
activation
=
"silu"
,
use_fp8_w8a8
=
True
,
use_fp8_w8a8
=
True
,
w1_scale
=
w1_scale
,
w1_scale
=
w1_scale
,
w2_scale
=
w2_scale
,
w2_scale
=
w2_scale
,
block_shape
=
block_shape
,
block_shape
=
block_shape
,
)
)
# Ensure outputs are same dtype for comparison
diff
=
calc_diff
(
y_cutlass
,
y_triton
)
y_cutlass
=
y_cutlass
.
to
(
dtype
)
print
(
f
"Diff:
{
diff
:.
6
f
}
"
)
y_triton
=
y_triton
.
to
(
dtype
)
abs_error
=
torch
.
abs
(
y_cutlass
-
y_triton
)
rel_error
=
abs_error
/
torch
.
clamp
(
torch
.
abs
(
y_triton
),
min
=
1e-2
)
max_abs_err
=
abs_error
.
max
().
item
()
max_rel_err
=
rel_error
.
max
().
item
()
print
(
"y_cutlass:"
,
y_cutlass
[:,
:
10
])
print
(
"y_triton:"
,
y_triton
[:,
:
10
])
print
(
f
"Max absolute error:
{
max_abs_err
:.
6
f
}
"
)
print
(
f
"Max relative error:
{
max_rel_err
:.
6
f
}
"
)
# Tolerance might need adjustment based on FP8 specifics and kernel differences
# Tolerance might need adjustment based on FP8 specifics and kernel differences
# FP8 comparisons often require higher tolerance than FP16/BF16
# FP8 comparisons often require higher tolerance than FP16/BF16
assert
max_rel_err
<
5
e-
1
,
f
"
Relative error
too high!
{
max_rel_err
}
"
assert
diff
<
1
e-
4
,
f
"
Diff
too high!
{
diff
}
"
print
(
"Correctness check passed."
)
print
(
"Correctness check passed."
)
...
@@ -264,7 +255,21 @@ if __name__ == "__main__":
...
@@ -264,7 +255,21 @@ if __name__ == "__main__":
"--batch-sizes"
,
"--batch-sizes"
,
type
=
int
,
type
=
int
,
nargs
=
"+"
,
nargs
=
"+"
,
default
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
],
# Adjusted default
default
=
[
1
,
4
,
8
,
16
,
32
,
64
,
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
,
],
# Adjusted default
help
=
"List of batch sizes to test"
,
help
=
"List of batch sizes to test"
,
)
)
parser
.
add_argument
(
"--check"
,
action
=
"store_true"
,
help
=
"Enable check mode"
)
parser
.
add_argument
(
"--check"
,
action
=
"store_true"
,
help
=
"Enable check mode"
)
...
...
sgl-kernel/CMakeLists.txt
View file @
fda47926
...
@@ -45,7 +45,7 @@ include(FetchContent)
...
@@ -45,7 +45,7 @@ include(FetchContent)
FetchContent_Declare
(
FetchContent_Declare
(
repo-cutlass
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG
664c4f7b3ed1959414905025728eef5568209479
GIT_TAG
a49a78ffefc86a87160dfe0ccc3a3a2d1622c918
GIT_SHALLOW OFF
GIT_SHALLOW OFF
)
)
FetchContent_Populate
(
repo-cutlass
)
FetchContent_Populate
(
repo-cutlass
)
...
...
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
View file @
fda47926
...
@@ -457,39 +457,40 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -457,39 +457,40 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
workspace
)
{
const
torch
::
Tensor
&
workspace
)
{
struct
MmaConfig0
{
struct
MmaConfigSmallM
{
// Swap A/B
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
MmaTileShape
=
Shape
<
_
64
,
_128
,
_128
>
;
using
MmaTileShape
=
Shape
<
_
128
,
_32
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
// TODO: Check Pingpong or Cooperative
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedPingpong
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm90BlockwiseScaleConfig
<
1
,
128
,
128
>
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm90BlockwiseScaleConfig
<
128
,
1
,
128
,
cute
::
GMMA
::
Major
::
K
,
cute
::
GMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
};
};
struct
MmaConfig
1
{
struct
MmaConfig
H20LargeK
{
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
MmaTileShape
=
Shape
<
_
128
,
_128
,
_128
>
;
using
MmaTileShape
=
Shape
<
_
64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_
1
,
_
2
,
_1
>
;
using
ClusterShape
=
Shape
<
_
2
,
_
1
,
_1
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecialized
Cooperative
FP8BlockScaledAccum
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecialized
Pingpong
FP8BlockScaledAccum
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized
Cooperative
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized
Pingpong
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm90BlockwiseScaleConfig
<
1
,
128
,
128
>
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm90BlockwiseScaleConfig
<
1
,
128
,
128
,
cute
::
GMMA
::
Major
::
K
,
cute
::
GMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
};
};
// [NOTE] default for H20
struct
MmaConfigHx00AndH20SmallK
{
struct
MmaConfigH20_default
{
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
MmaTileShape
=
Shape
<
_
64
,
_128
,
_128
>
;
using
MmaTileShape
=
Shape
<
_
128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecialized
Pingpong
FP8BlockScaledAccum
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecialized
Cooperative
FP8BlockScaledAccum
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized
Pingpong
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecialized
Cooperative
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm90BlockwiseScaleConfig
<
1
,
128
,
128
>
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm90BlockwiseScaleConfig
<
1
,
128
,
128
,
cute
::
GMMA
::
Major
::
K
,
cute
::
GMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
};
};
...
@@ -497,33 +498,34 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -497,33 +498,34 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
torch
::
TensorOptions
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
TensorOptions
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
problem_sizes_transpose
=
torch
::
empty
(
num_experts
*
3
,
options_int
);
torch
::
Tensor
problem_sizes_transpose
=
torch
::
empty
(
num_experts
*
3
,
options_int
);
torch
::
Tensor
output_t
=
output
.
t
();
torch
::
Tensor
a_t
=
a
.
t
();
torch
::
Tensor
b_t
=
b
.
transpose
(
1
,
2
);
torch
::
Tensor
scales_a_t
=
scales_a
.
t
();
torch
::
Tensor
scales_b_t
=
scales_b
.
transpose
(
1
,
2
);
const
std
::
string
H20_device_type_str
=
"NVIDIA H20"
;
const
std
::
string
H20_device_type_str
(
"NVIDIA H20"
)
;
bool
is_h20_device
=
isDeviceType
(
H20_device_type_str
)
;
bool
is_h20_device
=
std
::
string
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
name
)
==
H20_device_type_str
;
if
(
is_h20_device
)
{
if
(
a
.
size
(
0
)
<=
2048
)
{
using
execute_gemm_config
=
MmaConfigH20_default
;
run_get_group_gemm_starts
<
MmaConfigSmallM
::
LayoutSFA
,
MmaConfigSmallM
::
LayoutSFB
,
MmaConfigSmallM
::
ScaleConfig
>
(
run_get_group_gemm_starts
<
execute_gemm_config
::
LayoutSFA
,
execute_gemm_config
::
LayoutSFB
,
execute_gemm_config
::
ScaleConfig
>
(
expert_offsets
,
expert_offsets
,
a_ptrs
,
a_ptrs
,
b_ptrs
,
b_ptrs
,
out_ptrs
,
out_ptrs
,
a_scales_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
b_scales_ptrs
,
a
,
b_t
,
b
,
a_t
,
output
,
output
_t
,
scales_
a
,
scales_
b_t
,
scales_
b
,
scales_
a_t
,
layout_sfa
,
layout_sfa
,
layout_sfb
,
layout_sfb
,
problem_sizes
,
problem_sizes
,
problem_sizes_transpose
);
problem_sizes_transpose
,
true
);
launch_sm90_fp8_blockwise_scaled_group_mm
<
OutType
,
execute_gemm_config
,
cutlass
::
layout
::
Row
Major
>
(
launch_sm90_fp8_blockwise_scaled_group_mm
<
OutType
,
MmaConfigSmallM
,
cutlass
::
layout
::
Column
Major
>
(
out_ptrs
,
out_ptrs
,
a_ptrs
,
a_ptrs
,
b_ptrs
,
b_ptrs
,
...
@@ -534,13 +536,17 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -534,13 +536,17 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
stride_c
,
stride_c
,
layout_sfa
,
layout_sfa
,
layout_sfb
,
layout_sfb
,
problem_sizes
,
problem_sizes
_transpose
,
expert_offsets
,
expert_offsets
,
workspace
);
workspace
);
output
=
output_t
.
t
();
}
else
{
}
else
{
if
(
at
::
cuda
::
getCurrentDeviceProperties
()
->
multiProcessorCount
==
78
&&
a
.
size
(
1
)
>
128
)
{
if
(
is_h20_device
&&
a
.
size
(
1
)
>
128
)
{
// For H20 with K > 128, use Pingpong Schedule
// For H20 with K > 128, use Pingpong Schedule
run_get_group_gemm_starts
<
MmaConfig0
::
LayoutSFA
,
MmaConfig0
::
LayoutSFB
,
MmaConfig0
::
ScaleConfig
>
(
run_get_group_gemm_starts
<
MmaConfigH20LargeK
::
LayoutSFA
,
MmaConfigH20LargeK
::
LayoutSFB
,
MmaConfigH20LargeK
::
ScaleConfig
>
(
expert_offsets
,
expert_offsets
,
a_ptrs
,
a_ptrs
,
b_ptrs
,
b_ptrs
,
...
@@ -556,7 +562,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -556,7 +562,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
layout_sfb
,
layout_sfb
,
problem_sizes
,
problem_sizes
,
problem_sizes_transpose
);
problem_sizes_transpose
);
launch_sm90_fp8_blockwise_scaled_group_mm
<
OutType
,
MmaConfig
0
,
cutlass
::
layout
::
RowMajor
>
(
launch_sm90_fp8_blockwise_scaled_group_mm
<
OutType
,
MmaConfig
H20LargeK
,
cutlass
::
layout
::
RowMajor
>
(
out_ptrs
,
out_ptrs
,
a_ptrs
,
a_ptrs
,
b_ptrs
,
b_ptrs
,
...
@@ -572,7 +578,10 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -572,7 +578,10 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
workspace
);
workspace
);
}
else
{
}
else
{
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
// For H20 with K <= 128, and H100 & H200 & H800, use Cooperative Schedule
run_get_group_gemm_starts
<
MmaConfig1
::
LayoutSFA
,
MmaConfig1
::
LayoutSFB
,
MmaConfig1
::
ScaleConfig
>
(
run_get_group_gemm_starts
<
MmaConfigHx00AndH20SmallK
::
LayoutSFA
,
MmaConfigHx00AndH20SmallK
::
LayoutSFB
,
MmaConfigHx00AndH20SmallK
::
ScaleConfig
>
(
expert_offsets
,
expert_offsets
,
a_ptrs
,
a_ptrs
,
b_ptrs
,
b_ptrs
,
...
@@ -588,7 +597,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -588,7 +597,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
layout_sfb
,
layout_sfb
,
problem_sizes
,
problem_sizes
,
problem_sizes_transpose
);
problem_sizes_transpose
);
launch_sm90_fp8_blockwise_scaled_group_mm
<
OutType
,
MmaConfig
1
,
cutlass
::
layout
::
RowMajor
>
(
launch_sm90_fp8_blockwise_scaled_group_mm
<
OutType
,
MmaConfig
Hx00AndH20SmallK
,
cutlass
::
layout
::
RowMajor
>
(
out_ptrs
,
out_ptrs
,
a_ptrs
,
a_ptrs
,
b_ptrs
,
b_ptrs
,
...
...
sgl-kernel/tests/test_fp8_blockwise_moe.py
View file @
fda47926
...
@@ -5,10 +5,6 @@ import pytest
...
@@ -5,10 +5,6 @@ import pytest
import
torch
import
torch
from
sgl_kernel
import
fp8_blockwise_scaled_grouped_mm
from
sgl_kernel
import
fp8_blockwise_scaled_grouped_mm
from
sglang.srt.layers.quantization.fp8_kernel
import
(
per_token_group_quant_fp8_hopper_moe_mn_major
,
)
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
def
cdiv
(
a
:
int
,
b
:
int
)
->
int
:
return
-
(
a
//
-
b
)
return
-
(
a
//
-
b
)
...
@@ -106,24 +102,19 @@ def is_sm90_supported(device=None) -> bool:
...
@@ -106,24 +102,19 @@ def is_sm90_supported(device=None) -> bool:
not
(
is_sm100_supported
()
or
is_sm90_supported
()),
not
(
is_sm100_supported
()
or
is_sm90_supported
()),
reason
=
"fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90"
,
reason
=
"fp8_blockwise_scaled_grouped_mm at sgl-kernel is only supported on sm100 or sm90"
,
)
)
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
16
])
@
pytest
.
mark
.
parametrize
(
"num_experts"
,
[
8
,
16
,
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"use_custom_kernel"
,
[
True
,
False
])
def
test_fp8_blockwise_scaled_grouped_mm
(
num_experts
,
out_dtype
):
def
test_fp8_blockwise_scaled_grouped_mm
(
num_experts
,
out_dtype
,
use_custom_kernel
):
cc
=
torch
.
cuda
.
get_device_capability
(
None
)[
0
]
if
cc
==
10
and
use_custom_kernel
:
return
device
=
"cuda"
device
=
"cuda"
alignment
=
1
6
alignment
=
1
28
n_g
=
alignment
*
random
.
randint
(
1
,
5
)
*
128
n_g
=
random
.
randint
(
1
,
64
)
*
128
k_g
=
alignment
*
random
.
randint
(
1
,
5
)
*
128
k_g
=
random
.
randint
(
1
,
64
)
*
128
expert_offsets
=
torch
.
zeros
((
num_experts
+
1
),
device
=
device
,
dtype
=
torch
.
int32
)
expert_offsets
=
torch
.
zeros
((
num_experts
+
1
),
device
=
device
,
dtype
=
torch
.
int32
)
problem_sizes
=
torch
.
zeros
((
num_experts
,
3
),
device
=
device
,
dtype
=
torch
.
int32
)
problem_sizes
=
torch
.
zeros
((
num_experts
,
3
),
device
=
device
,
dtype
=
torch
.
int32
)
layout_sfa
=
torch
.
zeros
((
num_experts
,
5
),
device
=
device
,
dtype
=
torch
.
int32
)
layout_sfa
=
torch
.
zeros
((
num_experts
,
5
),
device
=
device
,
dtype
=
torch
.
int32
)
layout_sfb
=
torch
.
zeros
((
num_experts
,
5
),
device
=
device
,
dtype
=
torch
.
int32
)
layout_sfb
=
torch
.
zeros
((
num_experts
,
5
),
device
=
device
,
dtype
=
torch
.
int32
)
a_original_tensors
=
[]
a_tensors
=
[]
a_tensors
=
[]
b_tensors
=
[]
b_tensors
=
[]
a_scales_tensors
=
[]
a_scales_tensors
=
[]
...
@@ -131,7 +122,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
...
@@ -131,7 +122,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
baseline_tensors
=
[]
baseline_tensors
=
[]
for
g
in
range
(
num_experts
):
for
g
in
range
(
num_experts
):
m_g
=
alignment
*
random
.
randint
(
1
,
6
4
)
m_g
=
random
.
randint
(
1
,
25
6
)
expert_offsets
[
g
+
1
]
=
expert_offsets
[
g
]
+
m_g
expert_offsets
[
g
+
1
]
=
expert_offsets
[
g
]
+
m_g
problem_sizes
[
g
][:]
=
torch
.
tensor
([
m_g
,
n_g
,
k_g
],
device
=
device
)
problem_sizes
[
g
][:]
=
torch
.
tensor
([
m_g
,
n_g
,
k_g
],
device
=
device
)
...
@@ -144,7 +135,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
...
@@ -144,7 +135,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
b_g
,
b_scale
=
per_block_cast_to_fp8
(
b_g
,
b_scale
=
per_block_cast_to_fp8
(
b
b
)
# bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
)
# bg -- (K, N):(N, 1), b_scale() -- (k, n):(n, 1)
a_original_tensors
.
append
(
a
)
a_tensors
.
append
(
a_g
)
a_tensors
.
append
(
a_g
)
b_tensors
.
append
(
b_g
)
b_tensors
.
append
(
b_g
)
a_scales_tensors
.
append
(
a_scale
)
a_scales_tensors
.
append
(
a_scale
)
...
@@ -152,9 +142,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
...
@@ -152,9 +142,6 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
baseline
=
torch
.
mm
(
a
,
b
)
baseline
=
torch
.
mm
(
a
,
b
)
baseline_tensors
.
append
(
baseline
)
baseline_tensors
.
append
(
baseline
)
a_original_stack
=
torch
.
empty
(
(
expert_offsets
[
-
1
],
k_g
),
device
=
device
,
dtype
=
out_dtype
)
a_stack
=
torch
.
empty
(
a_stack
=
torch
.
empty
(
(
expert_offsets
[
-
1
],
k_g
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
(
expert_offsets
[
-
1
],
k_g
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
)
...
@@ -162,52 +149,28 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
...
@@ -162,52 +149,28 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
(
num_experts
,
n_g
,
k_g
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
(
num_experts
,
n_g
,
k_g
),
device
=
device
,
dtype
=
torch
.
float8_e4m3fn
)
)
a_scale_stack
=
torch
.
empty
(
a_scale_stack
=
torch
.
empty
(
(
expert_offsets
[
-
1
]
*
(
k_g
//
128
)),
device
=
device
,
dtype
=
torch
.
float32
(
expert_offsets
[
-
1
]
,
(
k_g
//
128
)),
device
=
device
,
dtype
=
torch
.
float32
)
)
b_scale_stack
=
torch
.
empty
(
b_scale_stack
=
torch
.
empty
(
(
num_experts
,
k
_g
//
128
,
n
_g
//
128
),
device
=
device
,
dtype
=
torch
.
float32
(
num_experts
,
n
_g
//
128
,
k
_g
//
128
),
device
=
device
,
dtype
=
torch
.
float32
)
)
for
g
in
range
(
num_experts
):
for
g
in
range
(
num_experts
):
# Matrix A is Row-Major.
# Matrix A is Row-Major.
a_original_stack
[
expert_offsets
[
g
]
:
expert_offsets
[
g
+
1
]]
=
(
a_stack
[
expert_offsets
[
g
]
:
expert_offsets
[
g
+
1
],
:]
=
a_tensors
[
a_original_tensors
[
g
]
)
a_stack
[
expert_offsets
[
g
]
:
expert_offsets
[
g
+
1
]]
=
a_tensors
[
g
g
]
# a_stack[expert_offsets[g] : expert_offsets[g + 1]] -- (M, K):(K, 1)
]
# a_stack[expert_offsets[g] : expert_offsets[g + 1]
, :
] -- (M, K):(K, 1)
b_stack
[
g
]
=
b_tensors
[
g
].
t
()
# b_stack[g] -- (N, K):(K, 1)
b_stack
[
g
]
=
b_tensors
[
g
].
t
()
# b_stack[g] -- (N, K):(K, 1)
if
cc
==
9
:
# For SM90, we need MN-Major scale factor
# We need K-Major scale factor
# a_scales_tensors[g] -- (M, k):(k, 1)
a_scale_stack
[
expert_offsets
[
g
]
:
expert_offsets
[
g
+
1
],
:]
=
a_scales_tensors
[
# a_scales_tensors[g].t().contiguous() -- (k, M):(M, 1)
g
a_scale_stack
[
]
expert_offsets
[
g
]
*
(
k_g
//
128
)
:
expert_offsets
[
g
+
1
]
*
(
k_g
//
128
)
b_scale_stack
[
g
]
=
b_scales_tensors
[
]
=
(
a_scales_tensors
[
g
].
t
().
contiguous
().
view
(
-
1
))
g
b_scale_stack
[
g
]
=
b_scales_tensors
[
g
]
# b_scale_stack[g] -- (k, n):(n, 1)
].
t
()
# b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
elif
cc
==
10
:
# For SM100, we need K-Major scale factor
# a_scales_tensors[g] -- (M, k):(k, 1)
a_scale_stack
[
expert_offsets
[
g
]
*
(
k_g
//
128
)
:
expert_offsets
[
g
+
1
]
*
(
k_g
//
128
)
]
=
a_scales_tensors
[
g
].
view
(
-
1
)
b_scale_stack
[
g
]
=
b_scales_tensors
[
g
]
# b_scale_stack[g] -- (k, n):(n, 1), we need transpose & contiguous later
a_scale_stack
=
a_scale_stack
.
view
(
expert_offsets
[
-
1
],
k_g
//
128
)
b_stack
=
b_stack
.
transpose
(
1
,
2
)
# Transpose Matrix B to Column-Major.
b_stack
=
b_stack
.
transpose
(
1
,
2
)
# Transpose Matrix B to Column-Major.
if
cc
==
10
:
b_scale_stack
=
b_scale_stack
.
transpose
(
1
,
2
)
b_scale_stack
=
b_scale_stack
.
transpose
(
1
,
2
).
contiguous
()
if
use_custom_kernel
:
# Replace a_stack, a_scale_stack with custom kernel output
a_stack
,
a_scale_stack
=
per_token_group_quant_fp8_hopper_moe_mn_major
(
a_original_stack
,
expert_offsets
[:
-
1
],
problem_sizes
,
128
,
expert_tokens_alignment
=
alignment
,
)
c_out
=
torch
.
empty
((
expert_offsets
[
-
1
],
n_g
),
device
=
device
,
dtype
=
out_dtype
)
c_out
=
torch
.
empty
((
expert_offsets
[
-
1
],
n_g
),
device
=
device
,
dtype
=
out_dtype
)
a_strides
=
torch
.
full
(
a_strides
=
torch
.
full
(
...
@@ -250,7 +213,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
...
@@ -250,7 +213,7 @@ def test_fp8_blockwise_scaled_grouped_mm(num_experts, out_dtype, use_custom_kern
diff
=
calc_diff
(
actual
,
baseline
)
diff
=
calc_diff
(
actual
,
baseline
)
assert
diff
<
0.001
assert
diff
<
0.001
print
(
print
(
f
"
cc=
{
cc
}
0
num_experts=
{
num_experts
}
, out_dtype=
{
out_dtype
}
, diff=
{
diff
:.
5
f
}
: OK"
f
"
m_g=
{
baseline
.
shape
[
0
]
}
n_g=
{
n_g
}
k_g=
{
k_g
}
num_experts=
{
num_experts
}
, out_dtype=
{
out_dtype
}
, diff=
{
diff
:.
5
f
}
: OK"
)
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment