Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5bb0accb
Unverified
Commit
5bb0accb
authored
Apr 29, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Apr 28, 2025
Browse files
cutlass 3.9 supported to improve fp8_blockwise_gemm (#5820)
parent
8d463fe3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
27 deletions
+18
-27
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-1
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+17
-26
No files found.
sgl-kernel/CMakeLists.txt
View file @
5bb0accb
...
@@ -43,7 +43,7 @@ include(FetchContent)
...
@@ -43,7 +43,7 @@ include(FetchContent)
FetchContent_Declare
(
FetchContent_Declare
(
repo-cutlass
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG
5e497243f7ad13a2aa842143f9b10bbb23d98292
GIT_TAG
e94e888df3551224738bfa505787b515eae8352f
GIT_SHALLOW OFF
GIT_SHALLOW OFF
)
)
FetchContent_Populate
(
repo-cutlass
)
FetchContent_Populate
(
repo-cutlass
)
...
...
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
5bb0accb
...
@@ -34,12 +34,7 @@
...
@@ -34,12 +34,7 @@
using
namespace
cute
;
using
namespace
cute
;
template
<
template
<
typename
SchedulerType
,
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
>
typename
SchedulerType
,
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
,
typename
ScaleGranularity
>
void
launch_sm90_fp8_blockwise_scaled_mm
(
void
launch_sm90_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
a
,
...
@@ -66,8 +61,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -66,8 +61,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentD
=
AlignmentC
;
constexpr
int
AlignmentD
=
AlignmentC
;
static
constexpr
int
ScaleGranularityM
=
size
<
0
>
(
ScaleGranularity
{});
using
ScaleTileShape
=
Shape
<
_1
,
_128
,
_128
>
;
static
constexpr
int
ScaleGranularityN
=
size
<
1
>
(
ScaleGranularity
{});
using
ScaleConfig
=
decltype
(
cutlass
::
detail
::
sm90_trivial_blockwise_scale_config
(
ScaleTileShape
{}));
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
...
@@ -75,8 +72,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -75,8 +72,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
KernelSchedule
=
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
;
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
<
ScaleGranularityM
,
ScaleGranularityN
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
ArchTag
,
OperatorClass
,
OperatorClass
,
...
@@ -98,10 +94,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -98,10 +94,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
ArchTag
,
ArchTag
,
OperatorClass
,
OperatorClass
,
ElementA
,
ElementA
,
LayoutA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
AlignmentA
,
ElementB
,
ElementB
,
LayoutB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
AlignmentB
,
ElementAccumulator
,
ElementAccumulator
,
TileShape
,
TileShape
,
...
@@ -140,7 +136,11 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -140,7 +136,11 @@ void launch_sm90_fp8_blockwise_scaled_mm(
StrideC
stride_c
;
StrideC
stride_c
;
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
cute
::
make_shape
(
m
,
n
,
1
));
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
cute
::
make_shape
(
m
,
n
,
1
));
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
,
4
,
a_s_ptr
,
b_s_ptr
};
LayoutSFA
layout_sfa
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_sfb
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
,
4
,
a_s_ptr
,
layout_sfa
,
b_s_ptr
,
layout_sfb
};
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{{},
nullptr
,
stride_d
,
o_ptr
,
stride_d
};
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{{},
nullptr
,
stride_d
,
o_ptr
,
stride_d
};
typename
Gemm
::
Arguments
args
=
{
typename
Gemm
::
Arguments
args
=
{
...
@@ -306,24 +306,15 @@ void sm90_fp8_blockwise_dispatch_shape(
...
@@ -306,24 +306,15 @@ void sm90_fp8_blockwise_dispatch_shape(
const
torch
::
Tensor
&
scales_b
)
{
const
torch
::
Tensor
&
scales_b
)
{
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
ScaleGranularity
=
Shape
<
_1
,
_128
,
_128
>
;
auto
k
=
a
.
size
(
1
);
auto
k
=
a
.
size
(
1
);
auto
n
=
b
.
size
(
1
);
auto
n
=
b
.
size
(
1
);
if
(
k
>
3
*
n
)
{
if
(
k
>
3
*
n
)
{
launch_sm90_fp8_blockwise_scaled_mm
<
launch_sm90_fp8_blockwise_scaled_mm
<
cutlass
::
gemm
::
StreamKScheduler
,
OutType
,
TileShape
,
ClusterShape
>
(
cutlass
::
gemm
::
StreamKScheduler
,
out
,
a
,
b
,
scales_a
,
scales_b
);
OutType
,
TileShape
,
ClusterShape
,
ScaleGranularity
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
else
{
}
else
{
launch_sm90_fp8_blockwise_scaled_mm
<
launch_sm90_fp8_blockwise_scaled_mm
<
cutlass
::
gemm
::
PersistentScheduler
,
OutType
,
TileShape
,
ClusterShape
>
(
cutlass
::
gemm
::
PersistentScheduler
,
out
,
a
,
b
,
scales_a
,
scales_b
);
OutType
,
TileShape
,
ClusterShape
,
ScaleGranularity
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment