Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5bb0accb
"docs/source/vscode:/vscode.git/clone" did not exist on "8a07ab77376a99b7114d0850ff99331ed88a648e"
Unverified
Commit
5bb0accb
authored
Apr 29, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Apr 28, 2025
Browse files
cutlass 3.9 supported to improve fp8_blockwise_gemm (#5820)
parent
8d463fe3
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
18 additions
and
27 deletions
+18
-27
sgl-kernel/CMakeLists.txt
sgl-kernel/CMakeLists.txt
+1
-1
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
+17
-26
No files found.
sgl-kernel/CMakeLists.txt
View file @
5bb0accb
...
@@ -43,7 +43,7 @@ include(FetchContent)
...
@@ -43,7 +43,7 @@ include(FetchContent)
FetchContent_Declare
(
FetchContent_Declare
(
repo-cutlass
repo-cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_REPOSITORY https://github.com/NVIDIA/cutlass
GIT_TAG
5e497243f7ad13a2aa842143f9b10bbb23d98292
GIT_TAG
e94e888df3551224738bfa505787b515eae8352f
GIT_SHALLOW OFF
GIT_SHALLOW OFF
)
)
FetchContent_Populate
(
repo-cutlass
)
FetchContent_Populate
(
repo-cutlass
)
...
...
sgl-kernel/csrc/gemm/fp8_blockwise_gemm_kernel.cu
View file @
5bb0accb
...
@@ -34,12 +34,7 @@
...
@@ -34,12 +34,7 @@
using
namespace
cute
;
using
namespace
cute
;
template
<
template
<
typename
SchedulerType
,
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
>
typename
SchedulerType
,
typename
OutType
,
typename
TileShape
,
typename
ClusterShape
,
typename
ScaleGranularity
>
void
launch_sm90_fp8_blockwise_scaled_mm
(
void
launch_sm90_fp8_blockwise_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
a
,
...
@@ -66,8 +61,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -66,8 +61,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
constexpr
int
AlignmentD
=
AlignmentC
;
constexpr
int
AlignmentD
=
AlignmentC
;
static
constexpr
int
ScaleGranularityM
=
size
<
0
>
(
ScaleGranularity
{});
using
ScaleTileShape
=
Shape
<
_1
,
_128
,
_128
>
;
static
constexpr
int
ScaleGranularityN
=
size
<
1
>
(
ScaleGranularity
{});
using
ScaleConfig
=
decltype
(
cutlass
::
detail
::
sm90_trivial_blockwise_scale_config
(
ScaleTileShape
{}));
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
...
@@ -75,8 +72,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -75,8 +72,7 @@ void launch_sm90_fp8_blockwise_scaled_mm(
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
EpilogueTileType
=
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
StoreEpilogueCompute
=
typename
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
>
;
using
KernelSchedule
=
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
;
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8BlockScaledAccum
<
ScaleGranularityM
,
ScaleGranularityN
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
ArchTag
,
OperatorClass
,
OperatorClass
,
...
@@ -98,10 +94,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -98,10 +94,10 @@ void launch_sm90_fp8_blockwise_scaled_mm(
ArchTag
,
ArchTag
,
OperatorClass
,
OperatorClass
,
ElementA
,
ElementA
,
LayoutA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
AlignmentA
,
ElementB
,
ElementB
,
LayoutB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
AlignmentB
,
ElementAccumulator
,
ElementAccumulator
,
TileShape
,
TileShape
,
...
@@ -140,7 +136,11 @@ void launch_sm90_fp8_blockwise_scaled_mm(
...
@@ -140,7 +136,11 @@ void launch_sm90_fp8_blockwise_scaled_mm(
StrideC
stride_c
;
StrideC
stride_c
;
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
cute
::
make_shape
(
m
,
n
,
1
));
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
cute
::
make_shape
(
m
,
n
,
1
));
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
,
4
,
a_s_ptr
,
b_s_ptr
};
LayoutSFA
layout_sfa
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_sfb
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
,
4
,
a_s_ptr
,
layout_sfa
,
b_s_ptr
,
layout_sfb
};
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{{},
nullptr
,
stride_d
,
o_ptr
,
stride_d
};
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{{},
nullptr
,
stride_d
,
o_ptr
,
stride_d
};
typename
Gemm
::
Arguments
args
=
{
typename
Gemm
::
Arguments
args
=
{
...
@@ -306,24 +306,15 @@ void sm90_fp8_blockwise_dispatch_shape(
...
@@ -306,24 +306,15 @@ void sm90_fp8_blockwise_dispatch_shape(
const
torch
::
Tensor
&
scales_b
)
{
const
torch
::
Tensor
&
scales_b
)
{
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
ScaleGranularity
=
Shape
<
_1
,
_128
,
_128
>
;
auto
k
=
a
.
size
(
1
);
auto
k
=
a
.
size
(
1
);
auto
n
=
b
.
size
(
1
);
auto
n
=
b
.
size
(
1
);
if
(
k
>
3
*
n
)
{
if
(
k
>
3
*
n
)
{
launch_sm90_fp8_blockwise_scaled_mm
<
launch_sm90_fp8_blockwise_scaled_mm
<
cutlass
::
gemm
::
StreamKScheduler
,
OutType
,
TileShape
,
ClusterShape
>
(
cutlass
::
gemm
::
StreamKScheduler
,
out
,
a
,
b
,
scales_a
,
scales_b
);
OutType
,
TileShape
,
ClusterShape
,
ScaleGranularity
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
else
{
}
else
{
launch_sm90_fp8_blockwise_scaled_mm
<
launch_sm90_fp8_blockwise_scaled_mm
<
cutlass
::
gemm
::
PersistentScheduler
,
OutType
,
TileShape
,
ClusterShape
>
(
cutlass
::
gemm
::
PersistentScheduler
,
out
,
a
,
b
,
scales_a
,
scales_b
);
OutType
,
TileShape
,
ClusterShape
,
ScaleGranularity
>
(
out
,
a
,
b
,
scales_a
,
scales_b
);
}
}
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment