Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8e9fb43d
Unverified
Commit
8e9fb43d
authored
Jul 05, 2025
by
Qi Yuhang
Committed by
GitHub
Jul 04, 2025
Browse files
Optimize Hopper CUTLASS FP8 Blockwise Grouped GEMM Kernel in Small K Scenario (#7782)
parent
83646089
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
86 additions
and
38 deletions
+86
-38
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
+86
-38
No files found.
sgl-kernel/csrc/moe/fp8_blockwise_moe_kernel.cu
View file @
8e9fb43d
...
@@ -61,7 +61,12 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
...
@@ -61,7 +61,12 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
FusionOperation
=
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
ElementAccumulator
>
;
static
constexpr
auto
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
using
CustomEVTIdentity
=
// acc
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
epilogue
::
thread
::
Identity
,
ElementD
,
ElementAccumulator
,
RoundStyle
>
,
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
ArchTag
,
...
@@ -78,7 +83,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
...
@@ -78,7 +83,7 @@ void launch_sm90_fp8_blockwise_scaled_group_mm(
LayoutC
*
,
LayoutC
*
,
AlignmentC
,
AlignmentC
,
typename
ScheduleConfig
::
EpilogueSchedule
,
typename
ScheduleConfig
::
EpilogueSchedule
,
F
us
ionOperation
>::
CollectiveOp
;
C
us
tomEVTIdentity
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
ArchTag
,
...
@@ -452,7 +457,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -452,7 +457,7 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
problem_sizes
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
expert_offsets
,
const
torch
::
Tensor
&
workspace
)
{
const
torch
::
Tensor
&
workspace
)
{
struct
MmaConfig
{
struct
MmaConfig
0
{
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
MmaTileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
MmaTileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
...
@@ -464,40 +469,86 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
...
@@ -464,40 +469,86 @@ void sm90_fp8_blockwise_group_mm_dispatch_shape(
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
};
};
struct
MmaConfig1
{
using
ElementA
=
cutlass
::
float_e4m3_t
;
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelPtrArrayTmaWarpSpecializedCooperativeFP8BlockScaledAccum
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
PtrArrayTmaWarpSpecializedCooperative
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm90BlockwiseScaleConfig
<
1
,
128
,
128
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
};
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
int
num_experts
=
(
int
)
expert_offsets
.
size
(
0
);
torch
::
TensorOptions
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
TensorOptions
options_int
=
torch
::
TensorOptions
().
dtype
(
torch
::
kInt64
).
device
(
a
.
device
());
torch
::
Tensor
problem_sizes_transpose
=
torch
::
empty
(
num_experts
*
3
,
options_int
);
torch
::
Tensor
problem_sizes_transpose
=
torch
::
empty
(
num_experts
*
3
,
options_int
);
run_get_group_gemm_starts
<
MmaConfig
::
LayoutSFA
,
MmaConfig
::
LayoutSFB
,
MmaConfig
::
ScaleConfig
>
(
if
(
a
.
size
(
1
)
>
128
)
{
expert_offsets
,
run_get_group_gemm_starts
<
MmaConfig0
::
LayoutSFA
,
MmaConfig0
::
LayoutSFB
,
MmaConfig0
::
ScaleConfig
>
(
a_ptrs
,
expert_offsets
,
b_ptrs
,
a_ptrs
,
out_ptrs
,
b_ptrs
,
a_scales_ptrs
,
out_ptrs
,
b_scales_ptrs
,
a_scales_ptrs
,
a
,
b_scales_ptrs
,
b
,
a
,
output
,
b
,
scales_a
,
output
,
scales_b
,
scales_a
,
layout_sfa
,
scales_b
,
layout_sfb
,
layout_sfa
,
problem_sizes
,
layout_sfb
,
problem_sizes_transpose
);
problem_sizes
,
launch_sm90_fp8_blockwise_scaled_group_mm
<
OutType
,
MmaConfig
,
cutlass
::
layout
::
RowMajor
>
(
problem_sizes_transpose
);
out_ptrs
,
launch_sm90_fp8_blockwise_scaled_group_mm
<
OutType
,
MmaConfig0
,
cutlass
::
layout
::
RowMajor
>
(
a_ptrs
,
out_ptrs
,
b_ptrs
,
a_ptrs
,
a_scales_ptrs
,
b_ptrs
,
b_scales_ptrs
,
a_scales_ptrs
,
stride_a
,
b_scales_ptrs
,
stride_b
,
stride_a
,
stride_c
,
stride_b
,
layout_sfa
,
stride_c
,
layout_sfb
,
layout_sfa
,
problem_sizes
,
layout_sfb
,
expert_offsets
,
problem_sizes
,
workspace
);
expert_offsets
,
workspace
);
}
else
{
// Small K
run_get_group_gemm_starts
<
MmaConfig1
::
LayoutSFA
,
MmaConfig1
::
LayoutSFB
,
MmaConfig1
::
ScaleConfig
>
(
expert_offsets
,
a_ptrs
,
b_ptrs
,
out_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
a
,
b
,
output
,
scales_a
,
scales_b
,
layout_sfa
,
layout_sfb
,
problem_sizes
,
problem_sizes_transpose
);
launch_sm90_fp8_blockwise_scaled_group_mm
<
OutType
,
MmaConfig1
,
cutlass
::
layout
::
RowMajor
>
(
out_ptrs
,
a_ptrs
,
b_ptrs
,
a_scales_ptrs
,
b_scales_ptrs
,
stride_a
,
stride_b
,
stride_c
,
layout_sfa
,
layout_sfb
,
problem_sizes
,
expert_offsets
,
workspace
);
}
}
}
/**
/**
...
@@ -641,7 +692,7 @@ void fp8_blockwise_scaled_grouped_mm(
...
@@ -641,7 +692,7 @@ void fp8_blockwise_scaled_grouped_mm(
#endif
#endif
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) && defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED)
if
(
sm_version
==
90
&&
a
.
size
(
1
)
>
256
)
{
if
(
sm_version
==
90
)
{
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
if
(
output
.
scalar_type
()
==
torch
::
kBFloat16
)
{
sm90_fp8_blockwise_group_mm_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
sm90_fp8_blockwise_group_mm_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
output
,
output
,
...
@@ -687,8 +738,5 @@ void fp8_blockwise_scaled_grouped_mm(
...
@@ -687,8 +738,5 @@ void fp8_blockwise_scaled_grouped_mm(
}
}
#endif
#endif
TORCH_CHECK_NOT_IMPLEMENTED
(
TORCH_CHECK_NOT_IMPLEMENTED
(
can_implement
,
can_implement
,
"No implemented fp8_blockwise_scaled_mm for current compute capability: "
,
sm_version
);
"No implemented fp8_blockwise_scaled_mm for current compute capability or K size: "
,
sm_version
,
a
.
size
(
1
));
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment