Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
6e92da8f
Unverified
Commit
6e92da8f
authored
Jul 18, 2025
by
Qi Yuhang
Committed by
GitHub
Jul 17, 2025
Browse files
[Fix][Ready]Fix register spilling in cutlass nvfp4 gemm kernel on Blackwell (#8127)
parent
e1020dc5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
24 deletions
+28
-24
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
+28
-24
No files found.
sgl-kernel/csrc/gemm/nvfp4_scaled_mm_kernels.cu
View file @
6e92da8f
...
...
@@ -40,27 +40,21 @@ using namespace cute;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Kernel Perf config
template
<
typename
T
>
struct
KernelTraits
;
template
<
>
struct
KernelTraits
<
float
>
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_128
,
_256
>
;
};
template
<
>
struct
KernelTraits
<
cutlass
::
half_t
>
{
struct
KernelTraits
{
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
_4
,
_4
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
int
,
int
,
_1
>
;
using
EpilogueTile
=
Shape
<
_128
,
_64
>
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecialized2Sm
;
using
MainloopSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized2SmNvf4Sm100
;
};
template
<
>
struct
KernelTraits
<
cutlass
::
bfloat16_t
>
{
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
_4
,
_4
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_256
,
_256
>
;
struct
KernelTraits
<
float
>
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
int
,
int
,
_1
>
;
using
EpilogueTile
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
;
using
MainloopSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecialized1SmNvf4Sm100
;
};
template
<
typename
T
>
...
...
@@ -90,23 +84,26 @@ struct Fp4GemmSm100 {
// Kernel Perf config
using
MmaTileShape
=
typename
KernelTraits
<
T
>::
MmaTileShape
;
using
ClusterShape
=
typename
KernelTraits
<
T
>::
ClusterShape
;
using
PerSmTileShape_MNK
=
typename
KernelTraits
<
T
>::
PerSmTileShape_MNK
;
using
EpilogueTile
=
typename
KernelTraits
<
T
>::
EpilogueTile
;
using
EpilogueSchedule
=
typename
KernelTraits
<
T
>::
EpilogueSchedule
;
using
MainloopSchedule
=
typename
KernelTraits
<
T
>::
MainloopSchedule
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
PerSm
TileShape
_MNK
,
cutlass
::
arch
::
OpClassTensorOp
,
Mma
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTile
Auto
,
EpilogueTile
,
ElementAccumulator
,
ElementAccumulator
,
ElementC
,
void
,
LayoutCTag
,
AlignmentC
,
ElementD
,
LayoutDTag
,
AlignmentD
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
>::
CollectiveOp
;
EpilogueSchedule
,
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
float
,
void
,
float
>>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
...
...
@@ -122,7 +119,7 @@ struct Fp4GemmSm100 {
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
collective
::
Kernel
Schedule
Auto
>::
CollectiveOp
;
Mainloop
Schedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
...
...
@@ -191,6 +188,13 @@ typename T::Gemm::Arguments args_from_options(
stride_D
}};
auto
&
fusion_args
=
arguments
.
epilogue
.
thread
;
fusion_args
.
alpha_ptr
=
static_cast
<
ElementCompute
const
*>
(
alpha
.
data_ptr
());
if
constexpr
(
std
::
is_same_v
<
T
,
float
>
)
{
arguments
.
hw_info
.
cluster_shape
=
dim3
(
1
,
4
,
1
);
arguments
.
hw_info
.
cluster_shape_fallback
=
dim3
(
1
,
1
,
1
);
}
else
{
arguments
.
hw_info
.
cluster_shape
=
dim3
(
4
,
4
,
1
);
arguments
.
hw_info
.
cluster_shape_fallback
=
dim3
(
2
,
1
,
1
);
}
return
arguments
;
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment