Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5fd311d3
Unverified
Commit
5fd311d3
authored
Aug 22, 2025
by
kousakawang
Committed by
GitHub
Aug 21, 2025
Browse files
[code clean] add H20 cutlass groupGemm default config (#9333)
Co-authored-by:
wanghanpei
<
wanghanpei@bytedance.com
>
parent
53e2cd46
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
15 additions
and
35 deletions
+15
-35
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
+15
-35
No files found.
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
View file @
5fd311d3
...
...
@@ -437,34 +437,6 @@ void sm100_fp8_blockwise_group_mm_dispatch_shape(
}
}
#define JOIN_STRUCT_PP_NAME(m, n, k, a, b, c) sm90_fp8_pp_config##_##m##_##n##_##k##_##a##_##b##_##c
#define JOIN_STRUCT_CO_NAME(m, n, k, a, b, c) sm90_fp8_co_config##_##m##_##n##_##k##_##a##_##b##_##c
#define GENERATE_SM90_FP8_PP_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_PP_NAME(M, N, K, A, B, C) { \
using ElementA = cutlass::float_e4m3_t; \
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; \
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
};
#define GENERATE_SM90_FP8_CO_CONFIG(M, N, K, A, B, C) \
struct JOIN_STRUCT_CO_NAME(M, N, K, A, B, C) { \
using ElementA = cutlass::float_e4m3_t; \
using MmaTileShape = Shape<cute::Int<M>, cute::Int<N>, cute::Int<K>>; \
using ClusterShape = Shape<cute::Int<A>, cute::Int<B>, cute::Int<C>>; \
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum; \
using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; \
using ScaleConfig = cutlass::detail::Sm90BlockwiseScaleConfig<1, 128, 128>; \
using LayoutSFA = decltype(ScaleConfig::deduce_layoutSFA()); \
using LayoutSFB = decltype(ScaleConfig::deduce_layoutSFB()); \
};
template
<
typename
OutType
>
void
sm90_fp8_blockwise_group_mm_dispatch_shape
(
torch
::
Tensor
&
output
,
...
...
@@ -509,20 +481,28 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
};
// [NOTE] Tuned for H20
GENERATE_SM90_FP8_PP_CONFIG
(
64
,
128
,
128
,
1
,
2
,
1
)
// [NOTE] default for H20
struct
MmaConfigH20_default
{
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
MmaTileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedPingpongFP8BlockScaledAccum
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedPingpong
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm90BlockwiseScaleConfig
<
1
,
128
,
128
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
};
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
torch
::
TensorOptions
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
problem_sizes_transpose
=
torch
::
empty
(
num_experts
*
3
,
options_int
);
bool
tuning_H20_kernel
=
getBoolEnv
(
"SGL_TUNE_DEVICE_KERNEL"
);
const
std
::
string
H20_device_type_str
=
"NVIDIA H20"
;
bool
is_h20
=
isDeviceType
(
H20_device_type_str
);
bool
is_h20
_device
=
isDeviceType
(
H20_device_type_str
);
if
(
is_h20
&&
tuning_H20_kernel
)
{
using
execute_gemm_config
=
sm90_fp8_pp_config_64_128_128_1_2_1
;
if
(
is_h20
_device
)
{
using
execute_gemm_config
=
MmaConfigH20_default
;
run_get_group_gemm_starts
<
execute_gemm_config
::
LayoutSFA
,
execute_gemm_config
::
LayoutSFB
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment