Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2060e936
Unverified
Commit
2060e936
authored
May 16, 2024
by
Tyler Michael Smith
Committed by
GitHub
May 16, 2024
Browse files
[Kernel] Add w8a8 CUTLASS kernels (#4749)
parent
8435b207
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
1197 additions
and
2 deletions
+1197
-2
CMakeLists.txt
CMakeLists.txt
+26
-1
csrc/ops.h
csrc/ops.h
+8
-0
csrc/pybind.cpp
csrc/pybind.cpp
+1
-0
csrc/quantization/cutlass_w8a8/common.hpp
csrc/quantization/cutlass_w8a8/common.hpp
+12
-0
csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp
...on/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp
+340
-0
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
+296
-0
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
+240
-0
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
+65
-0
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+192
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+17
-1
No files found.
CMakeLists.txt
View file @
2060e936
...
@@ -173,6 +173,16 @@ set(VLLM_EXT_SRC
...
@@ -173,6 +173,16 @@ set(VLLM_EXT_SRC
"csrc/pybind.cpp"
)
"csrc/pybind.cpp"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
if
(
VLLM_GPU_LANG STREQUAL
"CUDA"
)
include
(
FetchContent
)
SET
(
CUTLASS_ENABLE_HEADERS_ONLY=ON
)
FetchContent_Declare
(
cutlass
GIT_REPOSITORY https://github.com/nvidia/cutlass.git
# CUTLASS 3.5.0
GIT_TAG 7d49e6c7e2f8896c47f586706e67e1fb215529dc
)
FetchContent_MakeAvailable
(
cutlass
)
list
(
APPEND VLLM_EXT_SRC
list
(
APPEND VLLM_EXT_SRC
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/aqlm/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
"csrc/quantization/awq/gemm_kernels.cu"
...
@@ -180,7 +190,21 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
...
@@ -180,7 +190,21 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/quantization/gptq_marlin/gptq_marlin_repack.cu"
"csrc/custom_all_reduce.cu"
)
"csrc/custom_all_reduce.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu"
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
)
#
# The CUTLASS kernels for Hopper require sm90a to be enabled.
# This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a.
# That adds an extra 17MB to compiled binary, so instead we selectively enable it.
set_source_files_properties
(
"csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu"
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a"
)
endif
()
endif
()
define_gpu_extension_target
(
define_gpu_extension_target
(
...
@@ -190,6 +214,7 @@ define_gpu_extension_target(
...
@@ -190,6 +214,7 @@ define_gpu_extension_target(
SOURCES
${
VLLM_EXT_SRC
}
SOURCES
${
VLLM_EXT_SRC
}
COMPILE_FLAGS
${
VLLM_GPU_FLAGS
}
COMPILE_FLAGS
${
VLLM_GPU_FLAGS
}
ARCHITECTURES
${
VLLM_GPU_ARCHES
}
ARCHITECTURES
${
VLLM_GPU_ARCHES
}
INCLUDE_DIRECTORIES
${
CUTLASS_INCLUDE_DIR
}
;
${
CUTLASS_TOOLS_UTIL_INCLUDE_DIR
}
WITH_SOABI
)
WITH_SOABI
)
#
#
...
...
csrc/ops.h
View file @
2060e936
...
@@ -155,6 +155,14 @@ torch::Tensor gptq_marlin_repack(
...
@@ -155,6 +155,14 @@ torch::Tensor gptq_marlin_repack(
int64_t
size_k
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
size_n
,
int64_t
num_bits
);
int64_t
num_bits
);
int
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
#endif
#endif
void
squeezellm_gemm
(
void
squeezellm_gemm
(
...
...
csrc/pybind.cpp
View file @
2060e936
...
@@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -71,6 +71,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"gptq_marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"gptq_marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"gptq_marlin repack from GPTQ"
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"gptq_marlin repack from GPTQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
ops
.
def
(
"cutlass_scaled_mm_dq"
,
&
cutlass_scaled_mm_dq
,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."
);
#endif
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
...
...
csrc/quantization/cutlass_w8a8/common.hpp
0 → 100644
View file @
2060e936
#pragma once
#include "cutlass/cutlass.h"
/**
* Helper function for checking CUTLASS errors
*/
#define CUTLASS_CHECK(status) \
{ \
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status)) \
}
csrc/quantization/cutlass_w8a8/cutlass_visitor_2x_broadcast_epilogue.hpp
0 → 100644
View file @
2060e936
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights
*reserved. SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice,
*this list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
*ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE
*LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
*CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
*SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
*INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
*CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
*ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
*POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
//
// This file is a modified excerpt of
// include/cutlass/epilogue/fusion/visitor_load.hpp from
// https://github.com/NVIDIA/cutlass It's beem modified to support either
// row/column or scalar broadcasting, like is already supported in CUTLASS 3.x.
// Important because this saves us a factor 4x on the number of kernels
// compiled.
//
#pragma once
// clang-format off
#include "cutlass/epilogue/threadblock/fusion/visitor_2x.hpp"
#include "cute/tensor.hpp"
// clang-format on
namespace
cutlass
::
epilogue
::
threadblock
{
using
namespace
cute
;
using
namespace
detail
;
template
<
class
ThreadMap
,
class
Element
,
class
StrideMNL
>
struct
VisitorRowOrScalarBroadcast
{
struct
Arguments
{
Element
const
*
ptr_row
=
nullptr
;
Element
null_default
=
Element
(
0
);
StrideMNL
dRow
=
{};
};
using
Params
=
Arguments
;
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
0
;
}
struct
SharedStorage
{};
// Global load type
static
int
constexpr
vec_bits
=
ThreadMap
::
kElementsPerAccess
*
sizeof_bits
<
Element
>::
value
;
using
VecType
=
uint_bit_t
<
cute
::
min
(
128
,
vec_bits
)
>
;
static
int
constexpr
VecLength
=
sizeof
(
VecType
)
/
sizeof
(
Element
);
CUTLASS_HOST_DEVICE
VisitorRowOrScalarBroadcast
()
{
}
CUTLASS_HOST_DEVICE
VisitorRowOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params_ptr
(
&
params
)
{
}
Params
const
*
params_ptr
;
template
<
class
GTensor
,
class
RTensor
,
class
CTensor
,
class
ProblemShape
>
struct
Callbacks
:
EmptyCallbacks
{
CUTLASS_DEVICE
Callbacks
(
GTensor
&&
tC_gRow
,
RTensor
&&
tC_rRow
,
CTensor
&&
tC_cRow
,
ProblemShape
problem_shape
,
Params
const
*
params_ptr
)
:
tC_gRow
(
cute
::
forward
<
GTensor
>
(
tC_gRow
)),
tC_rRow
(
cute
::
forward
<
RTensor
>
(
tC_rRow
)),
tC_cRow
(
cute
::
forward
<
CTensor
>
(
tC_cRow
)),
n
(
get
<
1
>
(
problem_shape
)),
params_ptr
(
params_ptr
)
{
}
GTensor
tC_gRow
;
RTensor
tC_rRow
;
CTensor
tC_cRow
;
Params
const
*
params_ptr
;
int
n
;
// This function is modified from VisitorRowBroadcast
CUTLASS_DEVICE
void
begin_epilogue
()
{
clear
(
tC_rRow
);
auto
src_v
=
filter
(
tC_gRow
);
auto
coord_v
=
filter
(
tC_cRow
);
auto
dst_v
=
filter
(
tC_rRow
);
if
(
params_ptr
->
ptr_row
)
{
// In this case we are loading from a row vector and broadcasting
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
src_v
);
++
i
)
{
bool
guard
=
get
<
1
>
(
coord_v
(
i
))
<
n
;
cutlass
::
arch
::
global_load
<
VecType
,
sizeof
(
VecType
)
>
(
dst_v
(
i
),
(
void
const
*
)
&
src_v
(
i
),
guard
);
}
}
else
{
// In this case we are loading from a scalar and broadcasting
VecType
filled_vec
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
VecLength
;
i
++
)
{
reinterpret_cast
<
Element
*>
(
&
filled_vec
)[
i
]
=
params_ptr
->
null_default
;
}
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
src_v
);
++
i
)
{
if
(
get
<
1
>
(
coord_v
(
i
))
<
n
)
{
dst_v
(
i
)
=
filled_vec
;
}
}
}
}
template
<
class
ElementAccumulator
,
int
FragmentSize
>
CUTLASS_DEVICE
auto
// returns an Array
visit
(
int
iter_idx
,
int
row_idx
,
int
column_idx
,
int
frg_idx
,
Array
<
ElementAccumulator
,
FragmentSize
>
const
&
frg_acc
)
{
Tensor
rRow_frg
=
recast
<
Array
<
Element
,
FragmentSize
>>
(
coalesce
(
tC_rRow
));
return
rRow_frg
(
column_idx
);
}
};
template
<
class
ProblemShape
>
CUTLASS_DEVICE
auto
get_callbacks
(
gemm
::
GemmCoord
threadblock_tile_offset
,
int
thread_idx
,
ProblemShape
problem_shape
)
{
Tensor
mRow
=
make_tensor
(
make_gmem_ptr
(
params_ptr
->
ptr_row
),
problem_shape
,
params_ptr
->
dRow
);
// VECTOR, FRAGMENT_COLUMN
Tensor
tC_gRow
=
recast
<
VecType
>
(
ThreadMap
::
partition
(
mRow
,
thread_idx
,
threadblock_tile_offset
)
)(
_
,
_
,
_0
{},
_0
{},
_0
{},
_0
{});
Tensor
tC_rRow
=
make_tensor_like
(
tC_gRow
);
// Generate the pred tensor
Tensor
cRow
=
make_identity_tensor
(
mRow
.
shape
());
Tensor
tC_cRow
=
outer_partition
(
ThreadMap
::
partition
(
cRow
,
thread_idx
,
threadblock_tile_offset
)(
_
,
_
,
_0
{},
_0
{},
_0
{},
_0
{}),
Shape
<
Int
<
VecLength
>>
{},
(
_0
{})
);
return
Callbacks
<
decltype
(
tC_gRow
),
decltype
(
tC_rRow
),
decltype
(
tC_cRow
),
ProblemShape
>
(
cute
::
move
(
tC_gRow
),
cute
::
move
(
tC_rRow
),
cute
::
move
(
tC_cRow
),
problem_shape
,
params_ptr
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// Column vector broadcast
template
<
class
ThreadMap
,
class
Element
,
class
StrideMNL
=
Stride
<
_1
,
_0
,
_0
>
>
struct
VisitorColOrScalarBroadcast
{
struct
Arguments
{
Element
const
*
ptr_col
=
nullptr
;
Element
null_default
=
Element
(
0
);
StrideMNL
dCol
=
{};
};
using
Params
=
Arguments
;
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
)
{
return
0
;
}
struct
SharedStorage
{
};
// Global load type
static
int
constexpr
vec_bits
=
ThreadMap
::
kElementsPerAccess
*
sizeof_bits
<
Element
>::
value
;
using
VecType
=
uint_bit_t
<
cute
::
min
(
128
,
vec_bits
)
>
;
static
int
constexpr
VecLength
=
sizeof
(
VecType
)
/
sizeof
(
Element
);
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast
()
{
}
CUTLASS_HOST_DEVICE
VisitorColOrScalarBroadcast
(
Params
const
&
params
,
SharedStorage
const
&
shared_storage
)
:
params_ptr
(
&
params
)
{
}
Params
const
*
params_ptr
;
template
<
class
GTensor
,
class
RTensor
,
class
CTensor
,
class
ProblemShape
>
struct
Callbacks
:
EmptyCallbacks
{
CUTLASS_DEVICE
Callbacks
(
GTensor
&&
tC_gCol
,
RTensor
&&
tC_rCol
,
CTensor
&&
tC_cCol
,
ProblemShape
problem_shape
,
Params
const
*
params_ptr
)
:
tC_gCol
(
cute
::
forward
<
GTensor
>
(
tC_gCol
)),
tC_rCol
(
cute
::
forward
<
RTensor
>
(
tC_rCol
)),
tC_cCol
(
cute
::
forward
<
CTensor
>
(
tC_cCol
)),
m
(
get
<
0
>
(
problem_shape
)),
params_ptr
(
params_ptr
)
{
}
GTensor
tC_gCol
;
RTensor
tC_rCol
;
CTensor
tC_cCol
;
Params
const
*
params_ptr
;
int
m
;
// This function is modified from VisitorColBroadcast
CUTLASS_DEVICE
void
begin_epilogue
()
{
clear
(
tC_rCol
);
Tensor
pred
=
make_tensor
<
bool
>
(
shape
(
tC_gCol
));
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
pred
);
++
i
)
{
pred
(
i
)
=
get
<
0
>
(
tC_cCol
(
i
))
<
m
;
}
if
(
params_ptr
->
ptr_col
)
{
// In this case we are loading from a column vector and broadcasting
copy_if
(
pred
,
tC_gCol
,
tC_rCol
);
}
else
{
// In this case we are loading from a scalar and broadcasting
auto
dst_v
=
filter
(
tC_rCol
);
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
(
dst_v
);
++
i
)
{
if
(
pred
(
i
)){
dst_v
(
i
)
=
params_ptr
->
null_default
;
}
}
}
}
template
<
class
ElementAccumulator
,
int
FragmentSize
>
CUTLASS_DEVICE
auto
// returns an Array
visit
(
int
iter_idx
,
int
row_idx
,
int
column_idx
,
int
frg_idx
,
Array
<
ElementAccumulator
,
FragmentSize
>
const
&
frg_acc
)
{
Array
<
Element
,
FragmentSize
>
frg_col
;
frg_col
.
fill
(
tC_rCol
(
row_idx
,
iter_idx
));
return
frg_col
;
}
};
template
<
class
ProblemShape
>
CUTLASS_DEVICE
auto
get_callbacks
(
gemm
::
GemmCoord
threadblock_tile_offset
,
int
thread_idx
,
ProblemShape
problem_shape
)
{
Tensor
mCol
=
make_tensor
(
make_gmem_ptr
(
params_ptr
->
ptr_col
),
problem_shape
,
params_ptr
->
dCol
);
// VECTOR, FRAGMENT_COLUMN, FRAGMENT_ROW, ITERATION_ROW, ITERATION_GROUP, ITERATION_CLUSTER
Tensor
tC_gCol
=
group_modes
<
1
,
4
>
(
ThreadMap
::
partition
(
mCol
,
thread_idx
,
threadblock_tile_offset
)(
_0
{},
_0
{},
_
,
_
,
_
,
_
));
Tensor
tC_rCol
=
make_tensor_like
(
tC_gCol
);
// Generate the pred tensor
Tensor
cCol
=
make_identity_tensor
(
mCol
.
shape
());
Tensor
tC_cCol
=
group_modes
<
1
,
4
>
(
ThreadMap
::
partition
(
cCol
,
thread_idx
,
threadblock_tile_offset
)(
_0
{},
_0
{},
_
,
_
,
_
,
_
));
return
Callbacks
<
decltype
(
tC_gCol
),
decltype
(
tC_rCol
),
decltype
(
tC_cCol
),
ProblemShape
>
(
cute
::
move
(
tC_gCol
),
cute
::
move
(
tC_rCol
),
cute
::
move
(
tC_cCol
),
problem_shape
,
params_ptr
);
}
};
}
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
0 → 100644
View file @
2060e936
#include <stddef.h>
#include <torch/extension.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "cutlass_visitor_2x_broadcast_epilogue.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
/*
This defines a quantized GEMM operation with dequantized output, similar to
torch._scaled_mm. It is defined using the CUTLASS 2.x API, and is used for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
namespace
{
template
<
typename
Arch
,
typename
ElementAB_
,
typename
ElementD_
,
typename
TileShape
,
typename
WarpShape
,
typename
InstructionShape
,
int32_t
MainLoopStages
>
struct
cutlass_2x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
Operator
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
cutlass
::
arch
::
OpMultiplyAddSaturate
,
cutlass
::
arch
::
OpMultiplyAdd
>::
type
;
using
OutputTileThreadMap
=
cutlass
::
epilogue
::
threadblock
::
OutputTileThreadLayout
<
TileShape
,
WarpShape
,
float
,
4
,
1
/* epilogue stages */
>
;
using
Accum
=
cutlass
::
epilogue
::
threadblock
::
VisitorAccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
threadblock
::
VisitorColOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleB
=
cutlass
::
epilogue
::
threadblock
::
VisitorRowOrScalarBroadcast
<
OutputTileThreadMap
,
float
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
using
Compute0
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
threadblock
::
VisitorCompute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute1
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
D
=
cutlass
::
epilogue
::
threadblock
::
VisitorAuxStore
<
OutputTileThreadMap
,
ElementD
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>>
;
using
EVTD
=
cutlass
::
epilogue
::
threadblock
::
Sm80EVT
<
D
,
EVTCompute1
>
;
// clang-format off
using
RowMajor
=
typename
cutlass
::
layout
::
RowMajor
;
using
ColumnMajor
=
typename
cutlass
::
layout
::
ColumnMajor
;
using
KernelType
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemmWithVisitor
<
ElementAB
,
RowMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
ElementAB
,
ColumnMajor
,
cutlass
::
ComplexTransform
::
kNone
,
16
,
float
,
cutlass
::
layout
::
RowMajor
,
4
,
ElementAcc
,
float
,
cutlass
::
arch
::
OpClassTensorOp
,
Arch
,
TileShape
,
WarpShape
,
InstructionShape
,
EVTD
,
cutlass
::
gemm
::
threadblock
::
ThreadblockSwizzleStreamK
,
MainLoopStages
,
Operator
,
1
/* epilogue stages */
>::
GemmKernel
;
// clang-format on
using
Op
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
KernelType
>
;
};
template
<
typename
Gemm
>
void
cutlass_scaled_mm_dq_dispatcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
cutlass
::
gemm
::
GemmCoord
problem_size
{
m
,
n
,
k
};
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
a_scales_ptr
=
a_scales
.
data_ptr
<
float
>
();
auto
b_scales_ptr
=
b_scales
.
data_ptr
<
float
>
();
// If A and B are quantized per-tensor, then these scale tensors are scalars,
// and they are passed in via the second argument.
using
ScaleAArgs
=
typename
Gemm
::
ScaleA
::
Arguments
;
ScaleAArgs
a_args
=
a_scales
.
numel
()
==
1
?
ScaleAArgs
{
nullptr
,
a_scales
.
item
<
float
>
(),
{}}
:
ScaleAArgs
{
a_scales
.
data_ptr
<
float
>
(),
{},
{}};
using
ScaleBArgs
=
typename
Gemm
::
ScaleB
::
Arguments
;
ScaleBArgs
b_args
=
b_scales
.
numel
()
==
1
?
ScaleBArgs
{
nullptr
,
b_scales
.
item
<
float
>
(),
{}}
:
ScaleBArgs
{
b_scales
.
data_ptr
<
float
>
(),
{},
{}};
typename
Gemm
::
EVTCompute0
::
Arguments
evt0_compute_args
{
b_args
};
typename
Gemm
::
EVTCompute1
::
Arguments
evt1_compute_args
{
a_args
,
evt0_compute_args
};
typename
Gemm
::
D
::
Arguments
d_args
{
c_ptr
,
c_stride
};
typename
Gemm
::
EVTD
::
Arguments
epilogue_args
{
evt1_compute_args
,
d_args
,
};
typename
Gemm
::
Op
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemmSplitKParallel
,
// universal mode
problem_size
,
// problem size
1
,
// batch count
epilogue_args
,
a_ptr
,
b_ptr
,
nullptr
,
nullptr
,
0
,
0
,
0
,
0
,
lda
,
ldb
,
ldc
,
ldc
};
// Launch the CUTLASS GEMM kernel.
typename
Gemm
::
Op
gemm_op
;
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
cutlass
::
device_memory
::
allocation
<
uint8_t
>
workspace
(
workspace_size
);
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
cutlass
::
Status
status
=
gemm_op
(
args
,
workspace
.
get
());
CUTLASS_CHECK
(
status
);
}
}
// namespace
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
int8_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm75
,
int8_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
2
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
void
cutlass_scaled_mm_dq_sm80
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
int8_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm80
,
int8_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
void
cutlass_scaled_mm_dq_sm89
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
int8_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
assert
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
int8_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_2x_gemm
<
cutlass
::
arch
::
Sm89
,
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
TileShape
,
WarpShape
,
InstructionShape
,
5
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
0 → 100644
View file @
2060e936
#include <torch/extension.h>
#include <iostream>
#include <sstream>
#include <vector>
// clang-format will break include orders
// clang-format off
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "common.hpp"
// clang-format on
using
namespace
cute
;
/*
This defines a quantized GEMM operation with dequantized output, similar to
torch._scaled_mm. It is defined using the CUTLASS 3.x API, and is used for
NVIDIA GPUs with sm90a (Hopper) or later.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
namespace
{
template
<
typename
ElementAB_
,
typename
ElementD_
,
typename
TileShape
,
typename
ClusterShape
,
typename
KernelSchedule
,
typename
EpilogueSchedule
>
struct
cutlass_3x_gemm
{
using
ElementAB
=
ElementAB_
;
using
ElementD
=
ElementD_
;
using
ElementAcc
=
typename
std
::
conditional
<
std
::
is_same_v
<
ElementAB
,
int8_t
>
,
int32_t
,
float
>::
type
;
using
EpilogueDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
EpilogueDescriptor
<
TileShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementD
,
ElementD
,
EpilogueSchedule
>
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
using
ScaleA
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
/*Stages*/
,
typename
EpilogueDescriptor
::
TileShape
,
float
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
ScaleBDescriptor
=
cutlass
::
epilogue
::
collective
::
detail
::
RowBroadcastDescriptor
<
EpilogueDescriptor
,
float
>
;
using
ScaleB
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
ScaleBDescriptor
::
Stages
,
typename
EpilogueDescriptor
::
TileShape
,
typename
ScaleBDescriptor
::
Element
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
float
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
ScaleB
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementD
,
float
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute1
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
ScaleA
,
EVTCompute0
>
;
using
StrideD
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
ElementC
=
void
;
using
StrideC
=
StrideD
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAcc
,
float
,
ElementC
,
StrideC
,
4
,
ElementD
,
StrideD
,
4
,
EpilogueSchedule
,
EVTCompute1
>::
CollectiveOp
;
static
constexpr
size_t
CEStorageSize
=
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
);
using
Stages
=
typename
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
CEStorageSize
)
>
;
// clang-format off
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
cutlass
::
arch
::
Sm90
,
cutlass
::
arch
::
OpClassTensorOp
,
ElementAB
,
cutlass
::
layout
::
RowMajor
,
16
,
ElementAB
,
cutlass
::
layout
::
ColumnMajor
,
16
,
ElementAcc
,
TileShape
,
ClusterShape
,
Stages
,
KernelSchedule
>::
CollectiveOp
;
// clang-format on
using
KernelType
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
cute
::
Shape
<
int
,
int
,
int
,
int
>
,
CollectiveMainloop
,
CollectiveEpilogue
,
cutlass
::
gemm
::
PersistentScheduler
>
;
struct
GemmKernel
:
public
KernelType
{};
};
template
<
typename
Gemm
>
void
cutlass_scaled_mm_dq_dispatcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
int32_t
m
=
a
.
size
(
0
);
int32_t
n
=
b
.
size
(
1
);
int32_t
k
=
a
.
size
(
1
);
int64_t
lda
=
a
.
stride
(
0
);
int64_t
ldb
=
b
.
stride
(
1
);
int64_t
ldc
=
out
.
stride
(
0
);
using
StrideA
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
StrideB
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
using
StrideC
=
typename
Gemm
::
StrideC
;
StrideA
a_stride
{
lda
,
Int
<
1
>
{},
Int
<
0
>
{}};
StrideB
b_stride
{
ldb
,
Int
<
1
>
{},
Int
<
0
>
{}};
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
typename
GemmKernel
::
Arguments
args
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
prob_shape
,
mainloop_args
,
epilogue_args
};
using
ScaleA_Args
=
typename
Gemm
::
ScaleA
::
Arguments
;
using
ScaleB_Args
=
typename
Gemm
::
ScaleB
::
Arguments
;
ScaleA_Args
a_args
=
a_scales
.
numel
()
==
1
?
ScaleA_Args
{
nullptr
,
a_scales
.
item
<
float
>
(),
{}}
:
ScaleA_Args
{
a_scales
.
data_ptr
<
float
>
(),
{},
{}};
ScaleB_Args
b_args
=
b_scales
.
numel
()
==
1
?
ScaleB_Args
{
nullptr
,
b_scales
.
item
<
float
>
(),
{}}
:
ScaleB_Args
{
b_scales
.
data_ptr
<
float
>
(),
{},
{}};
args
.
epilogue
.
thread
=
{
a_args
,
{
b_args
}};
// Launch the CUTLASS GEMM kernel.
using
GemmOp
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
GemmOp
gemm_op
;
CUTLASS_CHECK
(
gemm_op
.
can_implement
(
args
));
size_t
workspace_size
=
gemm_op
.
get_workspace_size
(
args
);
TORCH_CHECK
(
workspace_size
==
0
);
cutlass
::
Status
status
=
gemm_op
.
run
(
args
);
CUTLASS_CHECK
(
status
);
}
}
// namespace
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
if
(
a
.
dtype
()
==
torch
::
kInt8
)
{
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecialized
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_3x_gemm
<
int8_t
,
cutlass
::
bfloat16_t
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_3x_gemm
<
int8_t
,
cutlass
::
half_t
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
else
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kFloat8_e4m3fn
);
using
TileShape
=
Shape
<
_128
,
_128
,
_128
>
;
using
ClusterShape
=
Shape
<
_1
,
_2
,
_1
>
;
using
KernelSchedule
=
typename
cutlass
::
gemm
::
KernelCpAsyncWarpSpecializedCooperative
;
using
EpilogueSchedule
=
typename
cutlass
::
epilogue
::
TmaWarpSpecializedCooperative
;
if
(
out
.
dtype
()
==
torch
::
kBFloat16
)
{
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_3x_gemm
<
cutlass
::
float_e4m3_t
,
cutlass
::
bfloat16_t
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
TORCH_CHECK
(
out
.
dtype
()
==
torch
::
kFloat16
);
return
cutlass_scaled_mm_dq_dispatcher
<
cutlass_3x_gemm
<
cutlass
::
float_e4m3_t
,
cutlass
::
half_t
,
TileShape
,
ClusterShape
,
KernelSchedule
,
EpilogueSchedule
>>
(
out
,
a
,
b
,
a_scales
,
b_scales
);
}
}
}
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
0 → 100644
View file @
2060e936
#include <c10/cuda/CUDAGuard.h>
#include <cuda_runtime.h>
#include <torch/extension.h>
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
int32_t
major_capability
;
int32_t
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
0
);
cudaDeviceGetAttribute
(
&
minor_capability
,
cudaDevAttrComputeCapabilityMinor
,
0
);
int32_t
version_num
=
major_capability
*
10
+
minor_capability
;
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
b
.
size
(
1
)
==
c
.
size
(
1
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
if
(
version_num
>=
90
)
{
// Hopper
cutlass_scaled_mm_dq_sm90
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
else
if
(
version_num
==
89
)
{
// Ada Lovelace
cutlass_scaled_mm_dq_sm89
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
else
if
(
version_num
>=
80
)
{
// Ampere
cutlass_scaled_mm_dq_sm80
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
else
{
// Turing
TORCH_CHECK
(
version_num
>=
75
);
cutlass_scaled_mm_dq_sm75
(
c
,
a
,
b
,
a_scales
,
b_scales
);
}
}
tests/kernels/test_cutlass.py
0 → 100644
View file @
2060e936
"""Tests for cutlass kernels
Run `pytest tests/kernels/test_cutlass.py`.
"""
from
typing
import
Type
import
pytest
import
torch
from
vllm
import
_custom_ops
as
ops
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
]
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
def
to_fp8
(
tensor
:
torch
.
tensor
):
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
return
torch
.
round
(
tensor
.
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)).
to
(
dtype
=
torch
.
float8_e4m3fn
)
def
to_int8
(
tensor
:
torch
.
tensor
):
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
cutlass_fp8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
a
=
to_fp8
(
torch
.
randn
((
m
,
k
),
device
=
device
))
b
=
to_fp8
(
torch
.
randn
((
n
,
k
),
device
=
device
).
t
())
m_a_scales
=
m
if
per_token_act_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
out
=
ops
.
cutlass_scaled_mm_dq
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
)
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
out_dtype
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-2
,
atol
=
1e-1
)
def
cutlass_int8_gemm_helper
(
m
:
int
,
n
:
int
,
k
:
int
,
per_token_act_quant
:
bool
,
per_out_channel_weight_quant
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]
=
torch
.
bfloat16
,
device
:
str
=
"cuda"
):
# Test for a cutlass kernel with per-token activation quantization
# and per-output channel weight quantization.
a
=
to_int8
(
torch
.
randn
((
m
,
k
),
device
=
device
)
*
5
)
b
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
device
).
t
()
*
5
)
m_a_scales
=
m
if
per_token_act_quant
else
1
n_b_scales
=
n
if
per_out_channel_weight_quant
else
1
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
device
,
dtype
=
torch
.
float32
)
/
10
)
out
=
ops
.
cutlass_scaled_mm_dq
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
)
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
dtype
=
out_dtype
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
):
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
]):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
device
:
str
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
torch
.
bfloat16
,
device
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_cutlass_int8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
device
:
str
):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
torch
.
bfloat16
,
device
)
# For the following two tests:
# N and K correspond to the size of the weight matrix and likely to be multiples
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
capability
<
89
,
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
)
# Test working with a subset of A and B
def
test_cutlass_subset
():
big_m
,
big_n
,
big_k
=
1024
,
1024
,
1024
m
,
n
,
k
=
512
,
512
,
512
whole_a
=
to_int8
(
torch
.
randn
((
big_m
,
big_k
),
device
=
"cuda"
)
*
5
)
whole_b
=
to_int8
(
torch
.
randn
((
big_n
,
big_k
),
device
=
"cuda"
).
t
()
*
5
)
a
=
whole_a
[
0
:
m
,
0
:
k
]
b
=
whole_b
[
0
:
k
,
0
:
n
]
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
out
=
ops
.
cutlass_scaled_mm_dq
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
=
torch
.
bfloat16
)
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
dtype
=
torch
.
bfloat16
)
assert
torch
.
allclose
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
vllm/_custom_ops.py
View file @
2060e936
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Type
import
torch
import
torch
...
@@ -163,6 +163,22 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -163,6 +163,22 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
size_k
)
size_k
)
# cutlass
def
cutlass_scaled_mm_dq
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
a_scales
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
])
->
torch
.
Tensor
:
assert
(
b
.
shape
[
0
]
%
16
==
0
and
b
.
shape
[
1
]
%
16
==
0
)
assert
(
out_dtype
is
torch
.
bfloat16
or
out_dtype
is
torch
.
float16
)
m
=
a
.
shape
[
0
]
n
=
b
.
shape
[
1
]
out
=
torch
.
empty
((
m
,
n
),
dtype
=
out_dtype
,
device
=
a
.
device
)
vllm_ops
.
cutlass_scaled_mm_dq
(
out
,
a
,
b
,
a_scales
,
b_scales
)
return
out
# aqlm
# aqlm
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment