Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e13945f9
Unverified
Commit
e13945f9
authored
Jun 15, 2025
by
Ilya Markov
Committed by
GitHub
Jun 14, 2025
Browse files
[Perf] Further tunings for SM100 FP8 CUTLASS kernel (#19566)
parent
08500011
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
25 additions
and
5 deletions
+25
-5
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
...ization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
+25
-5
No files found.
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
View file @
e13945f9
...
@@ -15,11 +15,25 @@ using c3x::cutlass_gemm_caller;
...
@@ -15,11 +15,25 @@ using c3x::cutlass_gemm_caller;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_default
{
struct
sm100_fp8_config_default
{
// M in (
128
, inf)
// M in (
256
, inf)
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_256
,
_128
,
_64
>
;
using
TileShape
=
Shape
<
_256
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_M256
{
// M in (128, 256]
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_2
,
_2
,
_1
>
;
using
ClusterShape
=
Shape
<
_2
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
...
@@ -33,8 +47,8 @@ struct sm100_fp8_config_M128 {
...
@@ -33,8 +47,8 @@ struct sm100_fp8_config_M128 {
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_128
,
_128
,
_6
4
>
;
using
TileShape
=
Shape
<
_128
,
_128
,
_
25
6
>
;
using
ClusterShape
=
Shape
<
_2
,
_
2
,
_1
>
;
using
ClusterShape
=
Shape
<
_2
,
_
4
,
_1
>
;
using
Cutlass3xGemm
=
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
KernelSchedule
,
EpilogueSchedule
>
;
...
@@ -72,6 +86,8 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
...
@@ -72,6 +86,8 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
typename
sm100_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
typename
sm100_fp8_config_M64
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM128
=
using
Cutlass3xGemmM128
=
typename
sm100_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
typename
sm100_fp8_config_M128
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
using
Cutlass3xGemmM256
=
typename
sm100_fp8_config_M256
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
m
=
a
.
size
(
0
);
uint32_t
const
mp2
=
uint32_t
const
mp2
=
...
@@ -85,8 +101,12 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
...
@@ -85,8 +101,12 @@ inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out,
// m in (64, 128]
// m in (64, 128]
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
return
cutlass_gemm_caller
<
Cutlass3xGemmM128
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
if
(
mp2
<=
256
)
{
// m in (128, 256]
return
cutlass_gemm_caller
<
Cutlass3xGemmM256
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
else
{
}
else
{
// m in (
128
, inf)
// m in (
256
, inf)
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment