Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3303f134
Unverified
Commit
3303f134
authored
Aug 07, 2025
by
Junhao Li
Committed by
GitHub
Aug 07, 2025
Browse files
[Kernel] Add support for block FP8 on SM120 (NVIDIA 5090 and RTX PRO 6000) (#22131)
Signed-off-by:
Junhao Li
<
junhao@ubicloud.com
>
parent
b2c8ce57
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
229 additions
and
18 deletions
+229
-18
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/cutlass_extensions/common.hpp
csrc/cutlass_extensions/common.hpp
+10
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu
...ization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu
+23
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
...tlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
+183
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+6
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
+6
-18
No files found.
CMakeLists.txt
View file @
3303f134
...
...
@@ -427,6 +427,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set
(
SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm120_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
...
...
csrc/cutlass_extensions/common.hpp
View file @
3303f134
...
...
@@ -60,3 +60,13 @@ struct enable_sm100_only : Kernel {
#endif
}
};
template
<
typename
Kernel
>
struct
enable_sm120_only
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1200
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8.cu
0 → 100644
View file @
3303f134
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_blockwise_sm120_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
cutlass_gemm_blockwise_sm120_fp8_dispatch
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
cutlass_gemm_blockwise_sm120_fp8_dispatch
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh
0 → 100644
View file @
3303f134
#pragma once
#include "cuda_utils.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
using
namespace
cute
;
// clang-format off
template
<
class
OutType
,
int
ScaleGranularityM
,
int
ScaleGranularityN
,
int
ScaleGranularityK
,
class
MmaTileShape
,
class
ClusterShape
,
class
EpilogueScheduler
,
class
MainloopScheduler
>
struct
cutlass_3x_gemm_fp8_blockwise
{
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
using
LayoutA_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutA
>::
type
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementAB
;
// ColumnMajor is used for B to match the CUTLASS convention.
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
using
LayoutB_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutB
>::
type
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementD
=
OutType
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
using
LayoutD_Transpose
=
typename
cutlass
::
layout
::
LayoutTranspose
<
LayoutD
>::
type
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
ElementC
=
void
;
// TODO: support bias
using
LayoutC
=
LayoutD
;
using
LayoutC_Transpose
=
LayoutD_Transpose
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm120BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>
;
// layout_SFA and layout_SFB cannot be swapped since they are deduced.
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
ArchTag
=
cutlass
::
arch
::
Sm120
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
static
constexpr
auto
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
using
ElementScalar
=
float
;
using
DefaultOperation
=
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
ElementCompute
,
ElementC
,
ElementScalar
,
RoundStyle
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
MmaTileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueScheduler
,
DefaultOperation
>::
CollectiveOp
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
;
using
KernelType
=
enable_sm120_only
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
>>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
>
void
cutlass_gemm_caller_blockwise
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
LayoutSFA
=
typename
Gemm
::
LayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
LayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
ScaleConfig
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
StrideA
a_stride
;
StrideB
b_stride
;
StrideC
c_stride
;
a_stride
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
auto
mainloop_args
=
[
&
](){
return
typename
GemmKernel
::
MainloopArguments
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
};
}();
auto
prob_shape
=
cute
::
make_shape
(
m
,
n
,
k
,
1
);
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
template
<
typename
OutType
>
void
cutlass_gemm_blockwise_sm120_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
// TODO: better heuristics
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
1
,
128
,
128
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
,
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
View file @
3303f134
...
...
@@ -47,4 +47,10 @@ void cutlass_scaled_mm_blockwise_sm100_fp8(torch::Tensor& out,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_blockwise_sm120_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm120.cu
View file @
3303f134
#include
<cudaTypedefs.h>
#include
"c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm120 (Blackwell
Geforce
).
NVIDIA GPUs with sm120 (Blackwell).
*/
#if defined ENABLE_SCALED_MM_SM120 && ENABLE_SCALED_MM_SM120
...
...
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm120(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
TORCH_CHECK
(
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell"
);
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently, only fp8 gemm is implemented for Blackwell"
);
vllm
::
cutlass_scaled_mm_sm120_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
dispatch_scaled_mm
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
,
vllm
::
cutlass_scaled_mm_sm120_fp8
,
nullptr
,
// int8 not supported on SM120
vllm
::
cutlass_scaled_mm_blockwise_sm120_fp8
);
}
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment