Commit e7c1b7f3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.4-dtk24.04.1'

parents 7462218e 04c62b93
......@@ -273,8 +273,6 @@ __global__ void Code2x8Dequant(
}
__syncthreads();
float res = 0;
int iters = (prob_k / 8 - 1) / (8 * 32) + 1;
while (iters--) {
if (pred && a_gl_rd < a_gl_end) {
......
......@@ -95,6 +95,7 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source) {
return result;
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
} // namespace awq
......
......@@ -17,14 +17,6 @@ Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
namespace vllm {
namespace awq {
// Pack two half values.
static inline __device__ __host__ unsigned __pack_half2(const half x,
const half y) {
unsigned v0 = *((unsigned short*)&x);
unsigned v1 = *((unsigned short*)&y);
return (v1 << 16) | v0;
}
template <int N>
__global__ void __launch_bounds__(64)
gemm_forward_4bit_cuda_m16nXk32(int G, int split_k_iters,
......@@ -42,11 +34,7 @@ __global__ void __launch_bounds__(64)
__shared__ half A_shared[16 * (32 + 8)];
__shared__ half B_shared[32 * (N + 8)];
__shared__ half scaling_factors_shared[N];
__shared__ half zeros_shared[N];
int j_factors1 = ((OC + N - 1) / N);
int blockIdx_x = 0;
int blockIdx_y = blockIdx.x % ((M + 16 - 1) / 16 * j_factors1);
int blockIdx_z = blockIdx.x / ((M + 16 - 1) / 16 * j_factors1);
......@@ -60,7 +48,6 @@ __global__ void __launch_bounds__(64)
static constexpr int row_stride_warp = 32 * 8 / 32;
static constexpr int row_stride = 2 * 32 * 8 / N;
bool ld_zero_flag = (threadIdx.y * 32 + threadIdx.x) * 8 < N;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool ld_A_flag =
(blockIdx_y / j_factors1 * 16 + threadIdx.y * row_stride_warp +
......@@ -145,11 +132,7 @@ __global__ void __launch_bounds__(64)
uint32_t B_loaded =
*(uint32_t*)(B_ptr_local + ax0_ax1_fused_0 * row_stride * (OC / 8));
uint4 B_loaded_fp16 = dequantize_s4_to_fp16x2(B_loaded);
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
// q * scale - zero * scale.
......@@ -367,17 +350,11 @@ __global__ void __launch_bounds__(64)
__global__ void __launch_bounds__(64)
dequantize_weights(int* __restrict__ B, half* __restrict__ scaling_factors,
int* __restrict__ zeros, half* __restrict__ C, int G) {
int j_factors1 = 4;
int row_stride2 = 4;
int split_k_iters = 1;
static constexpr uint32_t ZERO = 0x0;
half B_shared[32 * (128 + 8)];
half* B_shared_ptr2 = B_shared;
half B_shared_warp[32];
int OC = 512;
int N = blockDim.x * gridDim.x; // 2
int col = (blockIdx.x * blockDim.x + threadIdx.x);
int row = blockIdx.y * blockDim.y + threadIdx.y;
......
......@@ -64,8 +64,6 @@ using namespace detail;
// Row vector broadcast
template<
// Row bcast reuses the mbarriers from the epilogue subtile load pipeline, so this must be at least
// ceil_div(StagesC, epi tiles per CTA tile) + 1 to ensure no data races
int Stages,
class CtaTileShapeMNK,
class Element,
......@@ -73,14 +71,12 @@ template<
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcast {
static_assert(Alignment * sizeof_bits_v<Element> % 128 == 0, "sub-16B alignment not supported yet");
static_assert(
(cute::is_same_v<StrideMNL, Stride<_0,_1, _0>>) || // row vector broadcast, e.g. per-col alpha/bias
(cute::is_same_v<StrideMNL, Stride<_0,_1,int>>)); // batched row vector broadcast
static_assert(Stages == 0, "Row broadcast doesn't support smem usage");
static_assert(is_static_v<decltype(take<0,2>(StrideMNL{}))>); // batch stride can be dynamic or static
static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{});
// Accumulator doesn't distribute row elements evenly amongst threads so we must buffer in smem
struct SharedStorage {
alignas(16) array_aligned<Element, size<1>(CtaTileShapeMNK{}) * Stages> smem_row;
struct SharedStorage {
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
};
// This struct has been modified to have a bool indicating that ptr_row is a
......@@ -100,6 +96,12 @@ struct Sm90RowOrScalarBroadcast {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
......@@ -118,15 +120,15 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_HOST_DEVICE
Sm90RowOrScalarBroadcast(Params const& params, SharedStorage const& shared_storage)
: params(params),
smem_row(const_cast<Element*>(shared_storage.smem_row.data())) { }
: params(params)
, smem(const_cast<Element*>(shared_storage.smem.data())) { }
Params params;
Element* smem_row;
Element *smem = nullptr;
CUTLASS_DEVICE bool
is_producer_load_needed() const {
return true;
return false;
}
CUTLASS_DEVICE bool
......@@ -139,78 +141,76 @@ struct Sm90RowOrScalarBroadcast {
return (!params.row_broadcast && *(params.ptr_row) == Element(0));
}
template <int EpiTiles, class GTensor, class STensor>
struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks {
CUTLASS_DEVICE
ProducerLoadCallbacks(GTensor&& gRow, STensor&& sRow, Params const& params)
: gRow(cute::forward<GTensor>(gRow)),
sRow(cute::forward<STensor>(sRow)),
params(params) {}
GTensor gRow; // (CTA_M,CTA_N)
STensor sRow; // (CTA_M,CTA_N,PIPE)
Params const& params;
CUTLASS_DEVICE void
begin(uint64_t* full_mbarrier_ptr, int load_iteration, bool issue_tma_load) {
if (params.ptr_row == nullptr) {
return;
}
if (issue_tma_load) {
// Increment the expect-tx count of the first subtile's mbarrier by the row vector's byte-size
constexpr uint32_t copy_bytes = size<1>(CtaTileShapeMNK{}) * sizeof_bits_v<Element> / 8;
cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes);
// Issue the TMA bulk copy
auto bulk_copy = Copy_Atom<SM90_BULK_COPY_AUTO, Element>{}.with(*full_mbarrier_ptr);
// Filter so we don't issue redundant copies over stride-0 modes
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
copy(bulk_copy, filter(gRow), filter(sRow(_,_,bcast_pipe_index)));
}
}
};
template <class... Args>
CUTLASS_DEVICE auto
get_producer_load_callbacks(ProducerLoadArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
return ProducerLoadCallbacks<EpiTiles, decltype(gRow), decltype(sRow)>(
cute::move(gRow), cute::move(sRow), params);
return EmptyProducerLoadCallbacks{};
}
template <int EpiTiles, class RTensor, class STensor>
template <class GS_GTensor, class GS_STensor, class GS_CTensor, class Tiled_G2S, class SR_STensor, class SR_RTensor, class CTensor, class ThrResidue, class ThrNum>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(RTensor&& tCrRow, STensor&& tCsRow, Params const& params)
: tCrRow(cute::forward<RTensor>(tCrRow)),
tCsRow(cute::forward<STensor>(tCsRow)),
params(params) {}
RTensor tCrRow; // (CPY,CPY_M,CPY_N)
STensor tCsRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
ConsumerStoreCallbacks(
GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_,
GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_,
SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_,
CTensor tCcRow_, ThrResidue residue_tCcRow_, ThrNum thr_num_, Params const& params_)
: tGS_gRow(tGS_gRow_)
, tGS_sRow(tGS_sRow_)
, tGS_cRow(tGS_cRow_)
, tiled_G2S(tiled_g2s_)
, tSR_sRow(tSR_sRow_)
, tSR_rRow(tSR_rRow_)
, tCcRow(tCcRow_)
, residue_tCcRow(residue_tCcRow_)
, params(params_) {}
GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N)
GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N)
GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N)
Tiled_G2S tiled_G2S;
SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
CTensor tCcRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
ThrResidue residue_tCcRow; // (m, n)
ThrNum thr_num;
Params const& params;
CUTLASS_DEVICE void
previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) {
begin() {
if (!params.row_broadcast) {
fill(tCrRow, *(params.ptr_row));
fill(tSR_rRow, *(params.ptr_row));
return;
}
auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); };
Tensor tGS_gRow_flt = filter_zeros(tGS_gRow);
Tensor tGS_sRow_flt = filter_zeros(tGS_sRow);
Tensor tGS_cRow_flt = make_tensor(tGS_cRow.data(), make_layout(tGS_gRow_flt.shape(), tGS_cRow.stride()));
for (int i = 0; i < size(tGS_gRow_flt); ++i) {
if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) {
continue; // OOB of SMEM,
}
if (elem_less(tGS_cRow_flt(i), make_coord(get<0>(residue_tCcRow), get<1>(residue_tCcRow)))) {
tGS_sRow_flt(i) = tGS_gRow_flt(i);
}
else {
tGS_sRow_flt(i) = Element(0); // Set to Zero when OOB so LDS could be issue without any preds.
}
}
synchronize();
}
CUTLASS_DEVICE void
begin_loop(int epi_m, int epi_n) {
if (epi_m == 0) { // Assumes M-major subtile loop
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
int bcast_pipe_index = (load_iteration / EpiTiles) % Stages;
copy_aligned(filter(tCsRow(_,_,_,epi_m,epi_n,bcast_pipe_index)), filter(tCrRow));
if (!params.row_broadcast) return; // Do not issue LDS when row is scalar
Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n));
Tensor tSR_rRow_flt = filter_zeros(tSR_rRow);
copy(tSR_sRow_flt, tSR_rRow_flt);
}
}
......@@ -221,7 +221,7 @@ struct Sm90RowOrScalarBroadcast {
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < FragmentSize; ++i) {
frg_row[i] = tCrRow(epi_v * FragmentSize + i);
frg_row[i] = tSR_rRow(epi_v * FragmentSize + i);
}
return frg_row;
......@@ -234,17 +234,41 @@ struct Sm90RowOrScalarBroadcast {
>
CUTLASS_DEVICE auto
get_consumer_store_callbacks(ConsumerStoreArgs<Args...> const& args) {
auto [M, N, K, L] = args.problem_shape_mnkl;
auto [m, n, k, l] = args.tile_coord_mnkl;
using ThreadCount = decltype(size(args.tiled_copy));
Tensor sRow = make_tensor(make_smem_ptr(smem_row), // (CTA_M,CTA_N,PIPE)
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{}), Stages),
make_stride(_0{},_1{},size<1>(CtaTileShapeMNK{})));
Tensor tCsRow = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,PIPE)
sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrRow = make_tensor_like(take<0,3>(tCsRow)); // (CPY,CPY_M,CPY_N)
constexpr int EpiTiles = decltype(size<1>(zipped_divide(make_layout(take<0,2>(args.tile_shape_mnk)), args.epi_tile)))::value;
return ConsumerStoreCallbacks<EpiTiles, decltype(tCrRow), decltype(tCsRow)>(
cute::move(tCrRow), cute::move(tCsRow), params);
Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow);
Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N)
Tensor sRow = make_tensor(make_smem_ptr(smem),
make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N)
//// G2S: Gmem to Smem
auto tiled_g2s = make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
Layout< Shape<_1, ThreadCount>,
Stride<_0, _1>>{},
Layout<_1>{});
auto thr_g2s = tiled_g2s.get_slice(args.thread_idx);
Tensor tGS_gRow = thr_g2s.partition_S(gRow);
Tensor tGS_sRow = thr_g2s.partition_D(sRow);
//// G2S: Coord
auto cRow = make_identity_tensor(make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})));
Tensor tGS_cRow = thr_g2s.partition_S(cRow);
//// S2R: Smem to Reg
Tensor tSR_sRow = sm90_partition_for_epilogue<ReferenceSrc>(sRow, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N)
return ConsumerStoreCallbacks<decltype(tGS_gRow), decltype(tGS_sRow), decltype(tGS_cRow), decltype(tiled_g2s), decltype(tSR_sRow), decltype(tSR_rRow), decltype(args.tCcD), decltype(args.residue_cD), ThreadCount>(
tGS_gRow,
tGS_sRow,
tGS_cRow, tiled_g2s,
tSR_sRow,
tSR_rRow,
args.tCcD,
args.residue_cD,
ThreadCount{},
params);
}
};
......@@ -285,6 +309,12 @@ struct Sm90ColOrScalarBroadcast {
return args;
}
template <class ProblemShape>
static bool
can_implement(ProblemShape const& problem_shape, Arguments const& args) {
return true;
}
template <class ProblemShape>
static size_t
get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) {
......@@ -328,20 +358,36 @@ struct Sm90ColOrScalarBroadcast {
return EmptyProducerLoadCallbacks{};
}
template<class GTensor, class RTensor>
template<class GTensor, class RTensor, class CTensor, class ProblemShape>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE
ConsumerStoreCallbacks(GTensor&& tCgCol, RTensor&& tCrCol, Params const& params)
: tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
params(params) {}
ConsumerStoreCallbacks(
GTensor&& tCgCol,
RTensor&& tCrCol,
CTensor&& tCcCol,
ProblemShape problem_shape,
Params const& params
):
tCgCol(cute::forward<GTensor>(tCgCol)),
tCrCol(cute::forward<RTensor>(tCrCol)),
tCcCol(cute::forward<CTensor>(tCcCol)),
m(get<0>(problem_shape)),
params(params) {}
GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
RTensor tCrCol;
CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
Params const& params;
int m;
CUTLASS_DEVICE void
begin() {
Tensor pred = make_tensor<bool>(shape(tCgCol));
CUTLASS_PRAGMA_UNROLL
for (int i = 0; i < size(pred); ++i) {
pred(i) = get<0>(tCcCol(i)) < m;
}
if (!params.col_broadcast) {
fill(tCrCol, *(params.ptr_col));
return;
......@@ -349,7 +395,7 @@ struct Sm90ColOrScalarBroadcast {
// Filter so we don't issue redundant copies over stride-0 modes
// (only works if 0-strides are in same location, which is by construction)
copy_aligned(filter(tCgCol), filter(tCrCol));
copy_if(pred, filter(tCgCol), filter(tCrCol));
}
template <typename ElementAccumulator, int FragmentSize>
......@@ -381,8 +427,20 @@ struct Sm90ColOrScalarBroadcast {
mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
return ConsumerStoreCallbacks<decltype(tCgCol), decltype(tCrCol)>(
cute::move(tCgCol), cute::move(tCrCol), params);
// Generate an identity tensor matching the shape of the global tensor and
// partition the same way, this will be used to generate the predicate
// tensor for loading
Tensor cCol = make_identity_tensor(mCol.shape());
Tensor tCcCol = sm90_partition_for_epilogue<ReferenceSrc>( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N)
cCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx);
return ConsumerStoreCallbacks(
cute::move(tCgCol),
cute::move(tCrCol),
cute::move(tCcCol),
args.problem_shape_mnkl,
params
);
}
};
......
#pragma once
#include "cutlass/cutlass.h"
#include <climits>
/**
* Helper function for checking CUTLASS errors
......@@ -10,3 +11,17 @@
TORCH_CHECK(status == cutlass::Status::kSuccess, \
cutlassGetStatusString(status)) \
}
inline uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) {
int max_shared_mem_per_block_opt_in = 0;
cudaDeviceGetAttribute(&max_shared_mem_per_block_opt_in,
cudaDevAttrMaxSharedMemoryPerBlockOptin,
device);
return max_shared_mem_per_block_opt_in;
}
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using namespace cute;
#include "scaled_mm_c2x.cuh"
#include "scaled_mm_c2x_sm75_dispatch.cuh"
#include "scaled_mm_c2x_sm80_dispatch.cuh"
#include "scaled_mm_c2x_sm89_fp8_dispatch.cuh"
#include "scaled_mm_c2x_sm89_int8_dispatch.cuh"
/*
This file defines quantized GEMM operations using the CUTLASS 2.x API, for
NVIDIA GPUs with SM versions prior to sm90 (Hopper).
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace {
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue {
private:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm75_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
return evt_compute_args;
if (out.dtype() == torch::kBFloat16) {
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return vllm::cutlass_gemm_sm75_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
cutlass::arch::OpMultiplyAdd>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
using EVTCompute = typename Epilogue::EVTCompute;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
MainLoopStages, Operator,
1 /* epilogue stages */
>::GemmKernel>;
// clang-format on
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};
template <typename Gemm, typename... EpilogueArgs>
void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
cutlass::gemm::GemmCoord problem_size{m, n, k};
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
using Epilogue = typename Gemm::Epilogue;
auto evt_args =
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
typename Gemm::EVTD::Arguments epilogue_args{
evt_args,
d_args,
};
typename Gemm::Op::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
problem_size, // problem size
1, // batch count
epilogue_args,
a_ptr,
b_ptr,
nullptr,
nullptr,
0,
0,
0,
0,
lda,
ldb,
ldc,
ldc};
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.get(), stream);
CUTLASS_CHECK(status);
}
} // namespace
void cutlass_scaled_mm_sm75(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
out, a, b, a_scales, b_scales);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm75, enable_sm75_to_sm80, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 2>>(
return cutlass_scaled_mm_sm75_epilogue<vllm::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm80_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm80, enable_sm80_to_sm89, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
return vllm::cutlass_gemm_sm80_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
void cutlass_scaled_mm_sm80(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm80_epilogue<vllm::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
template <template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm89_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::bfloat16_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
assert(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, int8_t, cutlass::half_t,
ScaledEpilogue, TileShape, WarpShape, InstructionShape, 5>>(
out, a, b, a_scales, b_scales);
return vllm::cutlass_gemm_sm89_int8_dispatch<int8_t, cutlass::half_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::bfloat16_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
return vllm::cutlass_gemm_sm89_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<cutlass_2x_gemm<
cutlass::arch::Sm89, enable_sm89_to_sm90, cutlass::float_e4m3_t,
cutlass::half_t, ScaledEpilogue, TileShape, WarpShape,
InstructionShape, 5>>(out, a, b, a_scales, b_scales);
return vllm::cutlass_gemm_sm89_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm89(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == out.dtype(),
"currently bias dtype must match output dtype ", out.dtype());
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogueBias>(
out, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm89_epilogue<vllm::ScaledEpilogue>(
out, a, b, a_scales, b_scales);
}
}
#pragma once
#include <stddef.h>
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
// clang-format will break include orders
// clang-format off
#include "cute/tensor.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm_coord.h"
#include "cutlass/arch/mma_sm75.h"
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/device/gemm.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/epilogue/threadblock/fusion/visitors.hpp"
#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h"
#include "broadcast_load_epilogue_c2x.hpp"
#include "common.hpp"
// clang-format on
using namespace cute;
/*
Epilogue functions can be defined to post-process the output before it is
written to GPU memory.
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
namespace vllm {
// Wrappers for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
// __CUDA_ARCH__ is not defined in host code, so this lets us smuggle the ifdef
// into code that will be executed on the device where it is defined.
template <typename Kernel>
struct enable_sm75_to_sm80 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 750 && __CUDA_ARCH__ < 800
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm80_to_sm89 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 890
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
template <typename Kernel>
struct enable_sm89_to_sm90 : Kernel {
template <typename... Args>
CUTLASS_DEVICE static void invoke(Args&&... args) {
#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 890 && __CUDA_ARCH__ < 900
Kernel::invoke(std::forward<Args>(args)...);
#endif
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::threadblock::VisitorAccFetch;
using ScaleA = cutlass::epilogue::threadblock::VisitorColOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch._scaled_mm.
A and B may be both either int8 or fp8_e4m3. A can be quantized per-tensor or
per-row. B can be quantized per-tensor or per-column.
Any combination of per-tensor and per-row or column is supported.
A and B must have symmetric quantization (zero point == 0).
So the GEMM operation is D = (a_scales * A) (b_scales * B), where the
scales are applied elementwise with numpy-style broadcasting.
ScaleA and ScaleB define the epilogue functions that apply the scales for
the A and B operands respectively. These scales may be either per-tensor or
per row or column.
*/
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
public:
using EVTCompute =
cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA, EVTCompute0>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args};
return evt_compute_args;
}
};
template <typename ElementD, typename OutputTileThreadMap>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementD, OutputTileThreadMap> {
private:
using SUPER = ScaledEpilogueBase<ElementD, OutputTileThreadMap>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::threadblock::Sm80EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::threadblock::VisitorCompute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast<
OutputTileThreadMap, ElementD, Stride<Int<0>, Int<1>, Int<0>>>;
public:
using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT<Compute1, ScaleA,
EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
using ScaleAArgs = typename ScaleA::Arguments;
using ScaleBArgs = typename ScaleB::Arguments;
using BiasArgs = typename Bias::Arguments;
ScaleBArgs b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
ScaleAArgs a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
BiasArgs bias_args{static_cast<ElementD*>(bias.data_ptr()), {}};
typename EVTCompute0::Arguments evt0_compute_args{b_args};
typename EVTCompute::Arguments evt_compute_args{a_args, evt0_compute_args,
bias_args};
return evt_compute_args;
}
};
template <typename Arch, template <typename> typename ArchGuard,
typename ElementAB_, typename ElementD_,
template <typename, typename> typename Epilogue_, typename TileShape,
typename WarpShape, typename InstructionShape, int32_t MainLoopStages,
typename FP8MathOperator = cutlass::arch::OpMultiplyAdd>
struct cutlass_2x_gemm {
using ElementAB = ElementAB_;
using ElementD = ElementD_;
using ElementAcc =
typename std::conditional<std::is_same_v<ElementAB, int8_t>, int32_t,
float>::type;
using Operator =
typename std::conditional<std::is_same_v<ElementAB, int8_t>,
cutlass::arch::OpMultiplyAddSaturate,
FP8MathOperator>::type;
using OutputTileThreadMap =
cutlass::epilogue::threadblock::OutputTileThreadLayout<
TileShape, WarpShape, float, 4, 1 /* epilogue stages */
>;
using Epilogue = Epilogue_<ElementD, OutputTileThreadMap>;
using EVTCompute = typename Epilogue::EVTCompute;
using D = cutlass::epilogue::threadblock::VisitorAuxStore<
OutputTileThreadMap, ElementD, cutlass::FloatRoundStyle::round_to_nearest,
Stride<int64_t, Int<1>, Int<0>>>;
using EVTD = cutlass::epilogue::threadblock::Sm80EVT<D, EVTCompute>;
// clang-format off
using RowMajor = typename cutlass::layout::RowMajor;
using ColumnMajor = typename cutlass::layout::ColumnMajor;
using KernelType =
ArchGuard<typename cutlass::gemm::kernel::DefaultGemmWithVisitor<
ElementAB, RowMajor, cutlass::ComplexTransform::kNone, 16,
ElementAB, ColumnMajor, cutlass::ComplexTransform::kNone, 16,
float, cutlass::layout::RowMajor, 4,
ElementAcc, float, cutlass::arch::OpClassTensorOp,
Arch,
TileShape, WarpShape, InstructionShape,
EVTD,
cutlass::gemm::threadblock::ThreadblockSwizzleStreamK,
MainLoopStages, Operator,
1 /* epilogue stages */
>::GemmKernel>;
// clang-format on
using Op = cutlass::gemm::device::GemmUniversalAdapter<KernelType>;
};
template <typename Gemm, typename... EpilogueArgs>
inline void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_params) {
using ElementAB = typename Gemm::ElementAB;
using ElementD = typename Gemm::ElementD;
int32_t m = a.size(0);
int32_t n = b.size(1);
int32_t k = a.size(1);
cutlass::gemm::GemmCoord problem_size{m, n, k};
int64_t lda = a.stride(0);
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideC = Stride<int64_t, Int<1>, Int<0>>;
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
auto a_ptr = static_cast<ElementAB const*>(a.data_ptr());
auto b_ptr = static_cast<ElementAB const*>(b.data_ptr());
auto c_ptr = static_cast<ElementD*>(out.data_ptr());
typename Gemm::D::Arguments d_args{c_ptr, c_stride};
using Epilogue = typename Gemm::Epilogue;
auto evt_args =
Epilogue::prepare_args(std::forward<EpilogueArgs>(epilogue_params)...);
typename Gemm::EVTD::Arguments epilogue_args{
evt_args,
d_args,
};
typename Gemm::Op::Arguments args{
cutlass::gemm::GemmUniversalMode::kGemmSplitKParallel, // universal mode
problem_size, // problem size
1, // batch count
epilogue_args,
a_ptr,
b_ptr,
nullptr,
nullptr,
0,
0,
0,
0,
lda,
ldb,
ldc,
ldc};
// Launch the CUTLASS GEMM kernel.
typename Gemm::Op gemm_op;
size_t workspace_size = gemm_op.get_workspace_size(args);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
CUTLASS_CHECK(gemm_op.can_implement(args));
cutlass::Status status = gemm_op(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename Gemm, typename FallbackGemm, typename... EpilogueArgs>
inline void fallback_cutlass_gemm_caller(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
// In some cases, the GPU isn't able to accommodate the
// shared memory requirements of the Gemm. In such cases, use
// the FallbackGemm instead.
static const int max_shared_mem_per_block_opt_in =
get_cuda_max_shared_memory_per_block_opt_in(0);
size_t const gemm_shared_mem_size =
sizeof(typename Gemm::KernelType::SharedStorage);
size_t const fallback_gemm_shared_mem_size =
sizeof(typename FallbackGemm::KernelType::SharedStorage);
if (gemm_shared_mem_size <= max_shared_mem_per_block_opt_in) {
return cutlass_gemm_caller<Gemm>(out, a, b,
std::forward<EpilogueArgs>(args)...);
} else {
TORCH_CHECK(fallback_gemm_shared_mem_size <=
max_shared_mem_per_block_opt_in);
return cutlass_gemm_caller<FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM75 based on the Gemm
* shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_default {
// This config is used in 2 cases,
// - M in (256, inf]
// - M in (64, 128]
// Shared memory required by this Gemm 32768
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M256 {
// M in (128, 256]
// Shared memory required by this Gemm 65536
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M64 {
// M in (32, 64]
// Shared memory required by this Gemm 49152
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm75_config_M32 {
// M in [1, 32]
// Shared memory required by this Gemm 49152
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<8, 8, 16>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm75, enable_sm75_to_sm80, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 2>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm75_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
using Cutlass2xGemmDefault =
typename sm75_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM256 =
typename sm75_config_M256<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128 = Cutlass2xGemmDefault;
using Cutlass2xGemmM64 =
typename sm75_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM32 =
typename sm75_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm75_config_default has the least shared-memory requirements.
using FallbackGemm = Cutlass2xGemmDefault;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// M in [1, 32]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM256, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM80 based on the Gemm
* shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_default {
// This config is used in 2 cases,
// - M in (128, inf)
// - M in (64, 128] and N >= 8192
// Shared Memory required by this Gemm - 81920 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M64 {
// This config is used in 2 cases,
// - M in (32, 64]
// - M in (64, 128] and N < 8192
// Shared Memory required by this Gemm - 122880 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M32 {
// M in (16, 32]
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm80_config_M16 {
// M in [1, 16]
// Shared Memory required by this Gemm - 51200 bytes
static_assert(std::is_same<InType, int8_t>());
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm80, enable_sm80_to_sm89, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm80_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
using Cutlass2xGemmDefault =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128BigN =
typename sm80_config_default<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM128SmallN =
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM64 =
typename sm80_config_M64<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM32 =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
using Cutlass2xGemmM16 =
typename sm80_config_M16<InType, OutType, Epilogue>::Cutlass2xGemm;
// Due to shared memory requirements, some Gemms may fail to run on some
// GPUs. As the name indicates, the Fallback Gemm is used as an alternative
// in such cases.
// sm80_config_M16 has the least shared-memory requirement. However,
// based on some profiling, we select sm80_config_M32 as a better alternative
// performance wise.
using FallbackGemm =
typename sm80_config_M32<InType, OutType, Epilogue>::Cutlass2xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(16), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM16, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM32, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return fallback_cutlass_gemm_caller<Cutlass2xGemmM64, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
uint32_t const n = out.size(1);
bool const small_n = n < 8192;
if (small_n) {
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128SmallN,
FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return fallback_cutlass_gemm_caller<Cutlass2xGemmM128BigN, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else {
// M in (128, inf)
return fallback_cutlass_gemm_caller<Cutlass2xGemmDefault, FallbackGemm>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm
#pragma once
#include "scaled_mm_c2x.cuh"
#include "cutlass/float8.h"
/**
* This file defines Gemm kernel configurations for SM89 (FP8) based on the Gemm
* shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_fp8_fallback_gemm {
// Shared Memory required by this Gemm - 61440 bytes
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 64>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5,
FP8MathOperator>;
};
struct sm89_fp8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M128 {
// M in (64, 128]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<128, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8196) {
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAdd;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = typename cutlass::gemm::GemmShape<32, 128, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5, FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_fp8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
using FP8MathOperator = typename cutlass::arch::OpMultiplyAddFastAccum;
static const int32_t MainLoopStages = 5;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
using FallbackGemm =
typename sm89_fp8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = typename cutlass::gemm::GemmShape<16, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 24576) {
using TileShape = typename cutlass::gemm::GemmShape<16, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = typename cutlass::gemm::GemmShape<32, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, MainLoopStages,
FP8MathOperator>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_fp8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return sm89_fp8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_fp8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_fp8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_fp8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_fp8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_fp8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm
#pragma once
#include "scaled_mm_c2x.cuh"
/**
* This file defines Gemm kernel configurations for SM89 (int8) based on the
* Gemm shape.
*/
namespace vllm {
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue>
struct sm89_int8_fallback_gemm {
// Shared mem requirement : 61440
static_assert(std::is_same<InType, int8_t>());
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
static int32_t const MainLoopStages = 5;
using Cutlass2xGemm =
cutlass_2x_gemm<cutlass::arch::Sm89, enable_sm89_to_sm90, InType, OutType,
Epilogue, TileShape, WarpShape, InstructionShape, 5>;
};
struct sm89_int8_config_default {
// M in (256, inf)
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M256 {
// M in (128, 256]
using WarpShape = typename cutlass::gemm::GemmShape<64, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 4096) {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<256, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M128 {
// M in (64, 128]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (np2 <= 16384) {
using TileShape = cutlass::gemm::GemmShape<128, 128, 64>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M64 {
// M in (32, 64]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<64, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<64, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 3>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M32 {
// M in (16, 32]
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<32, 64, 128>;
using WarpShape = cutlass::gemm::GemmShape<16, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<32, 128, 128>;
using WarpShape = cutlass::gemm::GemmShape<32, 64, 64>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
struct sm89_int8_config_M16 {
// M in [1, 16]
using WarpShape = typename cutlass::gemm::GemmShape<16, 64, 64>;
using InstructionShape = typename cutlass::gemm::GemmShape<16, 8, 32>;
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
static void dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b, EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
using FallbackGemm =
typename sm89_int8_fallback_gemm<InType, OutType,
Epilogue>::Cutlass2xGemm;
uint32_t const n = out.size(1);
uint32_t const np2 = next_pow_2(n);
if (np2 <= 8192) {
using TileShape = cutlass::gemm::GemmShape<16, 64, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 5>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
using TileShape = cutlass::gemm::GemmShape<16, 128, 128>;
return vllm::fallback_cutlass_gemm_caller<
vllm::cutlass_2x_gemm<cutlass::arch::Sm89, vllm::enable_sm89_to_sm90,
InType, OutType, Epilogue, TileShape, WarpShape,
InstructionShape, 4>,
FallbackGemm>(out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
};
template <typename InType, typename OutType,
template <typename, typename> typename Epilogue,
typename... EpilogueArgs>
inline void cutlass_gemm_sm89_int8_dispatch(torch::Tensor& out,
torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 16) {
// M in [1, 16]
return sm89_int8_config_M16::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 32) {
// M in (16, 32]
return sm89_int8_config_M32::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 64) {
// M in (32, 64]
return sm89_int8_config_M64::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// M in (64, 128]
return sm89_int8_config_M128::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 256) {
// M in (128, 256]
return sm89_int8_config_M256::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// M in (256, inf)
return sm89_int8_config_default::dispatch<InType, OutType, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
} // namespace vllm
......@@ -18,8 +18,6 @@
#include "cute/atom/mma_atom.hpp"
#include "cutlass/numeric_types.h"
#include "cutlass/util/device_memory.h"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
......@@ -44,11 +42,6 @@ using namespace cute;
namespace {
uint32_t next_pow_2(uint32_t const num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
// A wrapper for the GEMM kernel that is used to guard against compilation on
// architectures that will never use the kernel. The purpose of this is to
// reduce the size of the compiled binary.
......@@ -64,6 +57,24 @@ struct enable_sm90_or_later : Kernel {
}
};
/*
* This class provides the common ScaleA and ScaleB descriptors for the
* ScaledEpilogue and ScaledEpilogueBias classes.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBase {
protected:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
Stride<Int<0>, Int<1>, Int<0>>>;
};
/*
This epilogue function defines a quantized GEMM operation similar to
torch.scaled_mm_.
......@@ -81,21 +92,13 @@ struct enable_sm90_or_later : Kernel {
per row or column.
*/
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogue {
struct ScaledEpilogue
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using Accum = cutlass::epilogue::fusion::Sm90AccFetch;
using ScaleA = cutlass::epilogue::fusion::Sm90ColOrScalarBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, float,
Stride<Int<1>, Int<0>, Int<0>>>;
using ScaleBDescriptor =
cutlass::epilogue::collective::detail::RowBroadcastDescriptor<
EpilogueDescriptor, float>;
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
ScaleBDescriptor::Stages, typename EpilogueDescriptor::TileShape,
typename ScaleBDescriptor::Element, Stride<Int<0>, Int<1>, Int<0>>>;
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
......@@ -125,6 +128,50 @@ struct ScaledEpilogue {
}
};
template <typename ElementAcc, typename ElementD, typename EpilogueDescriptor>
struct ScaledEpilogueBias
: private ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor> {
private:
using SUPER = ScaledEpilogueBase<ElementAcc, ElementD, EpilogueDescriptor>;
using Accum = typename SUPER::Accum;
using ScaleA = typename SUPER::ScaleA;
using ScaleB = typename SUPER::ScaleB;
using Compute0 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiplies, float, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using EVTCompute0 =
cutlass::epilogue::fusion::Sm90EVT<Compute0, ScaleB, Accum>;
using Compute1 = cutlass::epilogue::fusion::Sm90Compute<
cutlass::multiply_add, ElementD, float,
cutlass::FloatRoundStyle::round_to_nearest>;
using Bias = cutlass::epilogue::fusion::Sm90RowBroadcast<
0 /*Stages*/, typename EpilogueDescriptor::TileShape, ElementD,
Stride<Int<0>, Int<1>, Int<0>>, 128 / sizeof_bits_v<ElementD>, false>;
public:
using EVTCompute =
cutlass::epilogue::fusion::Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
using ArgumentType = typename EVTCompute::Arguments;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias) {
using ScaleA_Args = typename ScaleA::Arguments;
using ScaleB_Args = typename ScaleB::Arguments;
using Bias_Args = typename Bias::Arguments;
ScaleA_Args a_args{a_scales.data_ptr<float>(), a_scales.numel() != 1, {}};
ScaleB_Args b_args{b_scales.data_ptr<float>(), b_scales.numel() != 1, {}};
Bias_Args bias_args{static_cast<ElementD*>(bias.data_ptr())};
return ArgumentType{a_args, {b_args}, bias_args};
}
};
template <typename ElementAB_, typename ElementD_,
template <typename, typename, typename> typename Epilogue_,
typename TileShape, typename ClusterShape, typename KernelSchedule,
......@@ -194,12 +241,12 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
int64_t ldb = b.stride(1);
int64_t ldc = out.stride(0);
using StrideA = Stride<int64_t, Int<1>, Int<0>>;
using StrideB = Stride<int64_t, Int<1>, Int<0>>;
using StrideA = Stride<int64_t, Int<1>, int64_t>;
using StrideB = Stride<int64_t, Int<1>, int64_t>;
using StrideC = typename Gemm::StrideC;
StrideA a_stride{lda, Int<1>{}, Int<0>{}};
StrideB b_stride{ldb, Int<1>{}, Int<0>{}};
StrideA a_stride{lda, Int<1>{}, 0};
StrideB b_stride{ldb, Int<1>{}, 0};
StrideC c_stride{ldc, Int<1>{}, Int<0>{}};
using GemmKernel = typename Gemm::GemmKernel;
......@@ -225,24 +272,26 @@ void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a,
CUTLASS_CHECK(gemm_op.can_implement(args));
size_t workspace_size = gemm_op.get_workspace_size(args);
cutlass::device_memory::allocation<uint8_t> workspace(workspace_size);
auto const workspace_options =
torch::TensorOptions().dtype(torch::kUInt8).device(a.device());
auto workspace = torch::empty(workspace_size, workspace_options);
auto stream = at::cuda::getCurrentCUDAStream(a.get_device());
cutlass::Status status = gemm_op.run(args, workspace.get(), stream);
cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream);
CUTLASS_CHECK(status);
}
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue, int32_t M>
struct sm90_fp8_config {
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config_default {
// M in (128, inf)
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
......@@ -250,14 +299,14 @@ struct sm90_fp8_config {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
struct sm90_fp8_config_M128 {
// M in (64, 128]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
......@@ -265,7 +314,8 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 128> {
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
struct sm90_fp8_config_M64 {
// M in [1, 64]
static_assert(std::is_same<InType, cutlass::float_e4m3_t>());
using KernelSchedule =
cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum;
......@@ -278,6 +328,78 @@ struct sm90_fp8_config<InType, OutType, Epilogue, 64> {
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_default {
// For M > 128 and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M128 {
// For M in (64, 128] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _128>;
using ClusterShape = Shape<_2, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M64 {
// For M in (32, 64] and any N
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _1, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NBig {
// For M in [1, 32] and N >= 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _128, _256>;
using ClusterShape = Shape<_1, _4, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue>
struct sm90_int8_config_M32_NSmall {
// For M in [1, 32] and N < 8192
static_assert(std::is_same<InType, int8_t>());
using KernelSchedule = typename cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
using TileShape = Shape<_64, _64, _256>;
using ClusterShape = Shape<_1, _8, _1>;
using Cutlass3xGemm =
cutlass_3x_gemm<InType, OutType, Epilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>;
};
} // namespace
template <typename InType, typename OutType,
......@@ -291,11 +413,12 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
using Cutlass3xGemmDefault =
typename sm90_fp8_config<InType, OutType, Epilogue, 0>::Cutlass3xGemm;
typename sm90_fp8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_fp8_config<InType, OutType, Epilogue, 64>::Cutlass3xGemm;
typename sm90_fp8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_fp8_config<InType, OutType, Epilogue, 128>::Cutlass3xGemm;
typename sm90_fp8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
uint32_t const m = a.size(0);
uint32_t const mp2 =
......@@ -316,49 +439,111 @@ void cutlass_gemm_sm90_fp8_dispatch(torch::Tensor& out, torch::Tensor const& a,
}
}
void cutlass_scaled_mm_sm90(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
template <typename InType, typename OutType,
template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_gemm_sm90_int8_dispatch(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... args) {
static_assert(std::is_same<InType, int8_t>());
TORCH_CHECK(a.dtype() == torch::kInt8);
TORCH_CHECK(b.dtype() == torch::kInt8);
using Cutlass3xGemmDefault =
typename sm90_int8_config_default<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM128 =
typename sm90_int8_config_M128<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM64 =
typename sm90_int8_config_M64<InType, OutType, Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NBig =
typename sm90_int8_config_M32_NBig<InType, OutType,
Epilogue>::Cutlass3xGemm;
using Cutlass3xGemmM32NSmall =
typename sm90_int8_config_M32_NSmall<InType, OutType,
Epilogue>::Cutlass3xGemm;
uint32_t const n = out.size(1);
bool const is_small_n = n < 8192;
uint32_t const m = a.size(0);
uint32_t const mp2 =
std::max(static_cast<uint32_t>(32), next_pow_2(m)); // next power of 2
if (mp2 <= 32) {
// m in [1, 32]
if (is_small_n) {
return cutlass_gemm_caller<Cutlass3xGemmM32NSmall>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
return cutlass_gemm_caller<Cutlass3xGemmM32NBig>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
} else if (mp2 <= 64) {
// m in (32, 64]
return cutlass_gemm_caller<Cutlass3xGemmM64>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else if (mp2 <= 128) {
// m in (64, 128]
return cutlass_gemm_caller<Cutlass3xGemmM128>(
out, a, b, std::forward<EpilogueArgs>(args)...);
} else {
// m in (128, inf)
return cutlass_gemm_caller<Cutlass3xGemmDefault>(
out, a, b, std::forward<EpilogueArgs>(args)...);
}
}
template <template <typename, typename, typename> typename Epilogue,
typename... EpilogueArgs>
void cutlass_scaled_mm_sm90_epilogue(torch::Tensor& out, torch::Tensor const& a,
torch::Tensor const& b,
EpilogueArgs&&... epilogue_args) {
if (a.dtype() == torch::kInt8) {
TORCH_CHECK(b.dtype() == torch::kInt8);
using TileShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _2, _1>;
using KernelSchedule =
typename cutlass::gemm::KernelTmaWarpSpecializedPingpong;
using EpilogueSchedule = typename cutlass::epilogue::TmaWarpSpecialized;
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_caller<cutlass_3x_gemm<
int8_t, cutlass::bfloat16_t, ScaledEpilogue, TileShape, ClusterShape,
KernelSchedule, EpilogueSchedule>>(out, a, b, a_scales, b_scales);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::bfloat16_t,
Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_caller<
cutlass_3x_gemm<int8_t, cutlass::half_t, ScaledEpilogue, TileShape,
ClusterShape, KernelSchedule, EpilogueSchedule>>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_sm90_int8_dispatch<int8_t, cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
} else {
TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn);
TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn);
if (out.dtype() == torch::kBFloat16) {
return cutlass_gemm_sm90_fp8_dispatch<
cutlass::float_e4m3_t, cutlass::bfloat16_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::bfloat16_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
} else {
TORCH_CHECK(out.dtype() == torch::kFloat16);
return cutlass_gemm_sm90_fp8_dispatch<cutlass::float_e4m3_t,
cutlass::half_t, ScaledEpilogue>(
out, a, b, a_scales, b_scales);
cutlass::half_t, Epilogue>(
out, a, b, std::forward<EpilogueArgs>(epilogue_args)...);
}
}
}
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
TORCH_CHECK(a_scales.dtype() == torch::kFloat32);
TORCH_CHECK(b_scales.dtype() == torch::kFloat32);
if (bias) {
TORCH_CHECK(bias->dtype() == c.dtype(),
"currently bias dtype must match output dtype ", c.dtype());
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogueBias>(
c, a, b, a_scales, b_scales, *bias);
} else {
return cutlass_scaled_mm_sm90_epilogue<ScaledEpilogue>(c, a, b, a_scales,
b_scales);
}
}
#endif
......@@ -6,28 +6,49 @@
void cutlass_scaled_mm_sm75(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm80(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b,
torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias);
#endif
bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability) {
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if (cuda_device_capability >= 90) {
return CUDA_VERSION >= 12000;
} else if (cuda_device_capability >= 89) {
return CUDA_VERSION >= 12040;
}
#endif
return false;
}
void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
torch::Tensor const& b, torch::Tensor const& a_scales,
torch::Tensor const& b_scales) {
torch::Tensor const& b_scales,
c10::optional<torch::Tensor> const& bias) {
int32_t major_capability;
int32_t minor_capability;
cudaDeviceGetAttribute(&major_capability, cudaDevAttrComputeCapabilityMajor,
......@@ -50,6 +71,11 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
b.stride(1) % 16 == 0); // 16 Byte Alignment
TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous());
if (bias) {
TORCH_CHECK(bias->numel() == b.size(1) && bias->is_contiguous() &&
bias->dim() == 1);
}
at::cuda::OptionalCUDAGuard const device_guard(device_of(a));
if (version_num >= 90) {
......@@ -57,19 +83,19 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a,
// Guard against compilation issues for sm90 kernels
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias);
#else
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
#endif
} else if (version_num == 89) {
// Ada Lovelace
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias);
} else if (version_num >= 80) {
// Ampere
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias);
} else {
// Turing
TORCH_CHECK(version_num >= 75);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales);
cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias);
}
}
......@@ -526,6 +526,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert(false);
return {}; // Squash missing return statement warning
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
......@@ -536,6 +537,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert(false);
return {}; // Squash missing return statement warning
}
// The following macro is used to dispatch the conversion function based on
......
......@@ -7,6 +7,8 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "../../reduction_utils.cuh"
namespace vllm {
__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
......@@ -21,10 +23,16 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template <typename scalar_t>
template <bool is_scale_inverted>
__device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
const scalar_t val, const float inverted_scale) {
float x = static_cast<float>(val) * inverted_scale;
float const val, float const scale) {
float x = 0.0f;
if constexpr (is_scale_inverted) {
x = val * scale;
} else {
x = val / scale;
}
float r = fmax(-FP8_E4M3_MAX, fmin(x, FP8_E4M3_MAX));
return static_cast<c10::Float8_e4m3fn>(r);
}
......@@ -40,7 +48,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
const scalar_t* __restrict__ input,
int64_t num_elems) {
__shared__ float cache[1024];
int i = blockDim.x * blockIdx.x + threadIdx.x;
int64_t i = blockDim.x * blockIdx.x + threadIdx.x;
// First store maximum for all values processes by
// the current thread in cache[threadIdx.x]
......@@ -87,6 +95,70 @@ typedef struct __align__(4) {
}
float8x4_t;
template <typename scalar_t>
__device__ float thread_max_vec(scalar_t const* __restrict__ input,
int64_t const num_elems, int const tid,
int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
int64_t const num_vec_elems = num_elems >> 2;
float absmax_val = 0.0f;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
absmax_val = max(absmax_val, fabs(in_vec.x));
absmax_val = max(absmax_val, fabs(in_vec.y));
absmax_val = max(absmax_val, fabs(in_vec.z));
absmax_val = max(absmax_val, fabs(in_vec.w));
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
absmax_val = max(absmax_val, fabs(input[i]));
}
return absmax_val;
}
template <typename scalar_t, bool is_scale_inverted>
__device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
scalar_t const* __restrict__ input,
float const scale,
int64_t const num_elems,
int const tid, int const step) {
// Vectorized input/output to better utilize memory bandwidth.
vec4_t<scalar_t> const* vectorized_in =
reinterpret_cast<vec4_t<scalar_t> const*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
int64_t const num_vec_elems = num_elems >> 2;
#pragma unroll 4
for (int64_t i = tid; i < num_vec_elems; i += step) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
out_vec.x = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.x), scale);
out_vec.y = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.y), scale);
out_vec.z = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.z), scale);
out_vec.w = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(in_vec.w), scale);
vectorized_out[i] = out_vec;
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int64_t i = num_vec_elems * 4 + tid; i < num_elems; i += step) {
out[i] = scaled_fp8_conversion<is_scale_inverted>(
static_cast<float>(input[i]), scale);
}
}
template <typename scalar_t>
__global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
const scalar_t* __restrict__ input,
......@@ -97,38 +169,68 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
// Invert the scale so that we can use multiplications to avoid expensive
// division.
const float inverted_scale = 1.0f / (*scale);
scaled_fp8_conversion_vec<scalar_t, true>(
out, input, inverted_scale, num_elems, tid, blockDim.x * gridDim.x);
}
// Vectorized input/output to better utilize memory bandwidth.
const vec4_t<scalar_t>* vectorized_in =
reinterpret_cast<const vec4_t<scalar_t>*>(input);
float8x4_t* vectorized_out = reinterpret_cast<float8x4_t*>(out);
template <typename scalar_t>
__global__ void dynamic_per_token_scaled_fp8_quant_kernel(
c10::Float8_e4m3fn* __restrict__ out, float* __restrict__ scale,
scalar_t const* __restrict__ input, float const* __restrict__ scale_ub,
const int hidden_size) {
float const min_scaling_factor = 1.0f / (FP8_E4M3_MAX * 512.f);
int num_vec_elems = num_elems >> 2;
int const tid = threadIdx.x;
int const token_idx = blockIdx.x;
#pragma unroll 4
for (int i = tid; i < num_vec_elems; i += blockDim.x * gridDim.x) {
vec4_t<scalar_t> in_vec = vectorized_in[i];
float8x4_t out_vec;
scalar_t const* __restrict__ token_input = &input[token_idx * hidden_size];
c10::Float8_e4m3fn* __restrict__ token_output = &out[token_idx * hidden_size];
out_vec.x = scaled_fp8_conversion(in_vec.x, inverted_scale);
out_vec.y = scaled_fp8_conversion(in_vec.y, inverted_scale);
out_vec.z = scaled_fp8_conversion(in_vec.z, inverted_scale);
out_vec.w = scaled_fp8_conversion(in_vec.w, inverted_scale);
vectorized_out[i] = out_vec;
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
bool const can_vectorize = hidden_size % 4 == 0;
float absmax_val = 0.0f;
if (can_vectorize) {
absmax_val = thread_max_vec(token_input, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
float const x = static_cast<float>(token_input[i]);
absmax_val = max(absmax_val, fabs(x));
}
}
// Handle the remaining elements if num_elems is not divisible by 4
for (int i = num_vec_elems * 4 + tid; i < num_elems;
i += blockDim.x * gridDim.x) {
out[i] = scaled_fp8_conversion(input[i], inverted_scale);
float const block_absmax_val_maybe = blockReduceMax(absmax_val);
__shared__ float token_scale;
if (tid == 0) {
if (scale_ub) {
token_scale = min(block_absmax_val_maybe, *scale_ub);
} else {
token_scale = block_absmax_val_maybe;
}
// token scale computation
token_scale = max(token_scale / FP8_E4M3_MAX, min_scaling_factor);
scale[token_idx] = token_scale;
}
__syncthreads();
// Note that we don't use inverted scales so we can match FBGemm impl.
if (can_vectorize) {
scaled_fp8_conversion_vec<scalar_t, false>(
token_output, token_input, token_scale, hidden_size, tid, blockDim.x);
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
token_output[i] = scaled_fp8_conversion<false>(
static_cast<float>(token_input[i]), token_scale);
}
}
}
} // namespace vllm
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor const& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
......@@ -144,9 +246,9 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
});
}
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor& input, // [..., d]
torch::Tensor& scale) // [1]
void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scale) // [1]
{
int64_t num_tokens = input.numel() / input.size(-1);
int64_t num_elems = input.numel();
......@@ -163,3 +265,28 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
scale.data_ptr<float>(), num_elems);
});
}
void dynamic_per_token_scaled_fp8_quant(
torch::Tensor& out, // [..., d]
torch::Tensor const& input, // [..., d]
torch::Tensor& scales, std::optional<at::Tensor> const& scale_ub) {
TORCH_CHECK(input.is_contiguous());
TORCH_CHECK(out.is_contiguous());
int const hidden_size = input.size(-1);
int const num_tokens = input.numel() / hidden_size;
dim3 const grid(num_tokens);
dim3 const block(std::min(hidden_size, 1024));
const at::cuda::OptionalCUDAGuard device_guard(device_of(input));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(
input.scalar_type(), "dynamic_per_token_scaled_fp8_quant_kernel", [&] {
vllm::dynamic_per_token_scaled_fp8_quant_kernel<scalar_t>
<<<grid, block, 0, stream>>>(
out.data_ptr<c10::Float8_e4m3fn>(), scales.data_ptr<float>(),
input.data_ptr<scalar_t>(),
scale_ub.has_value() ? scale_ub->data_ptr<float>() : nullptr,
hidden_size);
});
}
/*
* Modified by Neural Magic
* Copyright (C) Marlin.2024 Elias Frantar
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/*
* Adapted from https://github.com/IST-DASLab/marlin
*/
#include "../gptq_marlin/marlin.cuh"
#include "../gptq_marlin/marlin_dtypes.cuh"
using namespace marlin;
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
"only float16 and bfloat16 is supported");
template <typename T>
inline std::string str(T x) {
return std::to_string(x);
}
namespace fp8_marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {}
} // namespace fp8_marlin
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
// m16n8k16 tensor core mma instruction with fp16 inputs and fp32
// output/accumulation.
template <typename scalar_t>
__device__ inline void mma(const typename ScalarType<scalar_t>::FragA& a_frag,
const typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragC& frag_c) {
const uint32_t* a = reinterpret_cast<const uint32_t*>(&a_frag);
const uint32_t* b = reinterpret_cast<const uint32_t*>(&frag_b);
float* c = reinterpret_cast<float*>(&frag_c);
if constexpr (std::is_same<scalar_t, half>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else if constexpr (std::is_same<scalar_t, nv_bfloat16>::value) {
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 "
"{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n"
: "=f"(c[0]), "=f"(c[1]), "=f"(c[2]), "=f"(c[3])
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1]),
"f"(c[0]), "f"(c[1]), "f"(c[2]), "f"(c[3]));
} else {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
}
// Instruction for loading a full 16x16 matrix fragment of operand A from shared
// memory, directly in tensor core layout.
template <typename scalar_t>
__device__ inline void ldsm4(typename ScalarType<scalar_t>::FragA& frag_a,
const void* smem_ptr) {
uint32_t* a = reinterpret_cast<uint32_t*>(&frag_a);
uint32_t smem = static_cast<uint32_t>(__cvta_generic_to_shared(smem_ptr));
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0,%1,%2,%3}, [%4];\n"
: "=r"(a[0]), "=r"(a[1]), "=r"(a[2]), "=r"(a[3])
: "r"(smem));
}
// Fast FP8ToFp16/FP8ToBf16: Efficiently dequantize 8bit fp8_e4m3 values to fp16
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
// Constants for FP8 (E4M3) and FP16 formats
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, FP16_EXPONENT = 5;
constexpr int RIGHT_SHIFT = FP16_EXPONENT - FP8_EXPONENT;
// Calculate MASK for extracting mantissa and exponent
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to FP16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (FP16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
const half2 bias_reg = __float2half2_rn(float(1 << BIAS_OFFSET));
// Convert to half2 and apply bias
typename ScalarType<half>::FragB frag_b;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = __hmul2(*reinterpret_cast<const half2*>(&Out1), bias_reg);
frag_b[0] = __hmul2(*reinterpret_cast<const half2*>(&Out2), bias_reg);
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit<nv_bfloat16>(int q) {
// Constants for FP8 (E4M3) and BF16 formats
constexpr int FP8_EXPONENT = 4, FP8_MANTISSA = 3, BF16_EXPONENT = 8;
constexpr int RIGHT_SHIFT = BF16_EXPONENT - FP8_EXPONENT;
// Calculate MASK for extracting mantissa and exponent
constexpr int MASK1 = 0x80000000;
constexpr int MASK2 = MASK1 >> (FP8_EXPONENT + FP8_MANTISSA);
constexpr int MASK3 = MASK2 & 0x7fffffff;
constexpr int MASK = MASK3 | (MASK3 >> 16);
// Final MASK value: 0x7F007F00
// Extract and shift FP8 values to BF16 format
int Out1 = (q & 0x80008000) | ((q & MASK) >> RIGHT_SHIFT);
int Out2 = ((q << 8) & 0x80008000) | (((q << 8) & MASK) >> RIGHT_SHIFT);
// Construct and apply exponent bias
constexpr int BIAS_OFFSET =
(1 << (BF16_EXPONENT - 1)) - (1 << (FP8_EXPONENT - 1));
// Add 127 (float exponent bias) to BIAS_OFFSET and shift to float exponent
// position
constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23;
const nv_bfloat162 bias_reg =
__float2bfloat162_rn(*reinterpret_cast<const float*>(&BIAS));
// Convert to bfloat162 and apply bias
typename ScalarType<nv_bfloat16>::FragB frag_b;
// Note: reverse indexing is intentional because weights are permuted
frag_b[1] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out1), bias_reg);
frag_b[0] = __hmul2(*reinterpret_cast<const nv_bfloat162*>(&Out2), bias_reg);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template <typename scalar_t>
__device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::FragS& frag_s,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 s =
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_s)[i]);
frag_b[0] = __hmul2(frag_b[0], s);
frag_b[1] = __hmul2(frag_b[1], s);
}
// Given 2 floats multiply by 2 scales (halves)
template <typename scalar_t>
__device__ inline void scale_float(float* c,
typename ScalarType<scalar_t>::FragS& s) {
scalar_t* s_ptr = reinterpret_cast<scalar_t*>(&s);
c[0] = __fmul_rn(c[0], ScalarType<scalar_t>::num2float(s_ptr[0]));
c[1] = __fmul_rn(c[1], ScalarType<scalar_t>::num2float(s_ptr[1]));
}
// Wait until barrier reaches `count`, then lock for current threadblock.
__device__ inline void barrier_acquire(int* lock, int count) {
if (threadIdx.x == 0) {
int state = -1;
do
// Guarantee that subsequent writes by this threadblock will be visible
// globally.
asm volatile("ld.global.acquire.gpu.b32 %0, [%1];\n"
: "=r"(state)
: "l"(lock));
while (state != count);
}
__syncthreads();
}
// Release barrier and increment visitation count.
__device__ inline void barrier_release(int* lock, bool reset = false) {
__syncthreads();
if (threadIdx.x == 0) {
if (reset) {
lock[0] = 0;
return;
}
int val = 1;
// Make sure that all writes since acquiring this barrier are visible
// globally, while releasing the barrier.
asm volatile("fence.acq_rel.gpu;\n");
asm volatile("red.relaxed.gpu.global.add.s32 [%0], %1;\n"
:
: "l"(lock), "r"(val));
}
}
template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
// threadblock
const int thread_n_blocks, // same for n dimension (output)
const int thread_k_blocks, // same for k dimension (reduction)
const int stages, // number of stages for the async global->shared
// fetch pipeline
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
__global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
// `thread_n_blocks`). Stripes are defined as shown in the 3x3 matrix 5 SM
// example:
// 0 1 3
// 0 2 3
// 1 2 4
// While this kind of partitioning makes things somewhat more complicated, it
// ensures good utilization of all SMs for many kinds of shape and GPU
// configurations, while requiring as few slow global cross-threadblock
// reductions as possible.
using Dtype = ScalarType<scalar_t>;
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
using FragA = typename ScalarType<scalar_t>::FragA;
using FragB = typename ScalarType<scalar_t>::FragB;
using FragC = typename ScalarType<scalar_t>::FragC;
using FragS = typename ScalarType<scalar_t>::FragS;
constexpr int pack_factor = 32 / num_bits;
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
int parallel = 1;
if (prob_m > 16 * thread_m_blocks) {
parallel = prob_m / (16 * thread_m_blocks);
prob_m = 16 * thread_m_blocks;
}
int k_tiles = prob_k / 16 / thread_k_blocks;
int n_tiles = prob_n / 16 / thread_n_blocks;
int iters = div_ceil(k_tiles * n_tiles * parallel, gridDim.x);
int slice_row = (iters * blockIdx.x) % k_tiles;
int slice_col_par = (iters * blockIdx.x) / k_tiles;
int slice_col = slice_col_par;
int slice_iters; // number of threadblock tiles in the current slice
int slice_count =
0; // total number of active threadblocks in the current slice
int slice_idx; // index of threadblock in current slice; numbered bottom to
// top
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
A += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_k / 8;
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles;
}
// Compute all information about the current slice which is required for
// synchronization.
auto init_slice = [&]() {
slice_iters =
iters * (blockIdx.x + 1) - (k_tiles * slice_col_par + slice_row);
if (slice_iters < 0 || slice_col_par >= n_tiles * parallel) slice_iters = 0;
if (slice_iters == 0) return;
if (slice_row + slice_iters > k_tiles) slice_iters = k_tiles - slice_row;
slice_count = 1;
slice_idx = 0;
int col_first = iters * div_ceil(k_tiles * slice_col_par, iters);
if (col_first <= k_tiles * (slice_col_par + 1)) {
int col_off = col_first - k_tiles * slice_col_par;
slice_count = div_ceil(k_tiles - col_off, iters);
if (col_off > 0) slice_count++;
int delta_first = iters * blockIdx.x - col_first;
if (delta_first < 0 || (col_off == 0 && delta_first == 0))
slice_idx = slice_count - 1;
else {
slice_idx = slice_count - 1 - delta_first / iters;
if (col_off > 0) slice_idx--;
}
}
if (slice_col == n_tiles) {
A += 16 * thread_m_blocks * prob_k / 8;
C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles;
slice_col = 0;
}
};
init_slice();
// A sizes/strides
// stride of the A matrix in global memory
int a_gl_stride = prob_k / 8;
// stride of an A matrix tile in shared memory
constexpr int a_sh_stride = 16 * thread_k_blocks / 8;
// delta between subsequent A tiles in global memory
constexpr int a_gl_rd_delta_o = 16 * thread_k_blocks / 8;
// between subsequent accesses within a tile
int a_gl_rd_delta_i = a_gl_stride * (threads / a_gl_rd_delta_o);
// between shared memory writes
constexpr int a_sh_wr_delta = a_sh_stride * (threads / a_gl_rd_delta_o);
// between shared memory tile reads
constexpr int a_sh_rd_delta_o = 2 * ((threads / 32) / (thread_n_blocks / 4));
// within a shared memory tile
constexpr int a_sh_rd_delta_i = a_sh_stride * 16;
// overall size of a tile
constexpr int a_sh_stage = a_sh_stride * (16 * thread_m_blocks);
// number of shared write iterations for a tile
constexpr int a_sh_wr_iters = div_ceil(a_sh_stage, a_sh_wr_delta);
// B sizes/strides
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
int b_gl_rd_delta_i = b_gl_stride * (threads / b_sh_stride_threads);
constexpr int b_sh_wr_delta = threads * b_thread_vecs;
constexpr int b_sh_rd_delta = threads * b_thread_vecs;
constexpr int b_sh_stage = b_sh_stride * thread_k_blocks;
constexpr int b_sh_wr_iters = b_sh_stage / b_sh_wr_delta;
// Scale sizes/strides without act_order
int s_gl_stride = prob_n / 8;
constexpr int s_sh_stride = 16 * thread_n_blocks / 8;
// Scale size/strides with act_order
constexpr int tb_k = 16 * thread_k_blocks;
constexpr int g_idx_stage = 0;
// constexpr int act_s_row_stride = 1;
// int act_s_col_stride = act_s_row_stride * num_groups;
int act_s_col_stride = 1;
int act_s_col_warp_stride = act_s_col_stride * 8;
int tb_n_warps = thread_n_blocks / 4;
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
a_gl_rd += a_gl_rd_delta_o * slice_row;
// Shared write index of current thread.
int a_sh_wr = a_sh_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
// Shared read index.
int a_sh_rd =
a_sh_stride * ((threadIdx.x % 32) % 16) + (threadIdx.x % 32) / 16;
a_sh_rd += 2 * ((threadIdx.x / 32) / (thread_n_blocks / 4));
int b_gl_rd = b_gl_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads) * b_thread_vecs;
b_gl_rd += b_sh_stride * slice_col;
b_gl_rd += b_gl_rd_delta_o * slice_row;
int b_sh_wr = threadIdx.x * b_thread_vecs;
int b_sh_rd = threadIdx.x * b_thread_vecs;
// For act_order
int slice_k_start = tb_k * slice_row;
int slice_k_start_shared_fetch = slice_k_start;
int slice_n_offset = act_s_col_tb_stride * slice_col;
// No act_order
int s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// We scale a `half2` tile in row-major layout for column-wise quantization.
int s_sh_rd =
8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) + (threadIdx.x % 32) % 4;
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
bool a_sh_wr_pred[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_pred[i] = a_sh_wr_delta * i + a_sh_wr < a_sh_stride * prob_m;
// To ensure that writing and reading A tiles to/from shared memory, the
// latter in fragment format, is fully bank conflict free, we need to use a
// rather fancy XOR-based layout. The key here is that neither reads nor
// writes of the 16-byte `int4` blocks of 8 consecutive threads involve the
// same shared memory banks. Further, it seems (based on NSight-Compute) that
// each warp must also write a consecutive memory segment?
auto transform_a = [&](int i) {
int row = i / a_gl_rd_delta_o;
return a_gl_rd_delta_o * row + (i % a_gl_rd_delta_o) ^ row;
};
// Since the computation of this remapping is non-trivial and, due to our main
// loop unrolls, all shared memory accesses are static, we simply precompute
// both transformed reads and writes.
int a_sh_wr_trans[a_sh_wr_iters];
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++)
a_sh_wr_trans[i] = transform_a(a_sh_wr_delta * i + a_sh_wr);
int a_sh_rd_trans[b_sh_wr_iters][thread_m_blocks];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < thread_m_blocks; j++)
a_sh_rd_trans[i][j] =
transform_a(a_sh_rd_delta_o * i + a_sh_rd_delta_i * j + a_sh_rd);
}
// Since B-accesses have non-constant stride they have to be computed at
// runtime; we break dependencies between subsequent accesses with a tile by
// maintining multiple pointers (we have enough registers), a tiny
// optimization.
const int4* B_ptr[b_sh_wr_iters];
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] = B + b_gl_rd_delta_i * i + b_gl_rd;
extern __shared__ int4 sh[];
// Shared memory storage for global fetch pipelines.
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4];
// Zero accumulators.
auto zero_accums = [&]() {
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4 * 2 * 4; i++)
reinterpret_cast<float*>(frag_c)[i] = 0;
};
int sh_first_group_id = -1;
int sh_num_groups = -1;
constexpr int sh_max_num_groups = 32;
auto fetch_scales_to_shared = [&](bool is_async, int first_group_id,
int last_group_id) {
sh_first_group_id = first_group_id;
sh_num_groups = last_group_id - first_group_id + 1;
if (sh_num_groups < sh_max_num_groups) {
sh_num_groups = sh_max_num_groups;
}
if (sh_first_group_id + sh_num_groups > num_groups) {
sh_num_groups = num_groups - sh_first_group_id;
}
int row_offset = first_group_id * s_gl_stride;
if (is_async) {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
cp_async4_pred(&sh_s[(i * s_sh_stride) + threadIdx.x],
&scales_ptr[row_offset + (i * s_gl_stride) +
slice_n_offset + threadIdx.x]);
}
}
} else {
for (int i = 0; i < sh_num_groups; i++) {
if (threadIdx.x < s_sh_stride) {
sh_s[(i * s_sh_stride) + threadIdx.x] =
scales_ptr[row_offset + (i * s_gl_stride) + slice_n_offset +
threadIdx.x];
}
}
}
};
// Asynchronously fetch the next A, B and s tile from global to the next
// shared memory pipeline location.
auto fetch_to_shared = [&](int pipe, int a_off, bool pred = true) {
if (pred) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < a_sh_wr_iters; i++) {
cp_async4_pred(
&sh_a_stage[a_sh_wr_trans[i]],
&A[a_gl_rd_delta_i * i + a_gl_rd + a_gl_rd_delta_o * a_off],
a_sh_wr_pred[i]);
}
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) {
#pragma unroll
for (int j = 0; j < b_thread_vecs; j++) {
cp_async4(&sh_b_stage[b_sh_wr_delta * i + b_sh_wr + j], B_ptr[i] + j);
}
B_ptr[i] += b_gl_rd_delta_o;
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
// waiting is also correct at this point.
cp_async_fence();
};
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<stages - 2>();
__syncthreads();
};
// Load the next sub-tile from the current location in the shared memory pipe
// into the current register buffer.
auto fetch_to_registers = [&](int k, int pipe) {
int4* sh_a_stage = sh_a + a_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++)
ldsm4<scalar_t>(frag_a[k % 2][i],
&sh_a_stage[a_sh_rd_trans[k % b_sh_wr_iters][i]]);
int4* sh_b_stage = sh_b + b_sh_stage * pipe;
#pragma unroll
for (int i = 0; i < b_thread_vecs; i++) {
frag_b_quant[k % 2][i] = *reinterpret_cast<I4*>(
&sh_b_stage[b_sh_rd_delta * (k % b_sh_wr_iters) + b_sh_rd + i]);
}
};
bool is_same_group[stages];
int same_group_id[stages];
auto init_same_group = [&](int pipe) {
is_same_group[pipe] = false;
same_group_id[pipe] = 0;
return;
};
// Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) {
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
for (int j = 0; j < 4; j++) {
FragB frag_b0;
FragB frag_b1;
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
mma<scalar_t>(frag_a[k % 2][i], frag_b0, frag_c[i][j][0]);
mma<scalar_t>(frag_a[k % 2][i], frag_b1, frag_c[i][j][1]);
}
}
};
// Since we slice across the k dimension of a tile in order to increase the
// number of warps while keeping the n dimension of a tile reasonable, we have
// multiple warps that accumulate their partial sums of the same output
// location; which we have to reduce over in the end. We do in shared memory.
auto thread_block_reduce = [&]() {
constexpr int red_off = threads / b_sh_stride_threads / 2;
if (red_off >= 1) {
int red_idx = threadIdx.x / b_sh_stride_threads;
constexpr int red_sh_stride = b_sh_stride_threads * 4 * 2;
constexpr int red_sh_delta = b_sh_stride_threads;
int red_sh_rd = red_sh_stride * (threadIdx.x / b_sh_stride_threads) +
(threadIdx.x % b_sh_stride_threads);
// Parallel logarithmic shared memory reduction. We make sure to avoid any
// unnecessary read or write iterations, e.g., for two warps we write only
// once by warp 1 and read only once by warp 0.
#pragma unroll
for (int m_block = 0; m_block < thread_m_blocks; m_block++) {
#pragma unroll
for (int i = red_off; i > 0; i /= 2) {
if (i <= red_idx && red_idx < 2 * i) {
#pragma unroll
for (int j = 0; j < 4 * 2; j++) {
int red_sh_wr =
red_sh_delta * j + (red_sh_rd - red_sh_stride * i);
if (i < red_off) {
float* c_rd =
reinterpret_cast<float*>(&sh[red_sh_delta * j + red_sh_rd]);
float* c_wr = reinterpret_cast<float*>(&sh[red_sh_wr]);
#pragma unroll
for (int k = 0; k < 4; k++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + j][k] +=
c_rd[k] + c_wr[k];
}
sh[red_sh_wr] =
reinterpret_cast<int4*>(&frag_c)[4 * 2 * m_block + j];
}
}
__syncthreads();
}
if (red_idx == 0) {
#pragma unroll
for (int i = 0; i < 4 * 2; i++) {
float* c_rd =
reinterpret_cast<float*>(&sh[red_sh_delta * i + red_sh_rd]);
#pragma unroll
for (int j = 0; j < 4; j++)
reinterpret_cast<FragC*>(frag_c)[4 * 2 * m_block + i][j] +=
c_rd[j];
}
}
__syncthreads();
}
}
};
// Since multiple threadblocks may process parts of the same column slice, we
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
constexpr int active_threads = 32 * thread_n_blocks / 4;
if (threadIdx.x < active_threads) {
int c_gl_stride = prob_n / 8;
int c_gl_wr_delta_o = 8 * c_gl_stride;
int c_gl_wr_delta_i = 4 * (active_threads / 32);
int c_gl_wr = c_gl_stride * ((threadIdx.x % 32) / 4) +
4 * (threadIdx.x / 32) + threadIdx.x % 4;
c_gl_wr += (2 * thread_n_blocks) * slice_col;
constexpr int c_sh_wr_delta = active_threads;
int c_sh_wr = threadIdx.x;
int row = (threadIdx.x % 32) / 4;
if (!first) {
// Interestingly, doing direct global accesses here really seems to mess up
// the compiler and lead to slowdowns, hence we also use async-copies even
// though these fetches are not actually asynchronous.
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
cp_async4_pred(
&sh[c_sh_wr + c_sh_wr_delta * i],
&C[c_gl_wr + c_gl_wr_delta_o * (i / 2) +
c_gl_wr_delta_i * (i % 2)],
i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m);
}
cp_async_fence();
cp_async_wait<0>();
}
#pragma unroll
for (int i = 0; i < thread_m_blocks * 4; i++) {
if (i < (thread_m_blocks - 1) * 4 || 8 * (i / 2) + row < prob_m) {
if (!first) {
int4 c_red = sh[c_sh_wr + i * c_sh_wr_delta];
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)] +=
Dtype::num2float(reinterpret_cast<scalar_t*>(&c_red)[j]);
}
}
if (!last) {
int4 c;
#pragma unroll
for (int j = 0; j < 2 * 4; j++) {
reinterpret_cast<scalar_t*>(&c)[j] =
Dtype::float2num(reinterpret_cast<float*>(
&frag_c)[4 * 2 * 4 * (i / 4) + 4 * j + (i % 4)]);
}
C[c_gl_wr + c_gl_wr_delta_o * (i / 2) + c_gl_wr_delta_i * (i % 2)] =
c;
}
}
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
auto write_result = [&]() {
int c_gl_stride = prob_n / 8;
constexpr int c_sh_stride = 2 * thread_n_blocks + 1;
int c_gl_wr_delta = c_gl_stride * (threads / (2 * thread_n_blocks));
constexpr int c_sh_rd_delta =
c_sh_stride * (threads / (2 * thread_n_blocks));
int c_gl_wr = c_gl_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
c_gl_wr += (2 * thread_n_blocks) * slice_col;
int c_sh_wr =
(4 * c_sh_stride) * ((threadIdx.x % 32) / 4) + (threadIdx.x % 32) % 4;
c_sh_wr += 32 * (threadIdx.x / 32);
int c_sh_rd = c_sh_stride * (threadIdx.x / (2 * thread_n_blocks)) +
(threadIdx.x % (2 * thread_n_blocks));
int c_gl_wr_end = c_gl_stride * prob_m;
// We first reorder in shared memory to guarantee the most efficient final
// global write patterns
auto write = [&](int idx, float c0, float c1, FragS& s) {
scalar_t2 res =
Dtype::nums2num2(Dtype::float2num(c0), Dtype::float2num(c1));
((scalar_t2*)sh)[idx] = res;
};
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
int wr = c_sh_wr + 8 * j;
write(wr + (4 * c_sh_stride) * 0 + 0, frag_c[i][j][0][0],
frag_c[i][j][0][1], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 8 + 0, frag_c[i][j][0][2],
frag_c[i][j][0][3], frag_s[j / 2][2 * (j % 2) + 0]);
write(wr + (4 * c_sh_stride) * 0 + 4, frag_c[i][j][1][0],
frag_c[i][j][1][1], frag_s[j / 2][2 * (j % 2) + 1]);
write(wr + (4 * c_sh_stride) * 8 + 4, frag_c[i][j][1][2],
frag_c[i][j][1][3], frag_s[j / 2][2 * (j % 2) + 1]);
}
c_sh_wr += 16 * (4 * c_sh_stride);
}
}
__syncthreads();
#pragma unroll
for (int i = 0;
i < div_ceil(16 * thread_m_blocks, threads / (2 * thread_n_blocks));
i++) {
if (c_gl_wr < c_gl_wr_end) {
C[c_gl_wr] = sh[c_sh_rd];
c_gl_wr += c_gl_wr_delta;
c_sh_rd += c_sh_rd_delta;
}
}
};
// Start global fetch and register load pipelines.
auto start_pipes = [&]() {
#pragma unroll
for (int i = 0; i < stages - 1; i++) {
fetch_to_shared(i, i, i < slice_iters);
}
zero_accums();
wait_for_stage();
init_same_group(0);
fetch_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1);
};
if (slice_iters) {
start_pipes();
}
// Main loop.
while (slice_iters) {
// We unroll over both the global fetch and the register load pipeline to
// ensure all shared memory accesses are static. Note that both pipelines
// have even length meaning that the next iteration will always start at
// index 0.
#pragma unroll
for (int pipe = 0; pipe < stages;) {
#pragma unroll
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
pipe++;
wait_for_stage();
init_same_group(pipe % stages);
}
matmul(k);
}
slice_iters--;
if (slice_iters == 0) {
break;
}
}
a_gl_rd += a_gl_rd_delta_o * stages;
slice_k_start += tb_k * stages;
slice_k_start_shared_fetch += tb_k * stages;
// Process results and, if necessary, proceed to the next column slice.
// While this pattern may not be the most readable, other ways of writing
// the loop seemed to noticeably worse performance after compilation.
if (slice_iters == 0) {
cp_async_wait<0>();
bool last = slice_idx == slice_count - 1;
// For per-column scales, we only fetch them here in the final step before
// write-out
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
cp_async_fence();
thread_block_reduce();
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
reinterpret_cast<int4*>(&frag_s)[0] = sh_s[s_sh_rd + 0];
reinterpret_cast<int4*>(&frag_s)[1] = sh_s[s_sh_rd + 4];
}
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
#pragma unroll
for (int j = 0; j < 4; j++) {
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][0]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][0][2]),
frag_s[j / 2][2 * (j % 2) + 0]);
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][0]),
frag_s[j / 2][2 * (j % 2) + 1]);
scale_float<scalar_t>(reinterpret_cast<float*>(&frag_c[i][j][1][2]),
frag_s[j / 2][2 * (j % 2) + 1]);
}
}
}
if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice
barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last);
barrier_release(&locks[slice_col], last);
}
if (last) // only the last block in a slice actually writes the result
write_result();
slice_row = 0;
slice_col_par++;
slice_col++;
init_slice();
if (slice_iters) {
a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++)
B_ptr[i] += b_sh_stride - b_gl_rd_delta_o * k_tiles;
if (slice_col == 0) {
#pragma unroll
for (int i = 0; i < b_sh_wr_iters; i++) B_ptr[i] -= b_gl_stride;
}
// Update slice k/n for scales loading
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
start_pipes();
}
}
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, GROUP_BLOCKS, NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, num_groups, prob_m, prob_n, prob_k, \
locks); \
}
typedef struct {
int thread_k;
int thread_n;
int num_threads;
} thread_config_t;
typedef struct {
int max_m_blocks;
thread_config_t tb_cfg;
} exec_config_t;
thread_config_t small_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{128, 128, 256},
{64, 128, 128},
{128, 64, 128},
};
thread_config_t large_batch_thread_configs[] = {
// Ordered by priority
// thread_k, thread_n, num_threads
{64, 256, 256},
{64, 128, 128},
{128, 64, 128},
};
int get_scales_cache_size(thread_config_t const& th_config, int prob_m,
int prob_n, int prob_k, int num_bits,
int group_size) {
int tb_n = th_config.thread_n;
// Get max scale groups per thread-block
// Fixed for channelwise
int tb_groups = 1;
int tb_scales = tb_groups * tb_n * 2;
return tb_scales * pipe_stages;
}
bool is_valid_cache_size(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int scales_cache_size, int max_shared_mem) {
int pack_factor = 32 / num_bits;
// Get B size
int tb_k = th_config.thread_k;
int tb_n = th_config.thread_n;
int b_size = (tb_k * tb_n / pack_factor) * 4;
// Get A size
int m_blocks = div_ceil(prob_m, 16);
int tb_max_m = 16;
while (true) {
if (m_blocks >= max_m_blocks) {
tb_max_m *= max_m_blocks;
break;
}
max_m_blocks--;
if (max_m_blocks == 0) {
TORCH_CHECK(false, "Unexpected m_blocks = ", m_blocks);
}
}
int a_size = (tb_max_m * tb_k) * 2;
float pipe_size = (a_size + b_size) * pipe_stages;
TORCH_CHECK(max_shared_mem / 2 > scales_cache_size); // Sanity
return pipe_size < 0.95f * (max_shared_mem - scales_cache_size);
}
bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
int prob_m, int prob_n, int prob_k, int num_bits,
int group_size, int max_shared_mem) {
// Sanity
if (th_config.thread_k == -1 || th_config.thread_n == -1 ||
th_config.num_threads == -1) {
return false;
}
// Verify K/N are divisible by thread K/N
if (prob_k % th_config.thread_k != 0 || prob_n % th_config.thread_n != 0) {
return false;
}
// Verify min for thread K/N
if (th_config.thread_n < min_thread_n || th_config.thread_k < min_thread_k) {
return false;
}
// num_threads must be at least 128 (= 4 warps)
if (th_config.num_threads < 128) {
return false;
}
// Determine cache for scales
int scales_cache_size = get_scales_cache_size(th_config, prob_m, prob_n,
prob_k, num_bits, group_size);
// Check that pipeline fits into cache
if (!is_valid_cache_size(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, scales_cache_size, max_shared_mem)) {
return false;
}
return true;
}
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
int max_shared_mem) {
int max_m_blocks = 4;
while (max_m_blocks > 0) {
if (prob_m <= 16) {
for (auto th_config : small_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
} else {
for (auto th_config : large_batch_thread_configs) {
if (is_valid_config(th_config, max_m_blocks, prob_m, prob_n, prob_k,
num_bits, group_size, max_shared_mem)) {
return exec_config_t{max_m_blocks, th_config};
}
}
}
max_m_blocks--; // Process less M blocks per invocation to reduce cache
// usage
}
return exec_config_t{0, {-1, -1, -1}};
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, -1, NUM_THREADS)
template <typename scalar_t>
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s, int prob_m,
int prob_n, int prob_k, void* workspace, int num_bits,
int num_groups, int group_size, int dev,
cudaStream_t stream, int thread_k, int thread_n, int sms,
int max_par) {
TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits);
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
int tot_m = prob_m;
int tot_m_blocks = div_ceil(tot_m, 16);
int pad = 16 * tot_m_blocks - tot_m;
if (sms == -1) {
cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, dev);
}
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
// Set thread config
exec_config_t exec_cfg;
if (thread_k != -1 && thread_n != -1) {
// User-defined config
exec_cfg =
exec_config_t{4, thread_config_t{thread_k, thread_n, default_threads}};
} else {
// Auto config
exec_cfg = determine_thread_config(prob_m, prob_n, prob_k, num_bits,
group_size, max_shared_mem);
}
TORCH_CHECK(
exec_cfg.max_m_blocks > 0 &&
is_valid_config(exec_cfg.tb_cfg, exec_cfg.max_m_blocks, prob_m,
prob_n, prob_k, num_bits, group_size, max_shared_mem),
"Invalid thread config: max_m_blocks = ", exec_cfg.max_m_blocks,
", thread_k = ", exec_cfg.tb_cfg.thread_k,
", thread_n = ", exec_cfg.tb_cfg.thread_n,
", num_threads = ", exec_cfg.tb_cfg.num_threads, " for MKN = [", prob_m,
", ", prob_k, ", ", prob_n, "] and num_bits = ", num_bits,
", group_size = ", group_size, ", max_shared_mem = ", max_shared_mem);
int num_threads = exec_cfg.tb_cfg.num_threads;
thread_k = exec_cfg.tb_cfg.thread_k;
thread_n = exec_cfg.tb_cfg.thread_n;
int thread_k_blocks = thread_k / 16;
int thread_n_blocks = thread_n / 16;
int blocks = sms;
TORCH_CHECK(prob_n % thread_n == 0, "prob_n = ", prob_n,
" is not divisible by thread_n = ", thread_n);
TORCH_CHECK(prob_k % thread_k == 0, "prob_k = ", prob_k,
" is not divisible by thread_k = ", thread_k);
int group_blocks = -1;
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
const int4* s_ptr = (const int4*)s;
int* locks = (int*)workspace;
// Main loop
for (int i = 0; i < tot_m_blocks; i += exec_cfg.max_m_blocks) {
int thread_m_blocks = tot_m_blocks - i;
prob_m = tot_m - 16 * i;
int par = 1;
if (thread_m_blocks > exec_cfg.max_m_blocks) {
// Note that parallel > 1 currently only works for inputs without any
// padding
par = (16 * thread_m_blocks - pad) / (16 * exec_cfg.max_m_blocks);
if (par > max_par) par = max_par;
prob_m = (16 * exec_cfg.max_m_blocks) * par;
i += exec_cfg.max_m_blocks * (par - 1);
thread_m_blocks = exec_cfg.max_m_blocks;
}
// Define kernel configurations
if (false) {
}
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 8, 256)
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
", num_groups = " + str(num_groups) +
", group_size = " + str(group_size) +
", thread_m_blocks = " + str(thread_m_blocks) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
}
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
C_ptr += 16 * thread_m_blocks * (prob_n / 8) * par;
}
}
} // namespace fp8_marlin
torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k) {
// Verify num_bits
TORCH_CHECK(num_bits == 8, "num_bits must be 8. Got = ", num_bits);
int pack_factor = 32 / num_bits;
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
", size_m = ", size_m);
TORCH_CHECK(a.size(1) == size_k, "Shape mismatch: a.size(1) = ", a.size(1),
", size_k = ", size_k);
// Verify B
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", marlin::tile_size);
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", marlin::tile_size);
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n);
// Verify device and strides
TORCH_CHECK(a.device().is_cuda(), "A is not on GPU");
TORCH_CHECK(a.is_contiguous(), "A is not contiguous");
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c = torch::empty({size_m, size_n}, options);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
// thread_n: `n` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_n = -1;
// sms: number of SMs to use for the kernel (can usually be left as auto -1)
int sms = -1;
// Detect groupsize and act_order
int num_groups = -1;
int group_size = -1;
int b_rank = b_scales.sizes().size();
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
" is not size_n = ", size_n);
// Channelwise only for FP8
TORCH_CHECK(b_scales.size(0) == 1)
num_groups = b_scales.size(0);
// Verify workspace size
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", marlin::min_thread_n);
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
fp8_marlin::marlin_mm_f16i4<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
b_scales.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
marlin::max_par);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
fp8_marlin::marlin_mm_f16i4<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(), size_m,
size_n, size_k, workspace.data_ptr(), num_bits, num_groups, group_size,
dev, at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
marlin::max_par);
} else {
TORCH_CHECK(false, "fp8_marlin_gemm only supports bfloat16 and float16");
}
return c;
}
#endif
......@@ -475,6 +475,7 @@ __inline__ __device__ uint8_t scaled_vec_conversion<uint8_t, __nv_bfloat16>(
__NV_SATFINITE, fp8_type);
return (uint8_t)res;
#endif
__builtin_unreachable(); // Suppress missing return statement warning
}
// float -> fp8
......@@ -508,6 +509,7 @@ __inline__ __device__ Tout convert(const Tin& x) {
}
#endif
assert(false);
__builtin_unreachable(); // Suppress missing return statement warning
}
template <typename Tout, typename Tin, Fp8KVCacheDataType kv_dt>
......@@ -520,6 +522,7 @@ __inline__ __device__ Tout scaled_convert(const Tin& x, const float scale) {
}
#endif
assert(false);
__builtin_unreachable(); // Suppress missing return statement warning
}
// The following macro is used to dispatch the conversion function based on
......
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
namespace marlin {
template <int const num_threads, int const num_bits>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
constexpr int pack_factor = 32 / num_bits;
int k_tiles = size_k / tile_k_size;
int n_tiles = size_n / tile_n_size;
int block_k_tiles = div_ceil(k_tiles, gridDim.x);
int start_k_tile = blockIdx.x * block_k_tiles;
if (start_k_tile >= k_tiles) {
return;
}
int finish_k_tile = min(start_k_tile + block_k_tiles, k_tiles);
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
// and can only issue the next fetch when it is guaranteed that the previous
// shared memory load is fully complete (as it may otherwise be
// overwritten).
cp_async_wait<repack_stages - 2>();
__syncthreads();
};
extern __shared__ int4 sh[];
constexpr int tile_n_ints = tile_n_size / pack_factor;
constexpr int stage_n_threads = tile_n_ints / 4;
constexpr int stage_k_threads = tile_k_size;
constexpr int stage_size = stage_k_threads * stage_n_threads;
auto fetch_to_shared = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
cp_async_fence();
return;
}
int first_n = n_tile_id * tile_n_size;
int first_n_packed = first_n / pack_factor;
int4* sh_ptr = sh + stage_size * pipe;
if (threadIdx.x < stage_size) {
int k_id = threadIdx.x / stage_n_threads;
int n_id = threadIdx.x % stage_n_threads;
int first_k = k_tile_id * tile_k_size;
cp_async4(&sh_ptr[k_id * stage_n_threads + n_id],
reinterpret_cast<int4 const*>(
&(b_q_weight_ptr[(first_k + k_id) * (size_n / pack_factor) +
first_n_packed + (n_id * 4)])));
}
cp_async_fence();
};
auto repack_tile = [&](int pipe, int k_tile_id, int n_tile_id) {
if (n_tile_id >= n_tiles) {
return;
}
int warp_id = threadIdx.x / 32;
int th_id = threadIdx.x % 32;
if (warp_id >= 4) {
return;
}
int tc_col = th_id / 4;
int tc_row = (th_id % 4) * 2;
constexpr int tc_offsets[4] = {0, 1, 8, 9};
int cur_n = warp_id * 16 + tc_col;
int cur_n_packed = cur_n / pack_factor;
int cur_n_pos = cur_n % pack_factor;
constexpr int sh_stride = tile_n_ints;
constexpr uint32_t mask = (1 << num_bits) - 1;
int4* sh_stage_ptr = sh + stage_size * pipe;
uint32_t* sh_stage_int_ptr = reinterpret_cast<uint32_t*>(sh_stage_ptr);
// Undo interleaving
int cur_n_pos_unpacked;
if constexpr (num_bits == 4) {
constexpr int undo_pack[8] = {0, 4, 1, 5, 2, 6, 3, 7};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
} else {
constexpr int undo_pack[4] = {0, 2, 1, 3};
cur_n_pos_unpacked = undo_pack[cur_n_pos];
}
uint32_t vals[8];
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int packed_src_0 = sh_stage_int_ptr[cur_n_packed + sh_stride * cur_elem];
int packed_src_1 = sh_stage_int_ptr[cur_n_packed + (8 / pack_factor) +
sh_stride * cur_elem];
vals[i] = (packed_src_0 >> (cur_n_pos_unpacked * num_bits)) & mask;
vals[4 + i] = (packed_src_1 >> (cur_n_pos_unpacked * num_bits)) & mask;
}
constexpr int tile_size = tile_k_size * tile_n_size / pack_factor;
int out_offset = (k_tile_id * n_tiles + n_tile_id) * tile_size;
// Result of:
// https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
if constexpr (num_bits == 4) {
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
out_ptr[out_offset + th_id * 4 + warp_id] = res;
} else {
constexpr int pack_idx[4] = {0, 2, 1, 3};
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
}
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 0] = res1;
out_ptr[out_offset + th_id * 8 + (warp_id * 2) + 1] = res2;
}
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
repack_tile(pipe, k_tile_id, n_tile_id + pipe);
wait_for_stage();
}
n_tile_id += repack_stages;
}
}
}
} // namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int const pack_factor = 32 / num_bits;
// Verify B
TORCH_CHECK(b_q_weight.size(0) == size_k,
"b_q_weight.size(0) = ", b_q_weight.size(0),
" is not size_k = ", size_k);
TORCH_CHECK((size_n / pack_factor) == b_q_weight.size(1),
"Shape mismatch: b_q_weight.size(1) = ", b_q_weight.size(1),
", size_n = ", size_n, ", pack_factor = ", pack_factor);
// Verify device and strides
TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU");
TORCH_CHECK(b_q_weight.is_contiguous(), "b_q_weight is not contiguous");
TORCH_CHECK(b_q_weight.dtype() == at::kInt, "b_q_weight type is not kInt");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(b_q_weight));
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out = torch::empty(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
// Get ptrs
uint32_t const* b_q_weight_ptr =
reinterpret_cast<uint32_t const*>(b_q_weight.data_ptr());
uint32_t* out_ptr = reinterpret_cast<uint32_t*>(out.data_ptr());
// Get dev info
int dev = b_q_weight.get_device();
cudaStream_t stream = at::cuda::getCurrentCUDAStream(dev);
int blocks;
cudaDeviceGetAttribute(&blocks, cudaDevAttrMultiProcessorCount, dev);
int max_shared_mem = 0;
cudaDeviceGetAttribute(&max_shared_mem,
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
TORCH_CHECK(max_shared_mem > 0);
if (false) {
}
CALL_IF(4)
CALL_IF(8)
else {
TORCH_CHECK(false, "Unsupported repack config: num_bits = ", num_bits);
}
return out;
}
#endif
......@@ -19,8 +19,9 @@
* Adapted from https://github.com/IST-DASLab/marlin
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include "marlin.cuh"
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
......@@ -32,7 +33,7 @@ inline std::string str(T x) {
return std::to_string(x);
}
namespace gptq_marlin {
namespace marlin {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
......@@ -59,23 +60,27 @@ __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_fp32_reduce // whether to use fp32 global reduce
) {}
} // namespace gptq_marlin
} // namespace marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full) {
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp) {
TORCH_CHECK_NOT_IMPLEMENTED(false,
"marlin_gemm(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
......@@ -264,6 +269,114 @@ dequant_8bit<nv_bfloat16>(int q) {
return frag_b;
}
// Zero-point dequantizers
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_4bit_zp<nv_bfloat16>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
uint32_t lo = prmt<start_byte_for_fp16, mask_for_elt_01>(q);
uint32_t hi = prmt<start_byte_for_fp16, mask_for_elt_23>(q);
static constexpr uint32_t I8s_TO_F16s_MAGIC_NUM = 0x64006400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
frag_b[1] = __hsub2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&I8s_TO_F16s_MAGIC_NUM));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit_zp<nv_bfloat16>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4];
uint32_t* fp32_intermediates_casted =
reinterpret_cast<uint32_t*>(fp32_intermediates);
static constexpr uint32_t fp32_base = 0x4B000000;
fp32_intermediates_casted[0] = __byte_perm(q, fp32_base, 0x7650);
fp32_intermediates_casted[1] = __byte_perm(q, fp32_base, 0x7652);
fp32_intermediates_casted[2] = __byte_perm(q, fp32_base, 0x7651);
fp32_intermediates_casted[3] = __byte_perm(q, fp32_base, 0x7653);
fp32_intermediates[0] -= 8388608.f;
fp32_intermediates[1] -= 8388608.f;
fp32_intermediates[2] -= 8388608.f;
fp32_intermediates[3] -= 8388608.f;
uint32_t* bf16_result_ptr = reinterpret_cast<uint32_t*>(&frag_b);
bf16_result_ptr[0] = __byte_perm(fp32_intermediates_casted[0],
fp32_intermediates_casted[1], 0x7632);
bf16_result_ptr[1] = __byte_perm(fp32_intermediates_casted[2],
fp32_intermediates_casted[3], 0x7632);
return frag_b;
}
// Multiply dequantized values by the corresponding quantization scale; used
// only for grouped quantization.
template <typename scalar_t>
......@@ -277,6 +390,17 @@ __device__ inline void scale(typename ScalarType<scalar_t>::FragB& frag_b,
frag_b[1] = __hmul2(frag_b[1], s);
}
template <typename scalar_t>
__device__ inline void sub_zp(typename ScalarType<scalar_t>::FragB& frag_b,
typename ScalarType<scalar_t>::scalar_t2& frag_zp,
int i) {
using scalar_t2 = typename ScalarType<scalar_t>::scalar_t2;
scalar_t2 zp =
ScalarType<scalar_t>::num2num2(reinterpret_cast<scalar_t*>(&frag_zp)[i]);
frag_b[0] = __hsub2(frag_b[0], zp);
frag_b[1] = __hsub2(frag_b[1], zp);
}
// Same as above, but for act_order (each K is multiplied individually)
template <typename scalar_t>
__device__ inline void scale4(typename ScalarType<scalar_t>::FragB& frag_b,
......@@ -404,6 +528,7 @@ template <typename scalar_t, // compute dtype, half or nv_float16
const int stages, // number of stages for the async global->shared
// fetch pipeline
const bool has_act_order, // whether act_order is enabled
const bool has_zp, // whether zero-points are enabled
const int group_blocks = -1 // number of consecutive 16x16 blocks
// with a separate quantization scale
>
......@@ -411,14 +536,18 @@ __global__ void Marlin(
const int4* __restrict__ A, // fp16 input matrix of shape mxk
const int4* __restrict__ B, // 4bit quantized weight matrix of shape kxn
int4* __restrict__ C, // fp16 output buffer of shape mxn
int4* __restrict__ C_tmp, // fp32 tmp output buffer (for reduce)
const int4* __restrict__ scales_ptr, // fp16 quantization scales of shape
// (k/groupsize)xn
const int4* __restrict__ zp_ptr, // 4bit packed zero-points of shape
// (k/groupsize)x(n/pack_factor)
const int* __restrict__ g_idx, // int32 group indices of shape k
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks // extra global storage for barrier synchronization
int num_groups, // number of scale groups per output channel
int prob_m, // batch dimension m
int prob_n, // output dimension n
int prob_k, // reduction dimension k
int* locks, // extra global storage for barrier synchronization
bool use_fp32_reduce // whether to use fp32 global reduce
) {
// Each threadblock processes one "stripe" of the B matrix with (roughly) the
// same size, which might involve multiple column "slices" (of width 16 *
......@@ -437,6 +566,7 @@ __global__ void Marlin(
using FragB = typename ScalarType<scalar_t>::FragB;
using FragC = typename ScalarType<scalar_t>::FragC;
using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP;
constexpr int pack_factor = 32 / num_bits;
......@@ -471,6 +601,8 @@ __global__ void Marlin(
int slice_idx; // index of threadblock in current slice; numbered bottom to
// top
int par_id = 0;
// We can easily implement parallel problem execution by just remapping
// indices and advancing global pointers
if (slice_col_par >= n_tiles) {
......@@ -478,6 +610,7 @@ __global__ void Marlin(
C += (slice_col_par / n_tiles) * 16 * thread_m_blocks * prob_n / 8;
locks += (slice_col_par / n_tiles) * n_tiles;
slice_col = slice_col_par % n_tiles;
par_id = slice_col_par / n_tiles;
}
// Compute all information about the current slice which is required for
......@@ -508,6 +641,7 @@ __global__ void Marlin(
C += 16 * thread_m_blocks * prob_n / 8;
locks += n_tiles;
slice_col = 0;
par_id++;
}
};
init_slice();
......@@ -566,6 +700,13 @@ __global__ void Marlin(
int tb_n_warps = thread_n_blocks / 4;
int act_s_col_tb_stride = act_s_col_warp_stride * tb_n_warps;
// Zero-points sizes/strides
int zp_gl_stride = (prob_n / pack_factor) / 4;
constexpr int zp_sh_stride = ((16 * thread_n_blocks) / pack_factor) / 4;
constexpr int zp_tb_groups = s_tb_groups;
constexpr int zp_sh_stage = has_zp ? zp_tb_groups * zp_sh_stride : 0;
int zp_gl_rd_delta = zp_gl_stride;
// Global A read index of current thread.
int a_gl_rd = a_gl_stride * (threadIdx.x / a_gl_rd_delta_o) +
(threadIdx.x % a_gl_rd_delta_o);
......@@ -605,6 +746,19 @@ __global__ void Marlin(
int s_sh_wr = threadIdx.x;
bool s_sh_wr_pred = threadIdx.x < s_sh_stride;
// Zero-points
int zp_gl_rd;
if constexpr (has_zp) {
if constexpr (group_blocks == -1) {
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
} else {
zp_gl_rd = zp_gl_stride * ((thread_k_blocks * slice_row) / group_blocks) +
zp_sh_stride * slice_col + threadIdx.x;
}
}
int zp_sh_wr = threadIdx.x;
bool zp_sh_wr_pred = threadIdx.x < zp_sh_stride;
// We use a different scale layout for grouped and column-wise quantization as
// we scale a `half2` tile in column-major layout in the former and in
// row-major in the latter case.
......@@ -616,6 +770,18 @@ __global__ void Marlin(
s_sh_rd = 8 * ((threadIdx.x / 32) % (thread_n_blocks / 4)) +
(threadIdx.x % 32) % 4;
// Zero-points have the same read layout as the scales
// (without column-wise case)
constexpr int num_col_threads = 8;
constexpr int num_row_threads = 4;
constexpr int num_ints_per_thread = 8 / pack_factor;
int zp_sh_rd;
if constexpr (has_zp) {
zp_sh_rd = num_ints_per_thread * num_col_threads *
((threadIdx.x / 32) % (thread_n_blocks / 4)) +
num_ints_per_thread * ((threadIdx.x % 32) / num_row_threads);
}
// Precompute which thread should not read memory in which iterations; this is
// needed if there are more threads than required for a certain tilesize or
// when the batchsize is not a multiple of 16.
......@@ -664,14 +830,17 @@ __global__ void Marlin(
int4* sh_a = sh;
int4* sh_b = sh_a + (stages * a_sh_stage);
int4* sh_g_idx = sh_b + (stages * b_sh_stage);
int4* sh_s = sh_g_idx + (stages * g_idx_stage);
int4* sh_zp = sh_g_idx + (stages * g_idx_stage);
int4* sh_s = sh_zp + (stages * zp_sh_stage);
// Register storage for double buffer of shared memory reads.
FragA frag_a[2][thread_m_blocks];
I4 frag_b_quant[2][b_thread_vecs];
FragC frag_c[thread_m_blocks][4][2];
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order
FragS frag_s[2][4]; // No act-order
FragS act_frag_s[2][4][4]; // For act-order
int frag_qzp[2][num_ints_per_thread]; // Zero-points
FragZP frag_zp; // Zero-points in fp16
// Zero accumulators.
auto zero_accums = [&]() {
......@@ -777,6 +946,28 @@ __global__ void Marlin(
}
}
}
if constexpr (has_zp && group_blocks != -1) {
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
if constexpr (group_blocks >= thread_k_blocks) {
// Only fetch zero-points if this tile starts a new group
if (pipe % (group_blocks / thread_k_blocks) == 0) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
} else {
for (int i = 0; i < zp_tb_groups; i++) {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp_stage[i * zp_sh_stride + zp_sh_wr],
&zp_ptr[zp_gl_rd]);
}
zp_gl_rd += zp_gl_rd_delta;
}
}
}
}
}
// Insert a fence even when we are winding down the pipeline to ensure that
......@@ -784,6 +975,12 @@ __global__ void Marlin(
cp_async_fence();
};
auto fetch_zp_to_shared = [&]() {
if (zp_sh_wr_pred) {
cp_async4(&sh_zp[zp_sh_wr], &zp_ptr[zp_gl_rd]);
}
};
// Wait until the next thread tile has been loaded to shared memory.
auto wait_for_stage = [&]() {
// We only have `stages - 2` active fetches since we are double buffering
......@@ -932,8 +1129,82 @@ __global__ void Marlin(
}
};
auto fetch_zp_to_registers = [&](int k, int full_pipe) {
// This code does not handle group_blocks == 0,
// which signifies act_order.
// has_zp implies AWQ, which doesn't have act_order,
static_assert(!has_zp || group_blocks != 0);
if constexpr (has_zp) {
int pipe = full_pipe % stages;
if constexpr (group_blocks == -1) {
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] = (reinterpret_cast<int*>(sh_zp))[zp_sh_rd + i];
}
} else if constexpr (group_blocks >= thread_k_blocks) {
int4* sh_zp_stage =
sh_zp + zp_sh_stage * ((group_blocks / thread_k_blocks) *
(pipe / (group_blocks / thread_k_blocks)));
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
} else {
int warp_id = threadIdx.x / 32;
int n_warps = thread_n_blocks / 4;
int warp_row = warp_id / n_warps;
int cur_k = warp_row * 16;
cur_k += k_iter_size * (k % b_sh_wr_iters);
int k_blocks = cur_k / 16;
int cur_group_id = 0;
// Suppress bogus and persistent divide-by-zero warning
#pragma nv_diagnostic push
#pragma nv_diag_suppress divide_by_zero
cur_group_id = k_blocks / group_blocks;
#pragma nv_diagnostic pop
int4* sh_zp_stage = sh_zp + zp_sh_stage * pipe;
sh_zp_stage += cur_group_id * zp_sh_stride;
for (int i = 0; i < num_ints_per_thread; i++) {
frag_qzp[k % 2][i] =
(reinterpret_cast<int*>(sh_zp_stage))[zp_sh_rd + i];
}
}
}
};
// Execute the actual tensor core matmul of a sub-tile.
auto matmul = [&](int k) {
if constexpr (has_zp) {
FragB frag_zp_0;
FragB frag_zp_1;
if constexpr (num_bits == 4) {
int zp_quant = frag_qzp[k % 2][0];
int zp_quant_shift = zp_quant >> 8;
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
} else {
int zp_quant_0 = frag_qzp[k % 2][0];
int zp_quant_1 = frag_qzp[k % 2][1];
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
}
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
frag_zp[3] = frag_zp_1[1];
}
// We have the m dimension as the inner loop in order to encourage overlapping
// dequantization and matmul operations.
#pragma unroll
......@@ -944,16 +1215,32 @@ __global__ void Marlin(
int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8;
frag_b0 = dequant_4bit<scalar_t>(b_quant);
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
if constexpr (has_zp) {
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
} else {
frag_b0 = dequant_4bit<scalar_t>(b_quant);
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
}
} else {
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
if constexpr (has_zp) {
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
} else {
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
}
}
// Apply zero-point to frag_b0
if constexpr (has_zp) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
}
// Apply scale to frag_b0
......@@ -967,6 +1254,11 @@ __global__ void Marlin(
}
}
// Apply zero-point to frag_b1
if constexpr (has_zp) {
sub_zp<scalar_t>(frag_b1, frag_zp[j], 1);
}
// Apply scale to frag_b1
if constexpr (has_act_order) {
scale4<scalar_t>(frag_b1, act_frag_s[k % 2][0][j],
......@@ -1048,7 +1340,7 @@ __global__ void Marlin(
// finally have to globally reduce over the results. As the striped
// partitioning minimizes the number of such reductions and our outputs are
// usually rather small, we perform this reduction serially in L2 cache.
auto global_reduce = [&](bool first = false, bool last = false) {
auto global_reduce_fp16 = [&](bool first = false, bool last = false) {
// We are very careful here to reduce directly in the output buffer to
// maximize L2 cache utilization in this step. To do this, we write out
// results in FP16 (but still reduce with FP32 compute).
......@@ -1109,6 +1401,53 @@ __global__ void Marlin(
}
};
// Globally reduce over threadblocks that compute the same column block.
// We use a tmp C buffer to reduce in full fp32 precision.
auto global_reduce_fp32 = [&](bool first = false, bool last = false) {
constexpr int tb_m = thread_m_blocks * 16;
constexpr int tb_n = thread_n_blocks * 16;
constexpr int c_size = tb_m * tb_n * sizeof(float) / 16;
constexpr int active_threads = 32 * thread_n_blocks / 4;
bool is_th_active = threadIdx.x < active_threads;
int par_offset = c_size * n_tiles * par_id;
int slice_offset = c_size * slice_col;
constexpr int num_floats = thread_m_blocks * 4 * 2 * 4;
constexpr int th_size = num_floats * sizeof(float) / 16;
int c_cur_offset = par_offset + slice_offset;
if (!is_th_active) {
return;
}
if (!first) {
float* frag_c_ptr = reinterpret_cast<float*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
sh[threadIdx.x] =
C_tmp[c_cur_offset + active_threads * k + threadIdx.x];
float* sh_c_ptr = reinterpret_cast<float*>(&sh[threadIdx.x]);
#pragma unroll
for (int f = 0; f < 4; f++) {
frag_c_ptr[k * 4 + f] += sh_c_ptr[f];
}
}
}
if (!last) {
int4* frag_c_ptr = reinterpret_cast<int4*>(&frag_c);
#pragma unroll
for (int k = 0; k < th_size; k++) {
C_tmp[c_cur_offset + active_threads * k + threadIdx.x] = frag_c_ptr[k];
}
}
};
// Write out the reduce final result in the correct layout. We only actually
// reshuffle matrix fragments in this step, the reduction above is performed
// in fragment layout.
......@@ -1189,6 +1528,12 @@ __global__ void Marlin(
}
fetch_scales_to_shared(true, g_idx[slice_k_start], g_idx[last_g_idx]);
}
if constexpr (has_zp && group_blocks == -1) {
if (i == 0) {
fetch_zp_to_shared();
}
}
fetch_to_shared(i, i, i < slice_iters);
}
......@@ -1197,6 +1542,7 @@ __global__ void Marlin(
init_same_group(0);
fetch_to_registers(0, 0);
fetch_scales_to_registers(0, 0);
fetch_zp_to_registers(0, 0);
a_gl_rd += a_gl_rd_delta_o * (stages - 1);
slice_k_start_shared_fetch += tb_k * (stages - 1);
};
......@@ -1217,6 +1563,7 @@ __global__ void Marlin(
for (int k = 0; k < b_sh_wr_iters; k++) {
fetch_to_registers(k + 1, pipe % stages);
fetch_scales_to_registers(k + 1, pipe);
fetch_zp_to_registers(k + 1, pipe);
if (k == b_sh_wr_iters - 2) {
fetch_to_shared((pipe + stages - 1) % stages, pipe,
slice_iters >= stages);
......@@ -1325,7 +1672,11 @@ __global__ void Marlin(
if (slice_count > 1) { // only globally reduce if there is more than one
// block in a slice
barrier_acquire(&locks[slice_col], slice_idx);
global_reduce(slice_idx == 0, last);
if (use_fp32_reduce) {
global_reduce_fp32(slice_idx == 0, last);
} else {
global_reduce_fp16(slice_idx == 0, last);
}
barrier_release(&locks[slice_col], last);
}
if (last) // only the last block in a slice actually writes the result
......@@ -1354,6 +1705,7 @@ __global__ void Marlin(
} else {
s_gl_rd = s_sh_stride * slice_col + threadIdx.x;
zp_gl_rd = zp_sh_stride * slice_col + threadIdx.x;
}
start_pipes();
......@@ -1363,22 +1715,24 @@ __global__ void Marlin(
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, GROUP_BLOCKS, NUM_THREADS) \
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && group_blocks == GROUP_BLOCKS && \
num_threads == NUM_THREADS) { \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS>, \
HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
GROUP_BLOCKS><<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, s_ptr, g_idx_ptr, num_groups, prob_m, prob_n, \
prob_k, locks); \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
A_ptr, B_ptr, C_ptr, C_tmp_ptr, s_ptr, zp_ptr, g_idx_ptr, \
num_groups, prob_m, prob_n, prob_k, locks, use_fp32_reduce); \
}
typedef struct {
......@@ -1517,6 +1871,27 @@ bool is_valid_config(thread_config_t const& th_config, int max_m_blocks,
return true;
}
int determine_reduce_max_m(int prob_m, int max_par) {
constexpr int tile_m_size = 16;
if (prob_m <= tile_m_size) {
return tile_m_size;
} else if (prob_m <= tile_m_size * 2) {
return tile_m_size * 2;
} else if (prob_m <= tile_m_size * 3) {
return tile_m_size * 3;
} else if (prob_m <= tile_m_size * 4) {
return tile_m_size * 4;
} else {
int cur_par = min(div_ceil(prob_m, tile_m_size * 4), max_par);
return tile_m_size * 4 * cur_par;
}
}
exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
int num_bits, int group_size,
bool has_act_order, bool is_k_full,
......@@ -1548,44 +1923,77 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return exec_config_t{0, {-1, -1, -1}};
}
#define CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, 0, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, 8, NUM_THREADS)
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template <typename scalar_t>
void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
void* g_idx, void* perm, void* a_tmp, int prob_m,
int prob_n, int prob_k, void* workspace, int num_bits,
bool has_act_order, bool is_k_full, int num_groups,
int group_size, int dev, cudaStream_t stream, int thread_k,
int thread_n, int sms, int max_par) {
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
void* zp, void* g_idx, void* perm, void* a_tmp, int prob_m,
int prob_n, int prob_k, void* workspace,
vllm::ScalarType const& q_type, bool has_act_order,
bool is_k_full, bool has_zp, int num_groups, int group_size,
int dev, cudaStream_t stream, int thread_k, int thread_n,
int sms, int max_par, bool use_fp32_reduce) {
if (has_zp) {
TORCH_CHECK(
q_type == vllm::kU4 || q_type == vllm::kU8,
"q_type must be u4 or u8 when has_zp = True. Got = ", q_type.str());
} else {
TORCH_CHECK(
q_type == vllm::kU4B8 || q_type == vllm::kU8B128,
"q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
q_type.str());
}
TORCH_CHECK(prob_m > 0 && prob_n > 0 && prob_k > 0, "Invalid MNK = [", prob_m,
", ", prob_n, ", ", prob_k, "]");
// TODO: remove alias when we start supporting other 8bit types
int num_bits = q_type.size_bits();
int tot_m = prob_m;
int tot_m_blocks = div_ceil(tot_m, 16);
int pad = 16 * tot_m_blocks - tot_m;
......@@ -1664,7 +2072,9 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
const int4* A_ptr = (const int4*)A;
const int4* B_ptr = (const int4*)B;
int4* C_ptr = (int4*)C;
int4* C_tmp_ptr = (int4*)C_tmp;
const int4* s_ptr = (const int4*)s;
const int4* zp_ptr = (const int4*)zp;
const int* g_idx_ptr = (const int*)g_idx;
const int* perm_ptr = (const int*)perm;
int4* a_tmp_ptr = (int4*)a_tmp;
......@@ -1701,28 +2111,33 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
thread_m_blocks = exec_cfg.max_m_blocks;
}
// Define kernel configurations
if (false) {
}
CALL_IF(4, 32, 2, 256)
CALL_IF(4, 16, 4, 256)
CALL_IF(4, 8, 8, 256)
CALL_IF(4, 8, 4, 128)
CALL_IF(4, 4, 8, 128)
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 8, 256)
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
GPTQ_CALL_IF(4, 16, 4, 256)
GPTQ_CALL_IF(4, 8, 8, 256)
GPTQ_CALL_IF(4, 8, 4, 128)
GPTQ_CALL_IF(4, 4, 8, 128)
GPTQ_CALL_IF(8, 16, 4, 256)
GPTQ_CALL_IF(8, 8, 8, 256)
GPTQ_CALL_IF(8, 8, 4, 128)
GPTQ_CALL_IF(8, 4, 8, 128)
AWQ_CALL_IF(4, 16, 4, 256)
AWQ_CALL_IF(4, 8, 8, 256)
AWQ_CALL_IF(4, 8, 4, 128)
AWQ_CALL_IF(4, 4, 8, 128)
AWQ_CALL_IF(8, 16, 4, 256)
AWQ_CALL_IF(8, 8, 8, 256)
AWQ_CALL_IF(8, 8, 4, 128)
AWQ_CALL_IF(8, 4, 8, 128)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
", has_act_order = " + str(has_act_order) +
", num_groups = " + str(num_groups) +
", group_size = " + str(group_size) +
", thread_m_blocks = " + str(thread_m_blocks) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
", ", prob_k, "]", ", has_act_order = ", has_act_order,
", num_groups = ", num_groups, ", group_size = ", group_size,
", thread_m_blocks = ", thread_m_blocks,
", thread_n_blocks = ", thread_n_blocks,
", thread_k_blocks = ", thread_k_blocks,
", num_bits = ", num_bits);
}
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
......@@ -1730,17 +2145,28 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
}
}
} // namespace gptq_marlin
} // namespace marlin
torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& b_scales, torch::Tensor& g_idx,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full) {
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
int pack_factor = 32 / num_bits;
torch::Tensor& b_scales, torch::Tensor& b_zeros,
torch::Tensor& g_idx, torch::Tensor& perm,
torch::Tensor& workspace,
vllm::ScalarTypeTorchPtr const& b_q_type,
int64_t size_m, int64_t size_n, int64_t size_k,
bool is_k_full, bool has_zp,
bool use_fp32_reduce) {
if (has_zp) {
TORCH_CHECK(*b_q_type == vllm::kU4 || *b_q_type == vllm::kU8,
"b_q_type must be u4 or u8 when has_zp = True. Got = ",
b_q_type->str());
} else {
TORCH_CHECK(
*b_q_type == vllm::kU4B8 || *b_q_type == vllm::kU8B128,
"b_q_type must be uint4b8 or uint8b128 when has_zp = False. Got = ",
b_q_type->str());
}
int pack_factor = 32 / b_q_type->size_bits();
// Verify A
TORCH_CHECK(a.size(0) == size_m, "Shape mismatch: a.size(0) = ", a.size(0),
......@@ -1749,16 +2175,15 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
", size_k = ", size_k);
// Verify B
TORCH_CHECK(size_k % gptq_marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK((size_k / gptq_marlin::tile_size) == b_q_weight.size(0),
TORCH_CHECK(size_k % marlin::tile_size == 0, "size_k = ", size_k,
" is not divisible by tile_size = ", marlin::tile_size);
TORCH_CHECK((size_k / marlin::tile_size) == b_q_weight.size(0),
"Shape mismatch: b_q_weight.size(0) = ", b_q_weight.size(0),
", size_k = ", size_k, ", tile_size = ", gptq_marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % gptq_marlin::tile_size == 0,
", size_k = ", size_k, ", tile_size = ", marlin::tile_size);
TORCH_CHECK(b_q_weight.size(1) % marlin::tile_size == 0,
"b_q_weight.size(1) = ", b_q_weight.size(1),
" is not divisible by tile_size = ", gptq_marlin::tile_size);
int actual_size_n =
(b_q_weight.size(1) / gptq_marlin::tile_size) * pack_factor;
" is not divisible by tile_size = ", marlin::tile_size);
int actual_size_n = (b_q_weight.size(1) / marlin::tile_size) * pack_factor;
TORCH_CHECK(size_n == actual_size_n, "size_n = ", size_n,
", actual_size_n = ", actual_size_n);
......@@ -1772,6 +2197,9 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU");
TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous");
TORCH_CHECK(b_zeros.device().is_cuda(), "b_zeros is not on GPU");
TORCH_CHECK(b_zeros.is_contiguous(), "b_zeros is not contiguous");
TORCH_CHECK(g_idx.device().is_cuda(), "g_idx is not on GPU");
TORCH_CHECK(g_idx.is_contiguous(), "g_idx is not contiguous");
......@@ -1784,6 +2212,17 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor c = torch::empty({size_m, size_n}, options);
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
// Alloc C tmp buffer that is going to be used for the global reduce
int reduce_max_m = marlin::determine_reduce_max_m(size_m, marlin::max_par);
int reduce_n = size_n;
auto options_fp32 =
torch::TensorOptions().dtype(at::kFloat).device(a.device());
if (!use_fp32_reduce) {
reduce_max_m = 0;
reduce_n = 0;
}
torch::Tensor c_tmp = torch::empty({reduce_max_m, reduce_n}, options_fp32);
// thread_k: `k` size of a thread_tile in `weights` (can usually be left as
// auto -1)
int thread_k = -1;
......@@ -1805,8 +2244,8 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
int group_size = -1;
bool has_act_order = g_idx.size(0) != 0;
int b_rank = b_scales.sizes().size();
TORCH_CHECK(b_rank == 2, "b_scales rank = ", b_rank, " is not 2");
int rank = b_scales.sizes().size();
TORCH_CHECK(rank == 2, "b_scales rank = ", rank, " is not 2");
TORCH_CHECK(b_scales.size(1) == size_n, "b_scales dim 1 = ", b_scales.size(1),
" is not size_n = ", size_n);
num_groups = b_scales.size(0);
......@@ -1832,34 +2271,45 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
}
// Verify b_zeros
if (has_zp) {
int rank = b_zeros.sizes().size();
TORCH_CHECK(rank == 2, "b_zeros rank = ", rank, " is not 2");
TORCH_CHECK(b_zeros.size(0) == num_groups,
"b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
"b_zeros dim 1 = ", b_scales.size(1),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
// Verify workspace size
TORCH_CHECK(
size_n % gptq_marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", gptq_marlin::min_thread_n);
int min_workspace_size =
(size_n / gptq_marlin::min_thread_n) * gptq_marlin::max_par;
TORCH_CHECK(size_n % marlin::min_thread_n == 0, "size_n = ", size_n,
", is not divisible by min_thread_n = ", marlin::min_thread_n);
int min_workspace_size = (size_n / marlin::min_thread_n) * marlin::max_par;
TORCH_CHECK(workspace.numel() >= min_workspace_size,
"workspace.numel = ", workspace.numel(),
" is below min_workspace_size = ", min_workspace_size);
int dev = a.get_device();
if (a.scalar_type() == at::ScalarType::Half) {
gptq_marlin::marlin_mm_f16i4<half>(
marlin::marlin_mm<half>(
a.data_ptr<at::Half>(), b_q_weight.data_ptr(), c.data_ptr<at::Half>(),
b_scales.data_ptr<at::Half>(), g_idx.data_ptr(), perm.data_ptr(),
c_tmp.data_ptr<float>(), b_scales.data_ptr<at::Half>(),
b_zeros.data_ptr(), g_idx.data_ptr(), perm.data_ptr(),
a_tmp.data_ptr<at::Half>(), size_m, size_n, size_k,
workspace.data_ptr(), num_bits, has_act_order, is_k_full, num_groups,
group_size, dev, at::cuda::getCurrentCUDAStream(dev), thread_k,
thread_n, sms, gptq_marlin::max_par);
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else if (a.scalar_type() == at::ScalarType::BFloat16) {
gptq_marlin::marlin_mm_f16i4<nv_bfloat16>(
marlin::marlin_mm<nv_bfloat16>(
a.data_ptr<at::BFloat16>(), b_q_weight.data_ptr(),
c.data_ptr<at::BFloat16>(), b_scales.data_ptr<at::BFloat16>(),
g_idx.data_ptr(), perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(),
size_m, size_n, size_k, workspace.data_ptr(), num_bits, has_act_order,
is_k_full, num_groups, group_size, dev,
at::cuda::getCurrentCUDAStream(dev), thread_k, thread_n, sms,
gptq_marlin::max_par);
c.data_ptr<at::BFloat16>(), c_tmp.data_ptr<float>(),
b_scales.data_ptr<at::BFloat16>(), b_zeros.data_ptr(), g_idx.data_ptr(),
perm.data_ptr(), a_tmp.data_ptr<at::BFloat16>(), size_m, size_n, size_k,
workspace.data_ptr(), *b_q_type, has_act_order, is_k_full, has_zp,
num_groups, group_size, dev, at::cuda::getCurrentCUDAStream(dev),
thread_k, thread_n, sms, marlin::max_par, use_fp32_reduce);
} else {
TORCH_CHECK(false, "gpt_marlin_gemm only supports bfloat16 and float16");
}
......
#include "gptq_marlin.cuh"
namespace gptq_marlin {
static constexpr int repack_stages = 8;
static constexpr int repack_threads = 256;
static constexpr int tile_k_size = tile_size;
static constexpr int tile_n_size = tile_k_size * 4;
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel(
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace gptq_marlin
} // namespace marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
......@@ -29,8 +22,10 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
#else
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void marlin_repack_kernel(
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {
......@@ -259,28 +254,28 @@ __global__ void marlin_repack_kernel(
}
}
} // namespace gptq_marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, \
NUM_BITS, HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
gptq_marlin::marlin_repack_kernel<gptq_marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, gptq_marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
} // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
// Verify compatibility with marlin tile of 16x64
TORCH_CHECK(size_k % gptq_marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", gptq_marlin::tile_k_size);
TORCH_CHECK(size_n % gptq_marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", gptq_marlin::tile_n_size);
TORCH_CHECK(size_k % marlin::tile_k_size == 0, "size_k = ", size_k,
" is not divisible by tile_k_size = ", marlin::tile_k_size);
TORCH_CHECK(size_n % marlin::tile_n_size == 0, "size_n = ", size_n,
" is not divisible by tile_n_size = ", marlin::tile_n_size);
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
......@@ -308,10 +303,9 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
torch::Tensor out =
torch::empty({size_k / gptq_marlin::tile_size,
size_n * gptq_marlin::tile_size / pack_factor},
options);
torch::Tensor out = torch::empty(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
// Detect if there is act_order
bool has_perm = perm.size(0) != 0;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment