Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
376786fa
Unverified
Commit
376786fa
authored
May 08, 2025
by
Shu Wang
Committed by
GitHub
May 08, 2025
Browse files
Add cutlass support for blackwell fp8 blockwise gemm (#14383)
Signed-off-by:
Shu Wang
<
shuw@nvidia.com
>
parent
4f605a6d
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
332 additions
and
64 deletions
+332
-64
CMakeLists.txt
CMakeLists.txt
+1
-0
csrc/cutlass_extensions/common.hpp
csrc/cutlass_extensions/common.hpp
+10
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
...ization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
+27
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
...tlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
+205
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
+57
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+5
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
+5
-17
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
+5
-46
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+2
-0
tests/kernels/quantization/test_cutlass_scaled_mm.py
tests/kernels/quantization/test_cutlass_scaled_mm.py
+3
-1
vllm/model_executor/layers/quantization/utils/fp8_utils.py
vllm/model_executor/layers/quantization/utils/fp8_utils.py
+12
-0
No files found.
CMakeLists.txt
View file @
376786fa
...
@@ -418,6 +418,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -418,6 +418,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set
(
SRCS
set
(
SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu"
)
)
set_gencode_flags_for_srcs
(
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
SRCS
"
${
SRCS
}
"
...
...
csrc/cutlass_extensions/common.hpp
View file @
376786fa
...
@@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel {
...
@@ -59,3 +59,13 @@ struct enable_sm90_only : Kernel {
#endif
#endif
}
}
};
};
template
<
typename
Kernel
>
struct
enable_sm100_only
:
Kernel
{
template
<
typename
...
Args
>
CUTLASS_DEVICE
void
operator
()(
Args
&&
...
args
)
{
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ == 1000
Kernel
::
operator
()(
std
::
forward
<
Args
>
(
args
)...);
#endif
}
};
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8.cu
0 → 100644
View file @
376786fa
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_blockwise_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a
.
size
(
0
)
%
4
==
0
,
"Input tensor must have a number of rows that is a multiple of 4. "
,
"but got: "
,
a
.
size
(
0
),
" rows."
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
cutlass_gemm_blockwise_sm100_fp8_dispatch
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
cutlass_gemm_blockwise_sm100_fp8_dispatch
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh
0 → 100644
View file @
376786fa
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/tile_scheduler_params.h"
#include "cutlass/epilogue/dispatch_policy.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/dispatch_policy.hpp"
#include "cutlass_extensions/gemm/collective/collective_builder.hpp"
#include "cutlass_gemm_caller.cuh"
namespace
vllm
{
using
namespace
cute
;
template
<
typename
OutType
,
typename
MmaTileShape
,
typename
ScalesPerTile
,
class
ClusterShape
,
typename
EpilogueScheduler
,
typename
MainloopScheduler
>
struct
cutlass_3x_gemm_fp8_blockwise
{
using
ElementAB
=
cutlass
::
float_e4m3_t
;
using
ElementA
=
ElementAB
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
using
ElementB
=
ElementAB
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
using
ElementC
=
void
;
using
ElementD
=
OutType
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
128
/
cutlass
::
sizeof_bits
<
ElementD
>::
value
;
using
LayoutC
=
LayoutD
;
static
constexpr
int
AlignmentC
=
AlignmentD
;
using
ElementAccumulator
=
float
;
using
ElementCompute
=
float
;
using
ElementBlockScale
=
float
;
// MMA and Cluster Tile Shapes
// Shape of the tile computed by tcgen05 MMA, could be across 2 SMs if Cluster
// Shape %2 == 0 using MmaTileShape_MNK = Shape<_128,_128,_128>;
static
constexpr
int
ScaleMsPerTile
=
size
<
0
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityM
=
size
<
0
>
(
MmaTileShape
{})
/
ScaleMsPerTile
;
static
constexpr
int
ScaleGranularityN
=
size
<
1
>
(
MmaTileShape
{})
/
size
<
1
>
(
ScalesPerTile
{});
static
constexpr
int
ScaleGranularityK
=
size
<
2
>
(
MmaTileShape
{})
/
size
<
2
>
(
ScalesPerTile
{});
// Shape of the threadblocks in a cluster
using
ClusterShape_MNK
=
ClusterShape
;
using
ScaleConfig
=
cutlass
::
detail
::
Sm100BlockwiseScaleConfig
<
ScaleGranularityM
,
ScaleGranularityN
,
ScaleGranularityK
,
cute
::
UMMA
::
Major
::
MN
,
cute
::
UMMA
::
Major
::
K
>
;
using
LayoutSFA
=
decltype
(
ScaleConfig
::
deduce_layoutSFA
());
using
LayoutSFB
=
decltype
(
ScaleConfig
::
deduce_layoutSFB
());
using
ArchTag
=
cutlass
::
arch
::
Sm100
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
static
constexpr
auto
RoundStyle
=
cutlass
::
FloatRoundStyle
::
round_to_nearest
;
using
ElementScalar
=
float
;
// clang-format off
using
DefaultOperation
=
cutlass
::
epilogue
::
fusion
::
LinearCombination
<
ElementD
,
ElementCompute
,
ElementC
,
ElementScalar
,
RoundStyle
>
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
MmaTileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueScheduler
,
DefaultOperation
>::
CollectiveOp
;
using
StageCountType
=
cutlass
::
gemm
::
collective
::
StageCountAuto
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementA
,
cute
::
tuple
<
LayoutA
,
LayoutSFA
>
,
AlignmentA
,
ElementB
,
cute
::
tuple
<
LayoutB
,
LayoutSFB
>
,
AlignmentB
,
ElementAccumulator
,
MmaTileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
MainloopScheduler
>::
CollectiveOp
;
// clang-format on
using
KernelType
=
enable_sm100_only
<
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
>>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
>
void
cutlass_gemm_caller_blockwise
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
LayoutSFA
=
typename
Gemm
::
LayoutSFA
;
using
LayoutSFB
=
typename
Gemm
::
LayoutSFB
;
using
ScaleConfig
=
typename
Gemm
::
ScaleConfig
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
),
n
=
b
.
size
(
1
),
k
=
a
.
size
(
1
);
auto
prob_shape
=
cute
::
make_shape
(
m
,
n
,
k
,
1
);
StrideA
a_stride
;
StrideB
b_stride
;
StrideC
c_stride
;
a_stride
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
m
,
k
,
1
));
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
n
,
k
,
1
));
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
m
,
n
,
1
));
LayoutSFA
layout_SFA
=
ScaleConfig
::
tile_atom_to_shape_SFA
(
make_shape
(
m
,
n
,
k
,
1
));
LayoutSFB
layout_SFB
=
ScaleConfig
::
tile_atom_to_shape_SFB
(
make_shape
(
m
,
n
,
k
,
1
));
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_scales_ptr
=
static_cast
<
float
*>
(
a_scales
.
data_ptr
());
auto
b_scales_ptr
=
static_cast
<
float
*>
(
b_scales
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
,
a_scales_ptr
,
layout_SFA
,
b_scales_ptr
,
layout_SFB
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
c3x
::
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
}
template
<
typename
OutType
>
void
cutlass_gemm_blockwise_sm100_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
auto
m
=
a
.
size
(
0
);
auto
k
=
a
.
size
(
1
);
auto
n
=
b
.
size
(
1
);
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
a
.
get_device
());
auto
should_use_2sm
=
[
&
sms
](
int
m
,
int
n
,
int
tile1SM
=
128
)
{
return
std
::
ceil
(
static_cast
<
float
>
(
m
)
/
tile1SM
)
*
std
::
ceil
(
static_cast
<
float
>
(
n
)
/
tile1SM
)
>=
sms
;
};
bool
use_2sm
=
should_use_2sm
(
m
,
n
);
if
(
use_2sm
)
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
Shape
<
_256
,
_128
,
_128
>
,
Shape
<
_256
,
_1
,
_1
>
,
Shape
<
_2
,
_2
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized2Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise2SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
cutlass_gemm_caller_blockwise
<
cutlass_3x_gemm_fp8_blockwise
<
OutType
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_128
,
_1
,
_1
>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
epilogue
::
TmaWarpSpecialized1Sm
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedBlockwise1SmSm100
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_helper.hpp
0 → 100644
View file @
376786fa
#include <torch/all.h>
#include "cuda_utils.h"
template
<
typename
Fp8Func
,
typename
Int8Func
,
typename
BlockwiseFunc
>
void
dispatch_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
,
Fp8Func
fp8_func
,
Int8Func
int8_func
,
BlockwiseFunc
blockwise_func
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
if
((
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)))
{
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
fp8_func
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
if
constexpr
(
!
std
::
is_same_v
<
Int8Func
,
std
::
nullptr_t
>
)
{
int8_func
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
false
,
"Int8 not supported for this architecture"
);
}
}
}
else
{
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
TORCH_CHECK
(
s
.
dim
()
==
2
,
"cutlass_scaled_mm group scales must be 2D"
);
return
{
cuda_utils
::
ceil_div
(
x
.
size
(
0
),
s
.
size
(
0
)),
cuda_utils
::
ceil_div
(
x
.
size
(
1
),
s
.
size
(
1
))};
};
GroupShape
a_scale_group_shape
=
make_group_shape
(
a
,
a_scales
);
GroupShape
b_scale_group_shape
=
make_group_shape
(
b
,
b_scales
);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
((
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
}
&&
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.
\n
"
"a_scale_group_shape must be [1, 128]. Got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128]. Got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
blockwise_func
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
View file @
376786fa
...
@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
...
@@ -36,4 +36,9 @@ void cutlass_scaled_mm_sm100_fp8(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_blockwise_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
}
// namespace vllm
}
// namespace vllm
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm100.cu
View file @
376786fa
#include
<cudaTypedefs.h>
#include
"c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm100 (Blackwell).
NVIDIA GPUs with sm100 (Blackwell).
...
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
...
@@ -15,20 +13,10 @@ void cutlass_scaled_mm_sm100(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
dispatch_scaled_mm
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
vllm
::
cutlass_scaled_mm_sm100_fp8
,
nullptr
,
// int8 not supported on SM100
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
vllm
::
cutlass_scaled_mm_blockwise_sm100_fp8
);
TORCH_CHECK
(
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell"
);
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently, only fp8 gemm is implemented for Blackwell"
);
vllm
::
cutlass_scaled_mm_sm100_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
#endif
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_c3x_sm90.cu
View file @
376786fa
#include
<cudaTypedefs.h>
#include
"c3x/scaled_mm_helper.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "c3x/scaled_mm_kernels.hpp"
#include "cuda_utils.h"
/*
/*
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
This file defines quantized GEMM operations using the CUTLASS 3.x API, for
NVIDIA GPUs with sm90a (Hopper).
NVIDIA GPUs with sm90a (Hopper).
...
@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
...
@@ -15,49 +13,10 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
dispatch_scaled_mm
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
,
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
vllm
::
cutlass_scaled_mm_sm90_fp8
,
vllm
::
cutlass_scaled_mm_sm90_int8
,
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
);
if
((
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)))
{
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)
{
vllm
::
cutlass_scaled_mm_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
vllm
::
cutlass_scaled_mm_sm90_int8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
}
else
{
using
GroupShape
=
std
::
array
<
int64_t
,
2
>
;
auto
make_group_shape
=
[](
torch
::
Tensor
const
&
x
,
torch
::
Tensor
const
&
s
)
->
GroupShape
{
TORCH_CHECK
(
s
.
dim
()
==
2
,
"cutlass_scaled_mm group scales must be 2D"
);
return
{
cuda_utils
::
ceil_div
(
x
.
size
(
0
),
s
.
size
(
0
)),
cuda_utils
::
ceil_div
(
x
.
size
(
1
),
s
.
size
(
1
))};
};
GroupShape
a_scale_group_shape
=
make_group_shape
(
a
,
a_scales
);
GroupShape
b_scale_group_shape
=
make_group_shape
(
b
,
b_scales
);
// 1x128 per-token group scales for activations
// 128x128 blockwise scales for weights
TORCH_CHECK
((
a_scale_group_shape
==
GroupShape
{
1
,
128
}
&&
b_scale_group_shape
==
GroupShape
{
128
,
128
}
&&
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
&&
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
),
"cutlass_scaled_mm only supports datatype float8_e4m3fn.
\n
"
"a_scale_group_shape must be [1, 128]. Got: ["
,
a_scale_group_shape
[
0
],
", "
,
a_scale_group_shape
[
1
],
"]
\n
"
"b_scale_group_shape must be [128, 128]. Got: ["
,
b_scale_group_shape
[
0
],
", "
,
b_scale_group_shape
[
1
],
"]"
);
TORCH_CHECK
(
!
bias
,
"Bias not yet supported blockwise scaled_mm"
);
vllm
::
cutlass_scaled_mm_blockwise_sm90_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
376786fa
...
@@ -110,6 +110,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
...
@@ -110,6 +110,8 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
#if defined CUDA_VERSION
#if defined CUDA_VERSION
if
(
cuda_device_capability
>=
90
&&
cuda_device_capability
<
100
)
{
if
(
cuda_device_capability
>=
90
&&
cuda_device_capability
<
100
)
{
return
CUDA_VERSION
>=
12000
;
return
CUDA_VERSION
>=
12000
;
}
else
if
(
cuda_device_capability
>=
100
)
{
return
CUDA_VERSION
>=
12080
;
}
}
#endif
#endif
...
...
tests/kernels/quantization/test_cutlass_scaled_mm.py
View file @
376786fa
...
@@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int,
...
@@ -95,7 +95,7 @@ def cutlass_fp8_gemm_helper(m: int,
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
5e-
2
)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
1.
5e-
1
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
...
@@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
...
@@ -161,6 +161,8 @@ def test_cutlass_fp8_blockwise_scale_gemm(m: int, n: int, k: int,
return
return
if
m
%
a_scale_group_shape
[
0
]
!=
0
or
k
%
a_scale_group_shape
[
1
]
!=
0
:
if
m
%
a_scale_group_shape
[
0
]
!=
0
or
k
%
a_scale_group_shape
[
1
]
!=
0
:
return
return
if
m
%
4
!=
0
and
current_platform
.
has_device_capability
(
100
):
return
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
a_scale_group_shape
,
b_scale_group_shape
,
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
a_scale_group_shape
,
b_scale_group_shape
,
use_bias
)
use_bias
)
...
...
vllm/model_executor/layers/quantization/utils/fp8_utils.py
View file @
376786fa
...
@@ -57,6 +57,16 @@ def apply_w8a8_block_fp8_linear(
...
@@ -57,6 +57,16 @@ def apply_w8a8_block_fp8_linear(
or
br
not
in
(
1
,
weight
.
shape
[
0
])):
or
br
not
in
(
1
,
weight
.
shape
[
0
])):
shape_supported_by_cutlass
=
False
shape_supported_by_cutlass
=
False
if
cutlass_block_fp8_supported
and
shape_supported_by_cutlass
:
if
cutlass_block_fp8_supported
and
shape_supported_by_cutlass
:
rows
,
cols
=
input_2d
.
shape
# Blackwell GPUs (SM100) require row dimensions to be multiple of 4 for
# optimal tensor core usage. Can be removed when targeting platforms
# without this constraint.
should_pad
=
current_platform
.
has_device_capability
(
100
)
and
rows
%
4
!=
0
if
should_pad
:
input_2d
=
torch
.
nn
.
functional
.
pad
(
input_2d
,
(
0
,
0
,
0
,
4
-
(
rows
%
4
)),
value
=
0
).
contiguous
()
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
block_size
[
1
],
column_major_scales
=
True
)
column_major_scales
=
True
)
...
@@ -65,6 +75,8 @@ def apply_w8a8_block_fp8_linear(
...
@@ -65,6 +75,8 @@ def apply_w8a8_block_fp8_linear(
out_dtype
=
input
.
dtype
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
.
T
)
scale_b
=
weight_scale
.
T
)
if
should_pad
:
output
=
output
[:
rows
,
:]
else
:
else
:
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
q_input
,
x_scale
=
per_token_group_quant_fp8
(
input_2d
,
block_size
[
1
],
block_size
[
1
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment