Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
f89978ad
Unverified
Commit
f89978ad
authored
Mar 04, 2025
by
kushanam
Committed by
GitHub
Mar 04, 2025
Browse files
add cutlass support for blackwell fp8 gemm (#13798)
parent
b3cf368d
Changes
11
Show whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
272 additions
and
65 deletions
+272
-65
CMakeLists.txt
CMakeLists.txt
+24
-8
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
+23
-25
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
+20
-12
csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
+62
-6
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
+6
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu
+24
-0
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
...ization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
+67
-0
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
+25
-0
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+19
-2
csrc/quantization/machete/machete_mm_kernel.cuh
csrc/quantization/machete/machete_mm_kernel.cuh
+1
-6
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
+1
-6
No files found.
CMakeLists.txt
View file @
f89978ad
...
@@ -31,7 +31,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
...
@@ -31,7 +31,7 @@ set(ignoreMe "${VLLM_PYTHON_PATH}")
set
(
PYTHON_SUPPORTED_VERSIONS
"3.9"
"3.10"
"3.11"
"3.12"
)
set
(
PYTHON_SUPPORTED_VERSIONS
"3.9"
"3.10"
"3.11"
"3.12"
)
# Supported NVIDIA architectures.
# Supported NVIDIA architectures.
set
(
CUDA_SUPPORTED_ARCHS
"7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0"
)
set
(
CUDA_SUPPORTED_ARCHS
"7.0;7.2;7.5;8.0;8.6;8.7;8.9;9.0
;10.0;10.1;12.0
"
)
# Supported AMD GPU architectures.
# Supported AMD GPU architectures.
set
(
HIP_SUPPORTED_ARCHS
"gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101"
)
set
(
HIP_SUPPORTED_ARCHS
"gfx906;gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101"
)
...
@@ -297,7 +297,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -297,7 +297,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# Only build Marlin kernels if we are building for at least some compatible archs.
# Only build Marlin kernels if we are building for at least some compatible archs.
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# Keep building Marlin for 9.0 as there are some group sizes and shapes that
# are not supported by Machete yet.
# are not supported by Machete yet.
cuda_archs_loose_intersection
(
MARLIN_ARCHS
"8.0;8.6;8.7;8.9;9.0"
"
${
CUDA_ARCHS
}
"
)
cuda_archs_loose_intersection
(
MARLIN_ARCHS
"8.0;8.6;8.7;8.9;9.0
;10.0;10.1;12.0
"
"
${
CUDA_ARCHS
}
"
)
if
(
MARLIN_ARCHS
)
if
(
MARLIN_ARCHS
)
set
(
MARLIN_SRCS
set
(
MARLIN_SRCS
"csrc/quantization/fp8/fp8_marlin.cu"
"csrc/quantization/fp8/fp8_marlin.cu"
...
@@ -335,7 +335,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -335,7 +335,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
# CUDA 12.0 or later (and only work on Hopper, 9.0a for now).
cuda_archs_loose_intersection
(
SCALED_MM_3X_ARCHS
"9.0a"
"
${
CUDA_ARCHS
}
"
)
cuda_archs_loose_intersection
(
SCALED_MM_3X_ARCHS
"9.0a
;10.0a;10.1a;12.0a
"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS
)
set
(
SRCS
set
(
SRCS
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu"
...
@@ -369,7 +369,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -369,7 +369,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x)
# kernels for the remaining archs that are not already built for 3x.
# kernels for the remaining archs that are not already built for 3x.
cuda_archs_loose_intersection
(
SCALED_MM_2X_ARCHS
cuda_archs_loose_intersection
(
SCALED_MM_2X_ARCHS
"7.5;8.0;8.6;8.7;8.9;9.0"
"
${
CUDA_ARCHS
}
"
)
"7.5;8.0;8.6;8.7;8.9;9.0
;10.0;10.1;12.0
"
"
${
CUDA_ARCHS
}
"
)
# subtract out the archs that are already built for 3x
# subtract out the archs that are already built for 3x
list
(
REMOVE_ITEM SCALED_MM_2X_ARCHS
${
SCALED_MM_3X_ARCHS
}
)
list
(
REMOVE_ITEM SCALED_MM_2X_ARCHS
${
SCALED_MM_3X_ARCHS
}
)
if
(
SCALED_MM_2X_ARCHS
)
if
(
SCALED_MM_2X_ARCHS
)
...
@@ -394,7 +394,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -394,7 +394,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
# 2:4 Sparse Kernels
# 2:4 Sparse Kernels
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# The 2:4 sparse kernels cutlass_scaled_sparse_mm and cutlass_compressor
# require CUDA 12.2 or later (and only work on Hopper
, 9.0a for now
).
# require CUDA 12.2 or later (and only work on Hopper
and Blackwell
).
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.2 AND SCALED_MM_3X_ARCHS
)
set
(
SRCS
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu"
)
set
(
SRCS
"csrc/sparse/cutlass/sparse_scaled_mm_c3x.cu"
)
set_gencode_flags_for_srcs
(
set_gencode_flags_for_srcs
(
...
@@ -419,8 +419,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -419,8 +419,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.8 AND FP4_ARCHS
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.8 AND FP4_ARCHS
)
set
(
SRCS
set
(
SRCS
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/nvfp4_quant_kernels.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
"csrc/quantization/fp4/nvfp4_scaled_mm_kernels.cu"
)
)
set_gencode_flags_for_srcs
(
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
FP4_ARCHS
}
"
)
CUDA_ARCHS
"
${
FP4_ARCHS
}
"
)
...
@@ -433,6 +432,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -433,6 +432,22 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
set
(
FP4_ARCHS
)
set
(
FP4_ARCHS
)
endif
()
endif
()
# FP8 Blackwell Archs
cuda_archs_loose_intersection
(
BLACKWELL_ARCHS
"10.0;10.1;12.0"
"
${
CUDA_ARCHS
}
"
)
if
(
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_GREATER 12.8 AND BLACKWELL_ARCHS
)
set
(
SRCS
"csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu"
)
set_gencode_flags_for_srcs
(
SRCS
"
${
SRCS
}
"
CUDA_ARCHS
"
${
BLACKWELL_ARCHS
}
"
)
list
(
APPEND VLLM_EXT_SRC
"
${
SRCS
}
"
)
message
(
STATUS
"Building FP8 for archs:
${
BLACKWELL_ARCHS
}
"
)
else
()
# clear BLACKWELL_ARCHS
set
(
BLACKWELL_ARCHS
)
endif
()
#
#
# Machete kernels
# Machete kernels
...
@@ -514,6 +529,7 @@ define_gpu_extension_target(
...
@@ -514,6 +529,7 @@ define_gpu_extension_target(
COMPILE_FLAGS
${
VLLM_GPU_FLAGS
}
COMPILE_FLAGS
${
VLLM_GPU_FLAGS
}
ARCHITECTURES
${
VLLM_GPU_ARCHES
}
ARCHITECTURES
${
VLLM_GPU_ARCHES
}
INCLUDE_DIRECTORIES
${
CUTLASS_INCLUDE_DIR
}
INCLUDE_DIRECTORIES
${
CUTLASS_INCLUDE_DIR
}
INCLUDE_DIRECTORIES
${
CUTLASS_TOOLS_UTIL_INCLUDE_DIR
}
USE_SABI 3
USE_SABI 3
WITH_SOABI
)
WITH_SOABI
)
...
@@ -537,7 +553,7 @@ set_gencode_flags_for_srcs(
...
@@ -537,7 +553,7 @@ set_gencode_flags_for_srcs(
CUDA_ARCHS
"
${
CUDA_ARCHS
}
"
)
CUDA_ARCHS
"
${
CUDA_ARCHS
}
"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
cuda_archs_loose_intersection
(
MARLIN_MOE_ARCHS
"8.0;8.6;8.7;8.9;9.0"
"
${
CUDA_ARCHS
}
"
)
cuda_archs_loose_intersection
(
MARLIN_MOE_ARCHS
"8.0;8.6;8.7;8.9;9.0
;10.0;10.1;12.0
"
"
${
CUDA_ARCHS
}
"
)
if
(
MARLIN_MOE_ARCHS
)
if
(
MARLIN_MOE_ARCHS
)
set
(
MARLIN_MOE_SRC
set
(
MARLIN_MOE_SRC
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
"csrc/moe/marlin_kernels/marlin_moe_kernel.h"
...
...
csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
View file @
f89978ad
...
@@ -22,7 +22,7 @@ struct identity {
...
@@ -22,7 +22,7 @@ struct identity {
T
operator
()(
T
lhs
)
const
{
return
lhs
;
}
T
operator
()(
T
lhs
)
const
{
return
lhs
;
}
};
};
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
TrivialEpilogue
{
struct
TrivialEpilogue
{
private:
private:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
...
@@ -44,32 +44,30 @@ struct TrivialEpilogue {
...
@@ -44,32 +44,30 @@ struct TrivialEpilogue {
* This class provides the common load descriptors for the
* This class provides the common load descriptors for the
* ScaledEpilogue[...] classes
* ScaledEpilogue[...] classes
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBase
{
struct
ScaledEpilogueBase
{
protected:
protected:
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
template
<
typename
T
>
template
<
typename
T
>
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
using
ColOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
0
/*Stages*/
,
TileShape
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
template
<
typename
T
>
template
<
typename
T
>
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
using
RowOrScalarLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowOrScalarBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
0
/*Stages*/
,
TileShape
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
// Don't want to support nullptr by default
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
using
ColLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
T
,
0
/*Stages*/
,
TileShape
,
T
,
T
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// Don't want to support nullptr by default
// Don't want to support nullptr by default
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
template
<
typename
T
,
bool
EnableNullPtr
=
false
>
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
using
RowLoad
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
T
,
T
,
0
/*Stages*/
,
TileShape
,
T
,
T
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>
,
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
128
/
sizeof_bits_v
<
T
>
,
EnableNullPtr
>
;
// This utility function constructs the arguments for the load descriptors
// This utility function constructs the arguments for the load descriptors
// from a tensor. It can handle both row and column, as well as row/column or
// from a tensor. It can handle both row and column, as well as row/column or
...
@@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
...
@@ -116,11 +114,11 @@ struct ScaledEpilogueBase {
the A and B operands respectively. These scales may be either per-tensor or
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
per row or column.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogue
struct
ScaledEpilogue
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
@@ -160,11 +158,11 @@ struct ScaledEpilogue
...
@@ -160,11 +158,11 @@ struct ScaledEpilogue
* The bias tensor must be per-output channel.
* The bias tensor must be per-output channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
* ScaleA and ScaleB can be per-tensor or per-token/per-channel.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBias
struct
ScaledEpilogueBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
@@ -203,11 +201,11 @@ struct ScaledEpilogueBias
...
@@ -203,11 +201,11 @@ struct ScaledEpilogueBias
* bias is a column vector instead of a row vector. Useful e.g. if we are
* bias is a column vector instead of a row vector. Useful e.g. if we are
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
* computing a GEMM via C^T += B^T A^T. This happens in the 2:4 sparse kernels.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueColumnBias
struct
ScaledEpilogueColumnBias
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
@@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
...
@@ -249,11 +247,11 @@ struct ScaledEpilogueColumnBias
*
*
* This epilogue also supports bias, which remains per-channel.
* This epilogue also supports bias, which remains per-channel.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBiasAzp
struct
ScaledEpilogueBiasAzp
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
@@ -314,11 +312,11 @@ struct ScaledEpilogueBiasAzp
...
@@ -314,11 +312,11 @@ struct ScaledEpilogueBiasAzp
*
*
* This epilogue also supports bias, which remains per-channel.
* This epilogue also supports bias, which remains per-channel.
*/
*/
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
EpilogueDescriptor
>
template
<
typename
ElementAcc
,
typename
ElementD
,
typename
TileShape
>
struct
ScaledEpilogueBiasAzpToken
struct
ScaledEpilogueBiasAzpToken
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
{
:
private
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
{
private:
private:
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
SUPER
=
ScaledEpilogueBase
<
ElementAcc
,
ElementD
,
TileShape
>
;
using
Accum
=
typename
SUPER
::
Accum
;
using
Accum
=
typename
SUPER
::
Accum
;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleA
=
typename
SUPER
::
template
ColOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
using
ScaleB
=
typename
SUPER
::
template
RowOrScalarLoad
<
float
>;
...
...
csrc/quantization/cutlass_w8a8/c3x/cutlass_gemm_caller.cuh
View file @
f89978ad
...
@@ -16,6 +16,7 @@
...
@@ -16,6 +16,7 @@
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/util/packed_stride.hpp"
#include "core/math.hpp"
#include "core/math.hpp"
#include "cutlass_extensions/common.hpp"
#include "cutlass_extensions/common.hpp"
...
@@ -64,22 +65,28 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
...
@@ -64,22 +65,28 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_params
)
{
EpilogueArgs
&&
...
epilogue_params
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementC
=
typename
Gemm
::
ElementC
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
ElementD
=
typename
Gemm
::
ElementD
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
int64_t
lda
=
a
.
stride
(
0
);
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
int64_t
ldb
=
b
.
stride
(
1
);
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
StrideD
=
StrideC
;
using
StrideA
=
cute
::
Stride
<
int64_t
,
cute
::
Int
<
1
>
,
int64_t
>
;
using
StrideAux
=
StrideC
;
using
StrideB
=
cute
::
Stride
<
int64_t
,
cute
::
Int
<
1
>
,
int64_t
>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
cute
::
Int
<
1
>
{},
0
};
StrideB
b_stride
{
ldb
,
cute
::
Int
<
1
>
{},
0
};
StrideC
c_stride
{
ldc
,
cute
::
Int
<
1
>
{},
cute
::
Int
<
0
>
{}};
typename
GemmKernel
::
ProblemShape
prob_shape
=
get_problem_shape
(
a
,
b
);
typename
GemmKernel
::
ProblemShape
prob_shape
=
get_problem_shape
(
a
,
b
);
auto
[
M
,
N
,
K
,
L
]
=
prob_shape
;
StrideA
a_stride
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
cute
::
make_shape
(
M
,
K
,
L
));
StrideB
b_stride
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
cute
::
make_shape
(
N
,
K
,
L
));
StrideC
c_stride
=
cutlass
::
make_cute_packed_stride
(
StrideC
{},
cute
::
make_shape
(
M
,
N
,
L
));
StrideD
d_stride
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
cute
::
make_shape
(
M
,
N
,
L
));
StrideAux
aux_stride
=
d_stride
;
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
...
@@ -87,10 +94,11 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
...
@@ -87,10 +94,11 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
b_stride
};
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
// auto d_ptr = static_cast<ElementC*>(out.data_ptr());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
Gemm
::
Epilogue
::
prepare_args
(
Gemm
::
Epilogue
::
prepare_args
(
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
std
::
forward
<
EpilogueArgs
>
(
epilogue_params
)...),
c_ptr
,
c_stride
,
c_ptr
,
c
_stride
};
c_ptr
,
c_stride
,
c_ptr
,
d
_stride
};
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
cutlass_gemm_caller
<
GemmKernel
>
(
a
.
device
(),
prob_shape
,
mainloop_args
,
epilogue_args
);
epilogue_args
);
...
...
csrc/quantization/cutlass_w8a8/c3x/scaled_mm.cuh
View file @
f89978ad
...
@@ -40,12 +40,7 @@ struct cutlass_3x_gemm {
...
@@ -40,12 +40,7 @@ struct cutlass_3x_gemm {
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
float
>::
type
;
using
EpilogueDescriptor
=
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
TileShape
>
;
cutlass
::
epilogue
::
collective
::
detail
::
EpilogueDescriptor
<
TileShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementD
,
ElementD
,
EpilogueSchedule
>
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
ElementC
=
void
;
using
ElementC
=
void
;
...
@@ -88,4 +83,65 @@ struct cutlass_3x_gemm {
...
@@ -88,4 +83,65 @@ struct cutlass_3x_gemm {
struct
GemmKernel
:
public
KernelType
{};
struct
GemmKernel
:
public
KernelType
{};
};
};
template
<
typename
ElementAB_
,
typename
ElementD_
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_3x_gemm_sm100
{
using
ElementAB
=
ElementAB_
;
using
LayoutA
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementAB
>::
value
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementD_
>::
value
;
using
ElementD
=
ElementD_
;
using
LayoutD
=
cutlass
::
layout
::
RowMajor
;
static
constexpr
int
AlignmentD
=
AlignmentC
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
TileShape
>
;
// MMA type
using
ElementAccumulator
=
float
;
// Epilogue types
using
ElementBias
=
cutlass
::
half_t
;
using
ElementCompute
=
float
;
using
ElementAux
=
ElementD
;
using
LayoutAux
=
LayoutD
;
using
ElementAmax
=
float
;
using
EVTCompute
=
typename
Epilogue
::
EVTCompute
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm100
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementC
,
LayoutC
,
AlignmentC
,
ElementD
,
LayoutD
,
AlignmentD
,
EpilogueSchedule
,
EVTCompute
>::
CollectiveOp
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm100
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementAB
,
LayoutA
,
AlignmentA
,
ElementAB
,
LayoutB
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
,
KernelSchedule
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
void
>
;
};
}
// namespace vllm
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_kernels.hpp
View file @
f89978ad
...
@@ -30,4 +30,10 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
...
@@ -30,4 +30,10 @@ void cutlass_scaled_mm_blockwise_sm90_fp8(torch::Tensor& out,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
}
// namespace vllm
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8.cu
0 → 100644
View file @
f89978ad
#include "scaled_mm_kernels.hpp"
#include "scaled_mm_sm100_fp8_dispatch.cuh"
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
namespace
vllm
{
void
cutlass_scaled_mm_sm100_fp8
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
if
(
bias
)
{
TORCH_CHECK
(
bias
->
dtype
()
==
out
.
dtype
(),
"currently bias dtype must match output dtype "
,
out
.
dtype
());
return
cutlass_scaled_mm_sm100_fp8_epilogue
<
c3x
::
ScaledEpilogueBias
>
(
out
,
a
,
b
,
a_scales
,
b_scales
,
*
bias
);
}
else
{
return
cutlass_scaled_mm_sm100_fp8_epilogue
<
c3x
::
ScaledEpilogue
>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
// namespace vllm
csrc/quantization/cutlass_w8a8/c3x/scaled_mm_sm100_fp8_dispatch.cuh
0 → 100644
View file @
f89978ad
#pragma once
#include "scaled_mm.cuh"
#include "cutlass_gemm_caller.cuh"
/**
* This file defines Gemm kernel configurations for SM100 (fp8) based on the
* Gemm shape.
*/
namespace
vllm
{
using
c3x
::
cutlass_gemm_caller
;
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
>
struct
sm100_fp8_config_default
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
using
KernelSchedule
=
cutlass
::
gemm
::
collective
::
KernelScheduleAuto
;
using
EpilogueSchedule
=
cutlass
::
epilogue
::
collective
::
EpilogueScheduleAuto
;
using
TileShape
=
Shape
<
_256
,
_128
,
_64
>
;
using
ClusterShape
=
Shape
<
_2
,
_2
,
_1
>
;
using
Cutlass3xGemm
=
cutlass_3x_gemm_sm100
<
InType
,
OutType
,
Epilogue
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>
;
};
template
<
typename
InType
,
typename
OutType
,
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
inline
void
cutlass_gemm_sm100_fp8_dispatch
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
args
)
{
static_assert
(
std
::
is_same
<
InType
,
cutlass
::
float_e4m3_t
>
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
Cutlass3xGemmDefault
=
typename
sm100_fp8_config_default
<
InType
,
OutType
,
Epilogue
>::
Cutlass3xGemm
;
return
cutlass_gemm_caller
<
Cutlass3xGemmDefault
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
args
)...);
}
template
<
template
<
typename
,
typename
,
typename
>
typename
Epilogue
,
typename
...
EpilogueArgs
>
void
cutlass_scaled_mm_sm100_fp8_epilogue
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
EpilogueArgs
&&
...
epilogue_args
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_gemm_sm100_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_gemm_sm100_fp8_dispatch
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
Epilogue
>
(
out
,
a
,
b
,
std
::
forward
<
EpilogueArgs
>
(
epilogue_args
)...);
}
}
}
// namespace vllm
\ No newline at end of file
csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu
View file @
f89978ad
...
@@ -71,3 +71,28 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
...
@@ -71,3 +71,28 @@ void cutlass_scaled_mm_azp_sm90(torch::Tensor& out, torch::Tensor const& a,
vllm
::
cutlass_scaled_mm_azp_sm90_int8
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
vllm
::
cutlass_scaled_mm_azp_sm90_int8
(
out
,
a
,
b
,
a_scales
,
b_scales
,
azp_adj
,
azp
,
bias
);
azp
,
bias
);
}
}
#if defined CUDA_VERSION && CUDA_VERSION >= 12080
void
cutlass_scaled_mm_sm100
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
int
M
=
a
.
size
(
0
),
N
=
b
.
size
(
1
),
K
=
a
.
size
(
1
);
TORCH_CHECK
(
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
))
&&
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
)),
"Currently, block scaled fp8 gemm is not implemented for Blackwell"
);
// Standard per-tensor/per-token/per-channel scaling
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
,
"Currently, only fp8 gemm is implemented for Blackwell"
);
vllm
::
cutlass_scaled_mm_sm100_fp8
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
}
#endif
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
f89978ad
...
@@ -29,6 +29,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
...
@@ -29,6 +29,11 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
void
cutlass_scaled_mm_sm100
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
#endif
#endif
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm_azp_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
...
@@ -86,7 +91,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
...
@@ -86,7 +91,7 @@ bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability) {
// and at least SM90 (Hopper)
// and at least SM90 (Hopper)
#if defined CUDA_VERSION
#if defined CUDA_VERSION
if
(
cuda_device_capability
>=
90
)
{
if
(
cuda_device_capability
>=
90
&&
cuda_device_capability
<
100
)
{
return
CUDA_VERSION
>=
12000
;
return
CUDA_VERSION
>=
12000
;
}
}
#endif
#endif
...
@@ -120,10 +125,22 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
...
@@ -120,10 +125,22 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
// Guard against compilation issues for sm90 kernels
// Guard against compilation issues for sm90 kernels
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X
if
(
version_num
>=
90
)
{
#if defined CUDA_VERSION && CUDA_VERSION < 12080
if
(
version_num
>=
90
&&
version_num
<
100
)
{
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
#else
if
(
version_num
>=
90
&&
version_num
<
100
)
{
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
cutlass_scaled_mm_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
return
;
}
else
if
(
version_num
>=
100
)
{
cutlass_scaled_mm_sm100
(
c
,
a
,
b
,
a_scales
,
b_scales
,
bias
);
return
;
}
}
#endif
#endif
#endif
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X
...
...
csrc/quantization/machete/machete_mm_kernel.cuh
View file @
f89978ad
...
@@ -126,15 +126,10 @@ struct MacheteKernelTemplate {
...
@@ -126,15 +126,10 @@ struct MacheteKernelTemplate {
std
::
is_same_v
<
ElementSChannel
,
ElementSToken
>
),
std
::
is_same_v
<
ElementSChannel
,
ElementSToken
>
),
"Currently token and channel scales (if present) must be the same type"
);
"Currently token and channel scales (if present) must be the same type"
);
using
EpilogueDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
EpilogueDescriptor
<
TileShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementD
,
ElementD
,
EpilogueSchedule
>
;
// Currently only supports float scales
// Currently only supports float scales
using
ChTokScalesEpilogue
=
using
ChTokScalesEpilogue
=
typename
vllm
::
c3x
::
ScaledEpilogue
<
ElementAccumulator
,
ElementD
,
typename
vllm
::
c3x
::
ScaledEpilogue
<
ElementAccumulator
,
ElementD
,
EpilogueDescriptor
>
;
TileShape
>
;
static_assert
((
with_channel_scales
||
with_token_scales
)
||
static_assert
((
with_channel_scales
||
with_token_scales
)
||
(
std
::
is_same_v
<
ElementSChannel
,
float
>
&&
(
std
::
is_same_v
<
ElementSChannel
,
float
>
&&
std
::
is_same_v
<
ElementSToken
,
float
>
),
std
::
is_same_v
<
ElementSToken
,
float
>
),
...
...
csrc/sparse/cutlass/sparse_scaled_mm_c3x.cuh
View file @
f89978ad
...
@@ -65,12 +65,7 @@ struct cutlass_sparse_3x_gemm {
...
@@ -65,12 +65,7 @@ struct cutlass_sparse_3x_gemm {
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
float
>::
type
;
using
EpilogueDescriptor
=
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
TileShape
>
;
cutlass
::
epilogue
::
collective
::
detail
::
EpilogueDescriptor
<
TileShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementD
,
ElementD
,
EpilogueSchedule
>
;
using
Epilogue
=
Epilogue_
<
ElementAcc
,
ElementD
,
EpilogueDescriptor
>
;
using
ElementC
=
void
;
using
ElementC
=
void
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
using
LayoutC
=
cutlass
::
layout
::
RowMajor
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment