Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
580090db
Unverified
Commit
580090db
authored
Apr 03, 2026
by
Necofish
Committed by
GitHub
Apr 03, 2026
Browse files
[Kernel] Add swapAB support for SM120 CUTLASS blockwise FP8 GEMM (#38325)
parent
cb10b7e8
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
82 additions
and
26 deletions
+82
-26
csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
...a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
+82
-26
No files found.
csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
View file @
580090db
...
@@ -26,8 +26,10 @@ using namespace cute;
...
@@ -26,8 +26,10 @@ using namespace cute;
template
<
class
OutType
,
int
ScaleGranularityM
,
template
<
class
OutType
,
int
ScaleGranularityM
,
int
ScaleGranularityN
,
int
ScaleGranularityK
,
int
ScaleGranularityN
,
int
ScaleGranularityK
,
class
MmaTileShape
,
class
ClusterShape
,
class
MmaTileShape
,
class
ClusterShape
,
class
EpilogueScheduler
,
class
MainloopScheduler
>
class
EpilogueScheduler
,
class
MainloopScheduler
,
bool
swap_ab_
=
false
>
struct
cutlass_3x_gemm_fp8_blockwise
{
struct
cutlass_3x_gemm_fp8_blockwise
{
static
constexpr
bool
swap_ab
=
swap_ab_
;
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
ElementA
=
ElementAB
;
...
@@ -55,9 +57,13 @@ struct cutlass_3x_gemm_fp8_blockwise {
...
@@ -55,9 +57,13 @@ struct cutlass_3x_gemm_fp8_blockwise {
using
ElementCompute
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
using
ElementBlockScale
=
float
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm120BlockwiseScaleConfig
<
using
ScaleConfig
=
conditional_t
<
swap_ab
,
cutlass
::
detail
::
Sm120BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>
;
cute
::
UMMA
::
Major
::
K
,
cute
::
UMMA
::
Major
::
MN
>
,
cutlass
::
detail
::
Sm120BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>>
;
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
...
@@ -78,17 +84,32 @@ struct cutlass_3x_gemm_fp8_blockwise {
...
@@ -78,17 +84,32 @@ struct cutlass_3x_gemm_fp8_blockwise {
ElementAccumulator
,
ElementAccumulator
,
ElementCompute
,
ElementCompute
,
ElementC
,
ElementC
,
LayoutC
,
conditional_t
<
swap_ab
,
LayoutC_Transpose
,
LayoutC
>
,
AlignmentC
,
AlignmentC
,
ElementD
,
ElementD
,
LayoutD
,
conditional_t
<
swap_ab
,
LayoutD_Transpose
,
LayoutD
>
,
AlignmentD
,
AlignmentD
,
EpilogueScheduler
,
EpilogueScheduler
,
DefaultOperation
DefaultOperation
>::
CollectiveOp
;
>::
CollectiveOp
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
using
CollectiveMainloop
=
using
CollectiveMainloop
=
conditional_t
<
swap_ab
,
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementB
,
cute
::
tuple
<
LayoutB_Transpose
,
LayoutSFA
>
,
AlignmentB
,
ElementA
,
cute
::
tuple
<
LayoutA_Transpose
,
LayoutSFB
>
,
AlignmentA
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
,
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
ArchTag
,
OperatorClass
,
OperatorClass
,
...
@@ -103,7 +124,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
...
@@ -103,7 +124,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
ClusterShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
MainloopScheduler
>::
CollectiveOp
;
>::
CollectiveOp
>
;
// SM12x family to support both SM120 (RTX 5090) and SM121 (DGX Spark)
// SM12x family to support both SM120 (RTX 5090) and SM121 (DGX Spark)
using
KernelType
=
enable_sm120_family
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
using
KernelType
=
enable_sm120_family
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
...
@@ -115,7 +136,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
...
@@ -115,7 +136,7 @@ struct cutlass_3x_gemm_fp8_blockwise {
// Tile configurations for different M ranges
// Tile configurations for different M ranges
template
<
typename
OutType
>
template
<
typename
OutType
>
struct
sm120_blockwise_fp8_config_default
{
struct
sm120_blockwise_fp8_config_default
{
//
M > 256:
use 128x128x128 tile with Cooperative (Auto) schedule
// use 128x128x128 tile with Cooperative (Auto) schedule
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
...
@@ -127,8 +148,8 @@ struct sm120_blockwise_fp8_config_default {
...
@@ -127,8 +148,8 @@ struct sm120_blockwise_fp8_config_default {
};
};
template
<
typename
OutType
>
template
<
typename
OutType
>
struct
sm120_blockwise_fp8_config_
M64
{
struct
sm120_blockwise_fp8_config_
pingpong
{
//
M in [1, 256]:
use 64x128x128 tile with Pingpong schedule
// use 64x128x128 tile with Pingpong schedule
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwisePingpongSm120
;
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwisePingpongSm120
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
using
TileShape
=
Shape
<
_64
,
_128
,
_128
>
;
...
@@ -139,11 +160,24 @@ struct sm120_blockwise_fp8_config_M64 {
...
@@ -139,11 +160,24 @@ struct sm120_blockwise_fp8_config_M64 {
EpilogueSchedule
,
KernelSchedule
>
;
EpilogueSchedule
,
KernelSchedule
>
;
};
};
template
<
typename
OutType
>
struct
sm120_blockwise_fp8_config_swapab
{
// use 128x32x128 tile with Cooperative schedule
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwiseCooperativeSm120
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_128
,
_32
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Gemm
=
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
128
,
1
,
128
,
TileShape
,
ClusterShape
,
EpilogueSchedule
,
KernelSchedule
,
true
>
;
};
template
<
typename
Gemm
>
template
<
typename
Gemm
>
void
cutlass_gemm_caller_blockwise
(
torch
::
stable
::
Tensor
&
out
,
torch
::
stable
::
Tensor
const
&
a
,
void
cutlass_gemm_caller_blockwise
(
torch
::
stable
::
Tensor
&
out
,
torch
::
stable
::
Tensor
const
&
a
,
torch
::
stable
::
Tensor
const
&
b
,
torch
::
stable
::
Tensor
const
&
b
,
torch
::
stable
::
Tensor
const
&
a_scales
,
torch
::
stable
::
Tensor
const
&
a_scales
,
torch
::
stable
::
Tensor
const
&
b_scales
)
{
torch
::
stable
::
Tensor
const
&
b_scales
)
{
static
constexpr
bool
swap_ab
=
Gemm
::
swap_ab
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
...
@@ -167,11 +201,13 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
...
@@ -167,11 +201,13 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
b_stride
=
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
c_stride
=
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
cutlass
::
make_cute_packed_stride
(
StrideC
{},
swap_ab
?
cute
::
make_shape
(
n
,
m
,
1
)
:
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
LayoutSFA
layout_SFA
=
swap_ab
?
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
n
,
m
,
k
,
1
))
:
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
LayoutSFB
layout_SFB
=
swap_ab
?
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
n
,
m
,
k
,
1
))
:
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
...
@@ -180,15 +216,24 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
...
@@ -180,15 +216,24 @@ void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Te
auto
b_scales_ptr
=
static_cast
<
ElementBlockScale
const
*>
(
b_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
ElementBlockScale
const
*>
(
b_scales
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{};
typename
GemmKernel
::
MainloopArguments
mainloop_args
{};
mainloop_args
.
layout_SFA
=
layout_SFA
;
mainloop_args
.
layout_SFB
=
layout_SFB
;
if
(
swap_ab
)
{
mainloop_args
.
ptr_A
=
b_ptr
;
mainloop_args
.
dA
=
b_stride
;
mainloop_args
.
ptr_B
=
a_ptr
;
mainloop_args
.
dB
=
a_stride
;
mainloop_args
.
ptr_SFA
=
b_scales_ptr
;
mainloop_args
.
ptr_SFB
=
a_scales_ptr
;
}
else
{
mainloop_args
.
ptr_A
=
a_ptr
;
mainloop_args
.
ptr_A
=
a_ptr
;
mainloop_args
.
dA
=
a_stride
;
mainloop_args
.
dA
=
a_stride
;
mainloop_args
.
ptr_B
=
b_ptr
;
mainloop_args
.
ptr_B
=
b_ptr
;
mainloop_args
.
dB
=
b_stride
;
mainloop_args
.
dB
=
b_stride
;
mainloop_args
.
ptr_SFA
=
a_scales_ptr
;
mainloop_args
.
ptr_SFA
=
a_scales_ptr
;
mainloop_args
.
layout_SFA
=
layout_SFA
;
mainloop_args
.
ptr_SFB
=
b_scales_ptr
;
mainloop_args
.
ptr_SFB
=
b_scales_ptr
;
mainloop_args
.
layout_SFB
=
layout_SFB
;
}
auto
prob_shape
=
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
prob_shape
=
swap_ab
?
cute
::
make_shape
(
n
,
m
,
k
,
1
)
:
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
...
@@ -204,8 +249,12 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
...
@@ -204,8 +249,12 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
torch
::
stable
::
Tensor
const
&
a_scales
,
torch
::
stable
::
Tensor
const
&
a_scales
,
torch
::
stable
::
Tensor
const
&
b_scales
)
{
torch
::
stable
::
Tensor
const
&
b_scales
)
{
int
M
=
a
.
size
(
0
);
int
M
=
a
.
size
(
0
);
// more heuristic tuning can be done here by checking N/K dimensions as well
bool
swap_ab
=
(
M
<=
64
)
||
(
M
%
4
!=
0
);
if
(
!
swap_ab
)
{
if
(
M
<=
256
)
{
if
(
M
<=
256
)
{
using
Gemm
=
typename
sm120_blockwise_fp8_config_
M64
<
OutType
>::
Gemm
;
using
Gemm
=
typename
sm120_blockwise_fp8_config_
pingpong
<
OutType
>::
Gemm
;
return
cutlass_gemm_caller_blockwise
<
Gemm
>
(
return
cutlass_gemm_caller_blockwise
<
Gemm
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
...
@@ -213,6 +262,13 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
...
@@ -213,6 +262,13 @@ void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out,
using
Gemm
=
typename
sm120_blockwise_fp8_config_default
<
OutType
>::
Gemm
;
using
Gemm
=
typename
sm120_blockwise_fp8_config_default
<
OutType
>::
Gemm
;
return
cutlass_gemm_caller_blockwise
<
Gemm
>
(
return
cutlass_gemm_caller_blockwise
<
Gemm
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
// Swap A/B for small M to improve performance
// Use TILE_N=32 as the minimum compatible tile size.
using
Gemm
=
typename
sm120_blockwise_fp8_config_swapab
<
OutType
>::
Gemm
;
return
cutlass_gemm_caller_blockwise
<
Gemm
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
}
// namespace vllm
}
// namespace vllm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment