Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cdd70259
Unverified
Commit
cdd70259
authored
Nov 14, 2025
by
czhu-cohere
Committed by
GitHub
Nov 14, 2025
Browse files
[kernel] Improve FP8 PTPC on Hopper for larger shapes (#28692)
Signed-off-by:
czhu-cohere
<
conway.zhu@cohere.com
>
parent
08542480
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
0 deletions
+27
-0
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh
...tization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh
+27
-0
No files found.
csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh
View file @
cdd70259
...
@@ -116,6 +116,26 @@ struct sm90_fp8_config_default {
...
@@ -116,6 +116,26 @@ struct sm90_fp8_config_default {
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
;
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
;
};
};
template
<
typename
InType
,
typename
OutType
,
bool
EnableBias
>
struct
sm90_fp8_config_M8192_K6144
{
// M >= 8192, K >= 6144
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
KernelTmaWarpSpecializedCooperativeFP8FastAccum
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
using
TileShape
=
Shape
<
_256
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
Cutlass3xGemm
=
conditional_t
<
EnableBias
,
cutlass_3x_gemm_sm90_fp8
<
InType
,
OutType
,
c3x
::
ScaledEpilogueBias
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
,
cutlass_3x_gemm_sm90_fp8
<
InType
,
OutType
,
c3x
::
ScaledEpilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
;
};
template
<
typename
InType
,
typename
OutType
,
bool
EnableBias
>
template
<
typename
InType
,
typename
OutType
,
bool
EnableBias
>
struct
sm90_fp8_config_M128
{
struct
sm90_fp8_config_M128
{
// M in (64, 128]
// M in (64, 128]
...
@@ -273,6 +293,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
...
@@ -273,6 +293,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
using
Cutlass3xGemmDefault
=
using
Cutlass3xGemmDefault
=
typename
sm90_fp8_config_default
<
InType
,
OutType
,
typename
sm90_fp8_config_default
<
InType
,
OutType
,
EnableBias
>::
Cutlass3xGemm
;
EnableBias
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM8192_K6144
=
typename
sm90_fp8_config_M8192_K6144
<
InType
,
OutType
,
EnableBias
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
using
Cutlass3xGemmM128
=
typename
sm90_fp8_config_M128
<
InType
,
OutType
,
EnableBias
>::
Cutlass3xGemm
;
typename
sm90_fp8_config_M128
<
InType
,
OutType
,
EnableBias
>::
Cutlass3xGemm
;
...
@@ -291,6 +314,7 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
...
@@ -291,6 +314,7 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
n
=
b
.
size
(
1
);
uint32_t
const
n
=
b
.
size
(
1
);
uint32_t
const
k
=
a
.
size
(
1
);
if
(
m
<=
16
)
{
if
(
m
<=
16
)
{
// m in [1, 16]
// m in [1, 16]
...
@@ -312,6 +336,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
...
@@ -312,6 +336,9 @@ inline void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out,
// m in (64, 128]
// m in (64, 128]
return
cutlass_gemm_caller_sm90_fp8
<
Cutlass3xGemmM128
>
(
return
cutlass_gemm_caller_sm90_fp8
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
out
,
a
,
b
,
a_scales
,
b_scales
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
m
>=
8192
&&
k
>=
6144
)
{
return
cutlass_gemm_caller_sm90_fp8
<
Cutlass3xGemmM8192_K6144
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
}
else
{
// m in (128, inf)
// m in (128, inf)
return
cutlass_gemm_caller_sm90_fp8
<
Cutlass3xGemmDefault
>
(
return
cutlass_gemm_caller_sm90_fp8
<
Cutlass3xGemmDefault
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment