Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cd436736
Unverified
Commit
cd436736
authored
Feb 25, 2026
by
wenshuai
Committed by
GitHub
Feb 24, 2026
Browse files
[Perf] Optimize FP8 gemm of sm120. (#34424)
Signed-off-by:
wenshuai
<
wenshuai@xiaomi.com
>
parent
35d44b45
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
133 additions
and
1 deletion
+133
-1
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh
...ization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh
+133
-1
No files found.
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh
View file @
cd436736
...
@@ -12,6 +12,68 @@ namespace vllm {
...
@@ -12,6 +12,68 @@ namespace vllm {
using
c3x
::
cutlass_gemm_caller
;
using
c3x
::
cutlass_gemm_caller
;
// Custom wrapper to allow specifying EpilogueTile for small M
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
,
typename
EpilogueTile
>
struct
cutlass_3x_gemm_sm120_custom
{
using
ElementAB
=
ElementAB_
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementD_
>::
value
;
using
ElementD
=
ElementD_
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
AlignmentC
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
TileShape
>
;
// MMA type
using
ElementAccumulator
=
float
;
// Epilogue types
using
ElementBias
=
cutlass
::
half_t
;
using
ElementCompute
=
float
;
using
ElementAux
=
ElementD
;
using
LayoutAux
=
LayoutD
;
using
ElementAmax
=
float
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm120
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
EpilogueTile
,
// Use custom EpilogueTile
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm120
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementAB
,
LayoutA
,
AlignmentA
,
ElementAB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
,
void
>::
CollectiveOp
;
using
GemmKernel
=
enable_sm120_only
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm120_fp8_config_default
{
struct
sm120_fp8_config_default
{
...
@@ -25,6 +87,54 @@ struct sm120_fp8_config_default {
...
@@ -25,6 +87,54 @@ struct sm120_fp8_config_default {
KernelSchedule
,
EpilogueSchedule
>
;
KernelSchedule
,
EpilogueSchedule
>
;
};
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm120_fp8_config_M64
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
// SM120 Cooperative kernel requires Tile M >= 128.
// For M=64 tile, we use Pingpong schedule which is more flexible with small
// tiles.
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_64
,
_64
,
_128
>
;
// CUTLASS 3.x on SM120 currently restricts programmatic multicast (Cluster >
// 1) for certain schedules/types. Reverting to 1x1x1 to ensure compilation.
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm120
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm120_fp8_config_M32
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_32
,
_64
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Use custom gemm to specify EpilogueTile M=32
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm120_custom
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
,
Shape
<
_32
,
_32
>>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm120_fp8_config_M16
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_16
,
_64
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
// Use custom gemm to specify EpilogueTile M=16
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm120_custom
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
,
Shape
<
_16
,
_32
>>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
typename
...
EpilogueArgs
>
...
@@ -36,6 +146,28 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
...
@@ -36,6 +146,28 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out,
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
int
M
=
a
.
size
(
0
);
if
(
M
<=
16
)
{
using
Cutlass3xGemmM16
=
typename
sm120_fp8_config_M16
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
cutlass_gemm_caller
<
Cutlass3xGemmM16
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
if
(
M
<=
32
)
{
using
Cutlass3xGemmM32
=
typename
sm120_fp8_config_M32
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
cutlass_gemm_caller
<
Cutlass3xGemmM32
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
if
(
M
<=
256
)
{
using
Cutlass3xGemmM64
=
typename
sm120_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
cutlass_gemm_caller
<
Cutlass3xGemmM64
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
using
Cutlass3xGemmDefault
=
using
Cutlass3xGemmDefault
=
typename
sm120_fp8_config_default
<
InType
,
OutType
,
typename
sm120_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
Epilogue
>::
Cutlass3xGemm
;
...
@@ -64,4 +196,4 @@ void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
...
@@ -64,4 +196,4 @@ void cutlass_scaled_mm_sm120_fp8_epilogue(torch::Tensor& out,
}
}
}
}
}
// namespace vllm
}
// namespace vllm
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment