Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
d47661f0
Unverified
Commit
d47661f0
authored
Jul 12, 2025
by
Michael Goin
Committed by
GitHub
Jul 11, 2025
Browse files
[Kernel] Basic tuned configs for NVFP4 CUTLASS dense GEMM (#20646)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
53fa4573
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
85 additions
and
50 deletions
+85
-50
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
+85
-50
No files found.
csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu
View file @
d47661f0
...
...
@@ -30,35 +30,40 @@
#include "cutlass/util/packed_stride.hpp"
#include "core/math.hpp"
using
namespace
cute
;
#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED)
// Kernel Perf config
template
<
typename
T
>
struct
KernelTraits
;
template
<
>
struct
KernelTraits
<
float
>
{
using
MmaTileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_128
,
_256
>
;
// Configuration for M in (256, inf)
struct
sm100_fp4_config_default
{
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_256
,
_256
>
;
};
template
<
>
struct
KernelTraits
<
cutlass
::
half_t
>
{
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
_4
,
_4
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_256
,
_256
>
;
// Configuration for M in (16, 256]
struct
sm100_fp4_config_M256
{
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_256
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_2
,
_1
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_128
,
_256
>
;
};
template
<
>
struct
KernelTraits
<
cutlass
::
bfloat16_t
>
{
using
MmaTileShape
=
Shape
<
_256
,
_256
,
_256
>
;
using
ClusterShape
=
Shape
<
_4
,
_4
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_256
,
_256
>
;
// Configuration for M in [1, 16]
struct
sm100_fp4_config_M16
{
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_128
,
_128
,
_256
>
;
using
ClusterShape
=
Shape
<
_1
,
_1
,
_1
>
;
using
PerSmTileShape_MNK
=
Shape
<
_128
,
_128
,
_256
>
;
};
template
<
typename
T
>
template
<
typename
Config
,
typename
OutType
>
struct
Fp4GemmSm100
{
// A matrix configuration
using
ElementA
=
cutlass
::
nv_float4_t
<
cutlass
::
float_e2m1_t
>
;
...
...
@@ -71,21 +76,22 @@ struct Fp4GemmSm100 {
static
constexpr
int
AlignmentB
=
32
;
// C/D matrix configuration
using
ElementD
=
T
;
using
ElementC
=
T
;
using
ElementD
=
OutType
;
using
ElementC
=
OutType
;
using
LayoutCTag
=
cutlass
::
layout
::
RowMajor
;
using
LayoutDTag
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementC
>::
value
;
// Kernel functional config
using
ElementAccumulator
=
float
;
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassBlockScaledTensorOp
;
//
Kernel Perf config
using
MmaTileShape
=
typename
KernelTraits
<
T
>::
Mma
TileShape
;
using
ClusterShape
=
typename
KernelTraits
<
T
>
::
ClusterShape
;
using
PerSmTileShape_MNK
=
typename
KernelTraits
<
T
>
::
PerSmTileShape_MNK
;
//
Use config's tile shapes
using
MmaTileShape
=
typename
Config
::
TileShape
;
using
ClusterShape
=
typename
Config
::
ClusterShape
;
using
PerSmTileShape_MNK
=
typename
Config
::
PerSmTileShape_MNK
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
...
...
@@ -119,22 +125,22 @@ struct Fp4GemmSm100 {
using
LayoutD
=
decltype
(
cute
::
make_layout
(
make_shape
(
0
,
0
,
0
),
StrideD
{}));
};
template
<
typename
T
>
typename
T
::
Gemm
::
Arguments
args_from_options
(
template
<
typename
Config
>
typename
Config
::
Gemm
::
Arguments
args_from_options
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
at
::
Tensor
const
&
alpha
,
int64_t
M
,
int64_t
N
,
int64_t
K
)
{
using
ElementA
=
typename
T
::
Gemm
::
ElementA
;
using
ElementB
=
typename
T
::
Gemm
::
ElementB
;
using
ElementA
=
typename
Config
::
Gemm
::
ElementA
;
using
ElementB
=
typename
Config
::
Gemm
::
ElementB
;
using
ElementSFA
=
cutlass
::
float_ue4m3_t
;
using
ElementSFB
=
cutlass
::
float_ue4m3_t
;
using
ElementD
=
typename
T
::
Gemm
::
ElementD
;
using
ElementD
=
typename
Config
::
Gemm
::
ElementD
;
using
ElementCompute
=
float
;
using
StrideA
=
typename
T
::
StrideA
;
using
StrideB
=
typename
T
::
StrideB
;
using
StrideD
=
typename
T
::
StrideD
;
using
Sm100BlkScaledConfig
=
typename
T
::
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
using
StrideA
=
typename
Config
::
StrideA
;
using
StrideB
=
typename
Config
::
StrideB
;
using
StrideD
=
typename
Config
::
StrideD
;
using
Sm100BlkScaledConfig
=
typename
Config
::
Gemm
::
GemmKernel
::
CollectiveMainloop
::
Sm1xxBlkScaledConfig
;
int
m
=
static_cast
<
int
>
(
M
);
int
n
=
static_cast
<
int
>
(
N
);
...
...
@@ -148,7 +154,7 @@ typename T::Gemm::Arguments args_from_options(
auto
layout_SFB
=
Sm100BlkScaledConfig
::
tile_atom_to_shape_SFB
(
cute
::
make_shape
(
m
,
n
,
k
,
1
));
typename
T
::
Gemm
::
Arguments
arguments
{
typename
Config
::
Gemm
::
Arguments
arguments
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
{
// Mainloop arguments
...
...
@@ -167,17 +173,17 @@ typename T::Gemm::Arguments args_from_options(
return
arguments
;
}
template
<
typename
T
>
template
<
typename
Config
>
void
runGemm
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
at
::
Tensor
const
&
alpha
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
)
{
typename
Fp4GemmSm100
<
T
>
::
Gemm
gemm
;
typename
Config
::
Gemm
gemm
;
auto
arguments
=
args_from_options
<
Fp4GemmSm100
<
T
>
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
);
args_from_options
<
Config
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
);
size_t
workspace_size
=
Fp4GemmSm100
<
T
>
::
Gemm
::
get_workspace_size
(
arguments
);
size_t
workspace_size
=
Config
::
Gemm
::
get_workspace_size
(
arguments
);
auto
const
workspace_options
=
torch
::
TensorOptions
().
dtype
(
torch
::
kUInt8
).
device
(
A
.
device
());
auto
workspace
=
torch
::
empty
(
workspace_size
,
workspace_options
);
...
...
@@ -188,12 +194,40 @@ void runGemm(at::Tensor& D, at::Tensor const& A, at::Tensor const& B,
CUTLASS_CHECK
(
gemm
.
run
(
arguments
,
workspace
.
data_ptr
(),
stream
));
}
// Dispatch function to select appropriate config based on M
template
<
typename
OutType
>
void
cutlass_fp4_gemm_dispatch
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
)
{
uint32_t
const
mp2
=
std
::
max
(
static_cast
<
uint32_t
>
(
16
),
next_pow_2
(
m
));
if
(
mp2
<=
16
)
{
// m in [1, 16]
runGemm
<
Fp4GemmSm100
<
sm100_fp4_config_M16
,
OutType
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
mp2
<=
256
)
{
// m in (16, 256]
runGemm
<
Fp4GemmSm100
<
sm100_fp4_config_M256
,
OutType
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
// m in (256, inf)
runGemm
<
Fp4GemmSm100
<
sm100_fp4_config_default
,
OutType
>>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
}
#else
template
<
typename
T
>
void
runGemm
(
at
::
Tensor
&
D
,
at
::
Tensor
const
&
A
,
at
::
Tensor
const
&
B
,
at
::
Tensor
const
&
A_sf
,
at
::
Tensor
const
&
B_sf
,
at
::
Tensor
const
&
alpha
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
)
{
template
<
typename
OutType
>
void
cutlass_fp4_gemm_dispatch
(
torch
::
Tensor
&
D
,
torch
::
Tensor
const
&
A
,
torch
::
Tensor
const
&
B
,
torch
::
Tensor
const
&
A_sf
,
torch
::
Tensor
const
&
B_sf
,
torch
::
Tensor
const
&
alpha
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
cudaStream_t
stream
)
{
TORCH_CHECK
(
false
,
"Unsupported CUTLASS version. Set VLLM_CUTLASS_SRC_DIR to "
"a CUTLASS 3.8 source directory to enable support."
);
...
...
@@ -271,12 +305,13 @@ void cutlass_scaled_fp4_mm_sm100a(torch::Tensor& D, torch::Tensor const& A,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
A
.
get_device
());
if
(
out_dtype
==
at
::
ScalarType
::
Half
)
{
runGemm
<
cutlass
::
half_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
cutlass_fp4_gemm_dispatch
<
cutlass
::
half_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
BFloat16
)
{
runGemm
<
cutlass
::
bfloat16_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
if
(
out_dtype
==
at
::
ScalarType
::
Float
)
{
runGemm
<
float
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
cutlass_fp4_gemm_dispatch
<
cutlass
::
bfloat16_t
>
(
D
,
A
,
B
,
A_sf
,
B_sf
,
alpha
,
m
,
n
,
k
,
stream
);
}
else
{
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm"
);
TORCH_CHECK
(
false
,
"Unsupported output data type of nvfp4 mm ("
,
out_dtype
,
")"
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment