Commit ad385667 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.6.3.post1-dev'

parents be0967c1 903593d3
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void awq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
#include "core/registration.h"
namespace marlin {
......@@ -122,7 +103,7 @@ __global__ void awq_marlin_repack_kernel(
}
uint32_t vals[8];
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
......@@ -143,7 +124,7 @@ __global__ void awq_marlin_repack_kernel(
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
......@@ -155,7 +136,7 @@ __global__ void awq_marlin_repack_kernel(
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
......@@ -167,21 +148,21 @@ __global__ void awq_marlin_repack_kernel(
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
......@@ -195,15 +176,15 @@ __global__ void awq_marlin_repack_kernel(
} // namespace marlin
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
#define CALL_IF(NUM_BITS) \
else if (num_bits == NUM_BITS) { \
cudaFuncSetAttribute( \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::awq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
int64_t size_n, int64_t num_bits) {
......@@ -266,4 +247,22 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k,
return out;
}
#endif
torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight,
c10::SymInt size_k, c10::SymInt size_n,
int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("awq_marlin_repack", &awq_marlin_repack_meta);
}
\ No newline at end of file
......@@ -23,6 +23,8 @@
#include "marlin_dtypes.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
......@@ -42,8 +44,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
int4* __restrict__ out_int4_ptr, int size_m,
int size_k, int block_rows) {}
template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
......@@ -151,20 +153,21 @@ __device__ inline uint32_t prmt(uint32_t a) {
return res;
}
// Efficiently dequantize an int32 value into a full B-fragment of 4 fp16
// values. We mostly follow the strategy in the link below, with some small
// changes:
template <typename scalar_t, vllm::ScalarTypeId w_type_id>
__device__ inline typename ScalarType<scalar_t>::FragB dequant(int q);
//
// Efficiently dequantize 4bit values packed in an int32 value into a full
// B-fragment of 4 fp16 values. We mostly follow the strategy in the link below,
// with some small changes:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L215-L287
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L327-L385
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
__device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU4B8.id()>(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
......@@ -187,7 +190,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_4bit<half>(int q) {
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_4bit<nv_bfloat16>(int q) {
dequant<nv_bfloat16, vllm::kU4B8.id()>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
......@@ -210,19 +213,64 @@ dequant_4bit<nv_bfloat16>(int q) {
return frag_b;
}
template <>
__device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU4.id()>(int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant<nv_bfloat16, vllm::kU4.id()>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
//
// Fast Int8ToFp16/Int8ToBf16: Efficiently dequantize 8bit int values to fp16 or
// bf16 Reference:
// - FP16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L53-L85
// - BF16:
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h#L125-L175
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
//
template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
__device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU8B128.id()>(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
......@@ -242,7 +290,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit<half>(int q) {
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit<nv_bfloat16>(int q) {
dequant<nv_bfloat16, vllm::kU8B128.id()>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4];
......@@ -269,68 +317,9 @@ dequant_8bit<nv_bfloat16>(int q) {
return frag_b;
}
// Zero-point dequantizers
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_4bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_4bit_zp<half>(
int q) {
const int LO = 0x000f000f;
const int HI = 0x00f000f0;
const int EX = 0x64006400;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, LO, EX);
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, HI, EX);
const int SUB = 0x64006400;
const int MUL = 0x2c002c00;
const int ADD = 0xd400d400;
typename ScalarType<half>::FragB frag_b;
frag_b[0] = __hsub2(*reinterpret_cast<half2*>(&lo),
*reinterpret_cast<const half2*>(&SUB));
frag_b[1] = __hfma2(*reinterpret_cast<half2*>(&hi),
*reinterpret_cast<const half2*>(&MUL),
*reinterpret_cast<const half2*>(&ADD));
return frag_b;
}
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_4bit_zp<nv_bfloat16>(int q) {
static constexpr uint32_t MASK = 0x000f000f;
static constexpr uint32_t EX = 0x43004300;
// Guarantee that the `(a & b) | c` operations are LOP3s.
int lo = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
q >>= 4;
int hi = lop3<(0xf0 & 0xcc) | 0xaa>(q, MASK, EX);
typename ScalarType<nv_bfloat16>::FragB frag_b;
static constexpr uint32_t MUL = 0x3F803F80;
static constexpr uint32_t ADD = 0xC300C300;
frag_b[0] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&lo),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
frag_b[1] = __hfma2(*reinterpret_cast<nv_bfloat162*>(&hi),
*reinterpret_cast<const nv_bfloat162*>(&MUL),
*reinterpret_cast<const nv_bfloat162*>(&ADD));
return frag_b;
}
template <typename scalar_t>
__device__ inline typename ScalarType<scalar_t>::FragB dequant_8bit_zp(int q) {
STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t);
}
template <>
__device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
int q) {
__device__ inline typename ScalarType<half>::FragB
dequant<half, vllm::kU8.id()>(int q) {
static constexpr uint32_t mask_for_elt_01 = 0x5250;
static constexpr uint32_t mask_for_elt_23 = 0x5351;
static constexpr uint32_t start_byte_for_fp16 = 0x64646464;
......@@ -350,7 +339,7 @@ __device__ inline typename ScalarType<half>::FragB dequant_8bit_zp<half>(
template <>
__device__ inline typename ScalarType<nv_bfloat16>::FragB
dequant_8bit_zp<nv_bfloat16>(int q) {
dequant<nv_bfloat16, vllm::kU8.id()>(int q) {
typename ScalarType<nv_bfloat16>::FragB frag_b;
float fp32_intermediates[4];
......@@ -517,8 +506,8 @@ __global__ void permute_cols_kernel(int4 const* __restrict__ a_int4_ptr,
}
}
template <typename scalar_t, // compute dtype, half or nv_float16
const int num_bits, // number of bits used for weights
template <typename scalar_t, // compute dtype, half or nv_float16
const vllm::ScalarTypeId w_type_id, // weight ScalarType id
const int threads, // number of threads in a threadblock
const int thread_m_blocks, // number of 16x16 blocks in the m
// dimension (batchsize) of the
......@@ -568,7 +557,9 @@ __global__ void Marlin(
using FragS = typename ScalarType<scalar_t>::FragS;
using FragZP = typename ScalarType<scalar_t>::FragZP;
constexpr int pack_factor = 32 / num_bits;
static constexpr auto w_type = vllm::ScalarType::from_id(w_type_id);
constexpr int pack_factor = 32 / w_type.size_bits();
// For larger GEMMs we run multiple batchsize 64 versions in parallel for a
// better partitioning with less reductions
......@@ -670,7 +661,7 @@ __global__ void Marlin(
// B sizes/strides
int b_gl_stride = 16 * prob_n / (pack_factor * 4);
constexpr int b_sh_stride = ((thread_n_blocks * 16) * 16 / pack_factor) / 4;
constexpr int b_thread_vecs = num_bits == 4 ? 1 : 2;
constexpr int b_thread_vecs = w_type.size_bits() == 4 ? 1 : 2;
constexpr int b_sh_stride_threads = b_sh_stride / b_thread_vecs;
int b_gl_rd_delta_o = b_gl_stride * thread_k_blocks;
......@@ -1186,19 +1177,20 @@ __global__ void Marlin(
if constexpr (has_zp) {
FragB frag_zp_0;
FragB frag_zp_1;
if constexpr (num_bits == 4) {
int zp_quant = frag_qzp[k % 2][0];
int zp_quant_shift = zp_quant >> 8;
frag_zp_0 = dequant_4bit_zp<scalar_t>(zp_quant);
frag_zp_1 = dequant_4bit_zp<scalar_t>(zp_quant_shift);
int zp_quant_0, zp_quant_1;
if constexpr (w_type.size_bits() == 4) {
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = zp_quant_0 >> 8;
} else {
int zp_quant_0 = frag_qzp[k % 2][0];
int zp_quant_1 = frag_qzp[k % 2][1];
frag_zp_0 = dequant_8bit_zp<scalar_t>(zp_quant_0);
frag_zp_1 = dequant_8bit_zp<scalar_t>(zp_quant_1);
static_assert(w_type.size_bits() == 8);
zp_quant_0 = frag_qzp[k % 2][0];
zp_quant_1 = frag_qzp[k % 2][1];
}
frag_zp_0 = dequant<scalar_t, w_type_id>(zp_quant_0);
frag_zp_1 = dequant<scalar_t, w_type_id>(zp_quant_1);
frag_zp[0] = frag_zp_0[0];
frag_zp[1] = frag_zp_0[1];
frag_zp[2] = frag_zp_1[0];
......@@ -1211,33 +1203,21 @@ __global__ void Marlin(
for (int j = 0; j < 4; j++) {
FragB frag_b0;
FragB frag_b1;
if constexpr (num_bits == 4) {
int b_quant = frag_b_quant[k % 2][0][j];
int b_quant_shift = b_quant >> 8;
if constexpr (has_zp) {
frag_b0 = dequant_4bit_zp<scalar_t>(b_quant);
frag_b1 = dequant_4bit_zp<scalar_t>(b_quant_shift);
} else {
frag_b0 = dequant_4bit<scalar_t>(b_quant);
frag_b1 = dequant_4bit<scalar_t>(b_quant_shift);
}
int b_quant_0, b_quant_1;
if constexpr (w_type.size_bits() == 4) {
b_quant_0 = frag_b_quant[k % 2][0][j];
b_quant_1 = b_quant_0 >> 8;
} else {
static_assert(w_type.size_bits() == 8);
int* frag_b_quant_ptr = reinterpret_cast<int*>(frag_b_quant[k % 2]);
int b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
int b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
if constexpr (has_zp) {
frag_b0 = dequant_8bit_zp<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit_zp<scalar_t>(b_quant_1);
} else {
frag_b0 = dequant_8bit<scalar_t>(b_quant_0);
frag_b1 = dequant_8bit<scalar_t>(b_quant_1);
}
b_quant_0 = frag_b_quant_ptr[j * 2 + 0];
b_quant_1 = frag_b_quant_ptr[j * 2 + 1];
}
frag_b0 = dequant<scalar_t, w_type_id>(b_quant_0);
frag_b1 = dequant<scalar_t, w_type_id>(b_quant_1);
// Apply zero-point to frag_b0
if constexpr (has_zp) {
sub_zp<scalar_t>(frag_b0, frag_zp[j], 0);
......@@ -1477,7 +1457,8 @@ __global__ void Marlin(
// For per-column quantization we finally apply the scale here (only for
// 4-bit)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 4) {
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 4) {
res = __hmul2(res, s[0]);
}
......@@ -1605,7 +1586,7 @@ __global__ void Marlin(
// For per-column scales, we only fetch them here in the final step before
// write-out
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (num_bits == 8) {
if constexpr (w_type.size_bits() == 8) {
if (s_sh_wr_pred) {
cp_async4(&sh_s[s_sh_wr], &scales_ptr[s_gl_rd]);
}
......@@ -1622,7 +1603,7 @@ __global__ void Marlin(
thread_block_reduce();
if constexpr (!has_act_order && group_blocks == -1) {
if constexpr (num_bits == 8) {
if constexpr (w_type.size_bits() == 8) {
cp_async_wait<0>();
__syncthreads();
if (threadIdx.x / 32 < thread_n_blocks / 4) {
......@@ -1645,7 +1626,8 @@ __global__ void Marlin(
// For 8-bit channelwise, we apply the scale before the global reduction
// that converts the fp32 results to fp16 (so that we avoid possible
// overflow in fp16)
if constexpr (!has_act_order && group_blocks == -1 && num_bits == 8) {
if constexpr (!has_act_order && group_blocks == -1 &&
w_type.size_bits() == 8) {
if (threadIdx.x / 32 < thread_n_blocks / 4) {
#pragma unroll
for (int i = 0; i < thread_m_blocks; i++) {
......@@ -1714,20 +1696,19 @@ __global__ void Marlin(
}
}
#define __CALL_IF(NUM_BITS, THREAD_M_BLOCKS, THREAD_N_BLOCKS, \
THREAD_K_BLOCKS, HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, \
NUM_THREADS) \
else if (num_bits == NUM_BITS && thread_m_blocks == THREAD_M_BLOCKS && \
#define __CALL_IF(W_TYPE, THREAD_M_BLOCKS, THREAD_N_BLOCKS, THREAD_K_BLOCKS, \
HAS_ACT_ORDER, HAS_ZP, GROUP_BLOCKS, NUM_THREADS) \
else if (q_type == W_TYPE && thread_m_blocks == THREAD_M_BLOCKS && \
thread_n_blocks == THREAD_N_BLOCKS && \
thread_k_blocks == THREAD_K_BLOCKS && \
has_act_order == HAS_ACT_ORDER && has_zp == HAS_ZP && \
group_blocks == GROUP_BLOCKS && num_threads == NUM_THREADS) { \
cudaFuncSetAttribute( \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
Marlin<scalar_t, NUM_BITS, NUM_THREADS, THREAD_M_BLOCKS, \
Marlin<scalar_t, W_TYPE.id(), NUM_THREADS, THREAD_M_BLOCKS, \
THREAD_N_BLOCKS, THREAD_K_BLOCKS, pipe_stages, HAS_ACT_ORDER, \
HAS_ZP, GROUP_BLOCKS> \
<<<blocks, NUM_THREADS, max_shared_mem, stream>>>( \
......@@ -1923,52 +1904,52 @@ exec_config_t determine_thread_config(int prob_m, int prob_n, int prob_k,
return exec_config_t{0, {-1, -1, -1}};
}
#define GPTQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(NUM_BITS, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(NUM_BITS, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
#define GPTQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, true, false, 0, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, false, 8, NUM_THREADS)
#define AWQ_CALL_IF(W_TYPE, N_BLOCKS, K_BLOCKS, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 1, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 2, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 3, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS) \
\
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, -1, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 2, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 4, NUM_THREADS) \
__CALL_IF(W_TYPE, 4, N_BLOCKS, K_BLOCKS, false, true, 8, NUM_THREADS)
template <typename scalar_t>
void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
......@@ -2113,23 +2094,23 @@ void marlin_mm(const void* A, const void* B, void* C, void* C_tmp, void* s,
if (false) {
}
GPTQ_CALL_IF(4, 16, 4, 256)
GPTQ_CALL_IF(4, 8, 8, 256)
GPTQ_CALL_IF(4, 8, 4, 128)
GPTQ_CALL_IF(4, 4, 8, 128)
GPTQ_CALL_IF(8, 16, 4, 256)
GPTQ_CALL_IF(8, 8, 8, 256)
GPTQ_CALL_IF(8, 8, 4, 128)
GPTQ_CALL_IF(8, 4, 8, 128)
AWQ_CALL_IF(4, 16, 4, 256)
AWQ_CALL_IF(4, 8, 8, 256)
AWQ_CALL_IF(4, 8, 4, 128)
AWQ_CALL_IF(4, 4, 8, 128)
AWQ_CALL_IF(8, 16, 4, 256)
AWQ_CALL_IF(8, 8, 8, 256)
AWQ_CALL_IF(8, 8, 4, 128)
AWQ_CALL_IF(8, 4, 8, 128)
GPTQ_CALL_IF(vllm::kU4B8, 16, 4, 256)
GPTQ_CALL_IF(vllm::kU4B8, 8, 8, 256)
GPTQ_CALL_IF(vllm::kU4B8, 8, 4, 128)
GPTQ_CALL_IF(vllm::kU4B8, 4, 8, 128)
GPTQ_CALL_IF(vllm::kU8B128, 16, 4, 256)
GPTQ_CALL_IF(vllm::kU8B128, 8, 8, 256)
GPTQ_CALL_IF(vllm::kU8B128, 8, 4, 128)
GPTQ_CALL_IF(vllm::kU8B128, 4, 8, 128)
AWQ_CALL_IF(vllm::kU4, 16, 4, 256)
AWQ_CALL_IF(vllm::kU4, 8, 8, 256)
AWQ_CALL_IF(vllm::kU4, 8, 4, 128)
AWQ_CALL_IF(vllm::kU4, 4, 8, 128)
AWQ_CALL_IF(vllm::kU8, 16, 4, 256)
AWQ_CALL_IF(vllm::kU8, 8, 8, 256)
AWQ_CALL_IF(vllm::kU8, 8, 4, 128)
AWQ_CALL_IF(vllm::kU8, 4, 8, 128)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [", prob_m, ", ", prob_n,
", ", prob_k, "]", ", has_act_order = ", has_act_order,
......@@ -2279,7 +2260,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
"b_zeros dim 0 = ", b_zeros.size(0),
" is not num_groups = ", num_groups);
TORCH_CHECK(b_zeros.size(1) == size_n / pack_factor,
"b_zeros dim 1 = ", b_scales.size(1),
"b_zeros dim 1 = ", b_zeros.size(1),
" is not size_n / pack_factor = ", size_n / pack_factor);
}
......@@ -2318,3 +2299,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
}
#endif
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_gemm", &gptq_marlin_gemm);
}
\ No newline at end of file
#include "marlin.cuh"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
namespace marlin {
template <int const num_threads, int const num_bits, bool const has_perm>
__global__ void gptq_marlin_repack_kernel(
uint32_t const* __restrict__ b_q_weight_ptr,
uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr,
int size_k, int size_n) {}
} // namespace marlin
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
int64_t num_bits) {
TORCH_CHECK_NOT_IMPLEMENTED(
false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0");
return torch::empty({1, 1});
}
#else
#include "core/registration.h"
namespace marlin {
......@@ -174,13 +154,13 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t b1_vals[tile_ints];
uint32_t b2_vals[tile_ints];
#pragma unroll
#pragma unroll
for (int i = 0; i < tile_ints; i++) {
b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i];
b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i];
}
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
int cur_elem = tc_row + tc_offsets[i];
int cur_int = cur_elem / pack_factor;
......@@ -200,7 +180,7 @@ __global__ void gptq_marlin_repack_kernel(
constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7};
uint32_t res = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 8; i++) {
res |= vals[pack_idx[i]] << (i * 4);
}
......@@ -212,7 +192,7 @@ __global__ void gptq_marlin_repack_kernel(
uint32_t res1 = 0;
uint32_t res2 = 0;
#pragma unroll
#pragma unroll
for (int i = 0; i < 4; i++) {
res1 |= vals[pack_idx[i]] << (i * 8);
res2 |= vals[4 + pack_idx[i]] << (i * 8);
......@@ -224,14 +204,14 @@ __global__ void gptq_marlin_repack_kernel(
};
auto start_pipes = [&](int k_tile_id, int n_tile_id) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages - 1; pipe++) {
fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe);
}
wait_for_stage();
};
#pragma unroll
#pragma unroll
for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) {
int n_tile_id = 0;
......@@ -242,7 +222,7 @@ __global__ void gptq_marlin_repack_kernel(
start_pipes(k_tile_id, n_tile_id);
while (n_tile_id < n_tiles) {
#pragma unroll
#pragma unroll
for (int pipe = 0; pipe < repack_stages; pipe++) {
fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id,
n_tile_id + pipe + repack_stages - 1);
......@@ -256,17 +236,17 @@ __global__ void gptq_marlin_repack_kernel(
} // namespace marlin
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
#define CALL_IF(NUM_BITS, HAS_PERM) \
else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \
cudaFuncSetAttribute( \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \
marlin::gptq_marlin_repack_kernel<marlin::repack_threads, NUM_BITS, \
HAS_PERM> \
<<<blocks, marlin::repack_threads, max_shared_mem, stream>>>( \
b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \
}
torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t size_k, int64_t size_n,
......@@ -341,4 +321,22 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
return out;
}
#endif
torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight,
torch::Tensor& perm, c10::SymInt size_k,
c10::SymInt size_n, int64_t num_bits) {
int const pack_factor = 32 / num_bits;
auto options = torch::TensorOptions()
.dtype(b_q_weight.dtype())
.device(b_q_weight.device());
return torch::empty_symint(
{size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor},
options);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack);
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) {
m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta);
}
\ No newline at end of file
# Machete (Mixed Precision Cutlass-Based GEMM)
Machete is a spiritual successor to the Marlin kernel but optimized for Hopper architectures and based on Cutlass. Being based on Cutlass, new type pairs and epilogues are easier to add compared to Marlin.
## Overview
Machete effectively performs
```
scale_type = w_s.dtype
compute_type = a.dtype
out = (w_q.to(scale_type) * w_s - w_z.to(scale_type)) @ a
```
Where `w_q` is a quantized weight matrix, `w_s` is the quantization scales, and
`w_z` is the quantization zeropoints.
> **_NOTE:_** `w_z` is added after the scales so we can
use FMA operations, but this means they must have the scales pre-applied if the
supplied zeropoints assume that they will be subtracted before the scales are
applied.
## API
The main optimization within Machete is prepacking the weight matrix to more closely match the tensor core layouts, allowing for wider shared memory loads when loading the weight matrix. This means that the weight matrix must be prepacked before calling `machete_gemm`. The flow looks something like:
```
from vllm import _custom_ops as ops
...
W_q_packed = ops.machete_prepack_B(w_q, wtype)
output = ops.machete_gemm(
a,
b_q=W_q_packed,
b_type=wtype,
b_scales=w_s,
b_group_size=group_size
)
```
## Code Generation
Since Machete is based on Cutlass, we can generate multiple type pairs and different tile shapes using the same kernel template. We generate multiple instantiations of this template using `generate.py`.
New type pairs (`TypeConfig`s) can be appended to `impl_configs` (in `generate()`), and these will get automatically generated (assuming they can be supported without issues). For each `TypeConfig`, you must also provide an `ImplConfig`, which bundles a `TypeConfig` with a list of `ScheduleConfig`s, `Specialization`s, and a default heuristic. The `ScheduleConfig`s (which contain info on tile shapes, tile scheduler, etc.) can perform differently for different problem shapes, and there is almost never one `ScheduleConfig` that works well for all problem shapes, so it is generally beneficial to generate different `ScheduleConfig`s for different potential problem shapes. This is where the heuristic comes in. For each `TypeConfig`, a default heuristic should be provided. This maps different problem shapes to different `ScheduleConfig`s and is used when the user does not provide the `schedule` parameter to `machete_gemm`. The `Specialization`s define what feature combinations to generate, i.e., `with_zeropoints`, `with_scales`, etc. We can reduce compile times and the final binary size by limiting the set of feature combinations we generate.
\ No newline at end of file
import itertools
import math
import os
import shutil
from collections.abc import Iterable
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import jinja2
# yapf conflicts with isort for this block
# yapf: disable
from vllm_cutlass_library_extension import (DataType, EpilogueScheduleTag,
EpilogueScheduleType,
MixedInputKernelScheduleType,
TileSchedulerTag,
TileSchedulerType, VLLMDataType,
VLLMDataTypeNames, VLLMDataTypeTag,
VLLMKernelScheduleTag)
# yapf: enable
#
# Generator templating
#
DISPATCH_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
using GemmDispatcher_ = GemmDispatcher<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
{% for s in schedules %}extern torch::Tensor
impl_{{type_name}}_sch_{{ gen_sch_name(s) }}(PyTorchArguments args);
{% endfor %}
template <>
torch::Tensor GemmDispatcher_::dispatch(PyTorchArguments args) {
[[maybe_unused]] auto M = args.A.size(0);
[[maybe_unused]] auto N = args.B.size(1);
[[maybe_unused]] auto K = args.A.size(1);
if (!args.schedule) {
{%- for cond, s in heuristic %}
{%if cond is not none%}if ({{cond}})
{%- else %}else
{%- endif %}
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);{% endfor %}
}
{% for s in schedules %}
if (*args.schedule == "{{ gen_sch_name(s) }}") {
return impl_{{ type_name }}_sch_{{ gen_sch_name(s) }}(args);
}
{% endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(false, "machete_gemm(..) is not implemented for "
"schedule = ", *args.schedule);
}
template <>
std::vector<std::string> GemmDispatcher_::supported_schedules() {
return {
{% for s in schedules -%}
"{{ gen_sch_name(s) }}"{{ ",
" if not loop.last }}{%- endfor %}
};
}
}; // namespace machete
"""
IMPL_TEMPLATE = """
#include "../machete_mm_launcher.cuh"
namespace machete {
template <typename Config, bool with_C, bool with_scales, bool with_zeropoints>
using Kernel = MacheteKernelTemplate<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
{{DataTypeTag[type_config.element_b_zeropoint]}}, // Zeropoints
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput,
Config, with_C, with_scales, with_zeropoints>;
{% for sch in schedules %}
{% set schedule_name = gen_sch_name(sch) -%}
struct sch_{{schedule_name}} {
using TileShapeNM = Shape<{{
to_cute_constant(sch.tile_shape_mn)|join(', ')}}>;
using ClusterShape = Shape<{{
to_cute_constant(sch.cluster_shape_mnk)|join(', ')}}>;
// TODO: Reimplement
// using KernelSchedule = {{KernelScheduleTag[sch.kernel_schedule]}};
using EpilogueSchedule = {{EpilogueScheduleTag[sch.epilogue_schedule]}};
using TileScheduler = {{TileSchedulerTag[sch.tile_scheduler]}};
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;
};
torch::Tensor
impl_{{type_name}}_sch_{{schedule_name}}(PyTorchArguments args) {
bool with_C = args.C.has_value(), with_scales = args.scales.has_value(),
with_zeropoints = args.zeros.has_value();
{% for s in specializations %}
if (with_C == {{s.with_C|lower}}
&& with_zeropoints == {{s.with_zeropoints|lower}}
&& with_scales == {{s.with_scales|lower}}) {
return run_impl<Kernel<sch_{{schedule_name}}, {{s.with_C|lower}},
{{s.with_scales|lower}}, {{s.with_zeropoints|lower}}>>(args);
}{% endfor %}
TORCH_CHECK_NOT_IMPLEMENTED(
false, "for the sake of compile times and binary size machete_mm(..) is "
" not implemented for with_C=", with_C, ", with_scales=", with_scales,
", with_zeropoints=", with_zeropoints,
" (for {{type_name}}_sch_{{schedule_name}})");
}
{% endfor %}
}; // namespace machete
"""
PREPACK_TEMPLATE = """
#include "../machete_prepack_launcher.cuh"
namespace machete {
using PrepackBDispatcher_ = PrepackBDispatcher<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
{{DataTypeTag[type_config.element_b_scale]}}, // Scales
{{DataTypeTag[type_config.element_b_zeropoint]}}>; // Zeropoints
using PrepackedLayoutB = PrepackedLayoutBTemplate<
{{DataTypeTag[type_config.element_a]}}, // ElementA
{{DataTypeTag[type_config.element_b]}}, // ElementB
{{DataTypeTag[type_config.element_d]}}, // ElementD
{{DataTypeTag[type_config.accumulator]}}, // Accumulator
cutlass::layout::ColumnMajor,
cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput>;
template <>
torch::Tensor PrepackBDispatcher_::dispatch(torch::Tensor B) {
return prepack_impl<PrepackedLayoutB>(B);
}
}; // namespace machete
"""
TmaMI = MixedInputKernelScheduleType.TmaWarpSpecializedCooperativeMixedInput
TmaCoop = EpilogueScheduleType.TmaWarpSpecializedCooperative
@dataclass(frozen=True)
class ScheduleConfig:
tile_shape_mn: Tuple[int, int]
cluster_shape_mnk: Tuple[int, int, int]
kernel_schedule: MixedInputKernelScheduleType
epilogue_schedule: EpilogueScheduleType
tile_scheduler: TileSchedulerType
@dataclass
class TypeConfig:
element_a: DataType
element_b: Union[DataType, VLLMDataType]
element_b_scale: DataType
element_b_zeropoint: DataType
element_d: DataType
accumulator: DataType
@dataclass
class Specialization:
with_C: bool
with_zeropoints: bool
with_scales: bool
@dataclass
class ImplConfig:
type_config: TypeConfig
schedule_configs: List[ScheduleConfig]
specializations: List[Specialization]
heuristic: List[Tuple[Optional[str], ScheduleConfig]]
def generate_schedule_name(schedule_config: ScheduleConfig) -> str:
tile_shape = (
f"{schedule_config.tile_shape_mn[0]}x{schedule_config.tile_shape_mn[1]}"
)
cluster_shape = (f"{schedule_config.cluster_shape_mnk[0]}" +
f"x{schedule_config.cluster_shape_mnk[1]}" +
f"x{schedule_config.cluster_shape_mnk[2]}")
kernel_schedule = VLLMKernelScheduleTag[schedule_config.kernel_schedule]\
.split("::")[-1]
epilogue_schedule = EpilogueScheduleTag[
schedule_config.epilogue_schedule].split("::")[-1]
tile_scheduler = TileSchedulerTag[schedule_config.tile_scheduler]\
.split("::")[-1]
return (f"{tile_shape}_{cluster_shape}_{kernel_schedule}" +
f"_{epilogue_schedule}_{tile_scheduler}")
# mostly unique shorter schedule_name
def generate_terse_schedule_name(schedule_config: ScheduleConfig) -> str:
kernel_terse_names_replace = {
"KernelTmaWarpSpecializedCooperativeMixedInput_": "TmaMI_",
"TmaWarpSpecializedCooperative_": "TmaCoop_",
"StreamKScheduler": "streamK",
}
schedule_name = generate_schedule_name(schedule_config)
for orig, terse in kernel_terse_names_replace.items():
schedule_name = schedule_name.replace(orig, terse)
return schedule_name
# unique type_name
def generate_type_signature(kernel_type_config: TypeConfig):
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
element_d = VLLMDataTypeNames[kernel_type_config.element_d]
accumulator = VLLMDataTypeNames[kernel_type_config.accumulator]
element_scale = VLLMDataTypeNames[kernel_type_config.element_b_scale]
element_zeropoint = VLLMDataTypeNames[
kernel_type_config.element_b_zeropoint]
return (f"{element_a}{element_b}{element_d}"
f"{accumulator}{element_scale}{element_zeropoint}")
# non-unique shorter type_name
def generate_terse_type_signature(kernel_type_config: TypeConfig):
element_a = VLLMDataTypeNames[kernel_type_config.element_a]
element_b = VLLMDataTypeNames[kernel_type_config.element_b]
return f"{element_a}{element_b}"
def is_power_of_two(n):
return (n != 0) and (n & (n - 1) == 0)
def to_cute_constant(value: List[int]):
def _to_cute_constant(value: int):
if is_power_of_two(value):
return f"_{value}"
else:
return f"Int<{value}>"
if isinstance(value, Iterable):
return [_to_cute_constant(value) for value in value]
else:
return _to_cute_constant(value)
template_globals = {
"DataTypeTag": VLLMDataTypeTag,
"KernelScheduleTag": VLLMKernelScheduleTag,
"EpilogueScheduleTag": EpilogueScheduleTag,
"TileSchedulerTag": TileSchedulerTag,
"to_cute_constant": to_cute_constant,
"gen_sch_name": generate_terse_schedule_name,
}
def create_template(template_str):
template = jinja2.Template(template_str)
template.globals.update(template_globals)
return template
mm_dispatch_template = create_template(DISPATCH_TEMPLATE)
mm_impl_template = create_template(IMPL_TEMPLATE)
prepack_dispatch_template = create_template(PREPACK_TEMPLATE)
def create_sources(impl_config: ImplConfig, num_impl_files=1):
sources = []
type_name = generate_type_signature(impl_config.type_config)
terse_type_name = generate_terse_type_signature(impl_config.type_config)
sources.append((
f"machete_mm_{terse_type_name}",
mm_dispatch_template.render(type_name=type_name,
type_config=impl_config.type_config,
schedules=impl_config.schedule_configs,
heuristic=impl_config.heuristic),
))
sources.append((
f"machete_prepack_{terse_type_name}",
prepack_dispatch_template.render(
type_name=type_name,
type_config=impl_config.type_config,
),
))
num_schedules = len(impl_config.schedule_configs)
schedules_per_file = math.ceil(num_schedules / num_impl_files)
for part, i in enumerate(range(0, num_schedules, schedules_per_file)):
file_schedules = impl_config.schedule_configs[i:i + schedules_per_file]
sources.append((
f"machete_mm_{terse_type_name}_impl_part{part}",
mm_impl_template.render(
type_name=type_name,
type_config=impl_config.type_config,
schedules=file_schedules,
specializations=impl_config.specializations,
),
))
return sources
def generate():
# See csrc/quantization/machete/Readme.md, the Codegeneration for more info
# about how this works
SCRIPT_DIR = os.path.dirname(__file__)
schedule_common_params = dict(
kernel_schedule=TmaMI,
epilogue_schedule=TmaCoop,
tile_scheduler=TileSchedulerType.StreamK,
)
# For now we use the same heuristic for all types
# Heuristic is currently tuned for H100s
default_heuristic = [
#### M = 257+
(
"M > 256 && K <= 16384 && N <= 4096",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 256",
ScheduleConfig(
tile_shape_mn=(128, 256),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 129-256
(
"M > 128 && K <= 4096 && N <= 4096",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 128 && K <= 8192 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 128",
ScheduleConfig(
tile_shape_mn=(128, 256),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 65-128
(
"M > 64 && K <= 4069 && N <= 4069",
ScheduleConfig(
tile_shape_mn=(128, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64 && K <= 4069 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64 && K >= 8192 && N >= 12288",
ScheduleConfig(
tile_shape_mn=(256, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 64",
ScheduleConfig(
tile_shape_mn=(128, 128),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 33-64
(
"M > 32 && K <= 6144 && N <= 6144",
ScheduleConfig(
tile_shape_mn=(128, 16),
cluster_shape_mnk=(1, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 32 && K >= 16384 && N >= 12288",
ScheduleConfig(
tile_shape_mn=(256, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 32",
ScheduleConfig(
tile_shape_mn=(128, 64),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 17-32
(
"M > 16 && K <= 12288 && N <= 8192",
ScheduleConfig(
tile_shape_mn=(128, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
(
"M > 16",
ScheduleConfig(
tile_shape_mn=(256, 32),
cluster_shape_mnk=(2, 1, 1),
**schedule_common_params # type: ignore
)),
#### M = 1-16
(
"N >= 26624",
ScheduleConfig(
tile_shape_mn=(256, 16),
cluster_shape_mnk=(1, 1, 1),
**schedule_common_params # type: ignore
)),
(
None,
ScheduleConfig(
tile_shape_mn=(128, 16),
cluster_shape_mnk=(1, 1, 1),
**schedule_common_params # type: ignore
)),
]
# Do not use schedules = list(set(...)) because we need to make sure
# the output list is deterministic; otherwise the generated kernel file
# will be non-deterministic and causes ccache miss.
schedules = []
for _, schedule_config in default_heuristic:
if schedule_config not in schedules:
schedules.append(schedule_config)
impl_configs = []
GPTQ_kernel_type_configs = list(
(TypeConfig(
element_a=element_a,
element_b=element_b,
element_b_scale=element_a,
element_b_zeropoint=element_a,
element_d=element_a,
accumulator=DataType.f32,
) for element_b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for element_a in (DataType.f16, DataType.bf16)))
GPTQ_kernel_specializations = [
Specialization(with_C=False, with_zeropoints=False, with_scales=True)
]
impl_configs += [
ImplConfig(x[0], x[1], x[2], x[3])
for x in zip(GPTQ_kernel_type_configs, itertools.repeat(schedules),
itertools.repeat(GPTQ_kernel_specializations),
itertools.repeat(default_heuristic))
]
AWQ_kernel_type_configs = list(
(TypeConfig(
element_a=element_a,
element_b=element_b,
element_b_scale=element_a,
element_b_zeropoint=element_a,
element_d=element_a,
accumulator=DataType.f32,
) for element_b in (DataType.u4, DataType.u8)
for element_a in (DataType.f16, DataType.bf16)))
AWQ_kernel_specializations = [
Specialization(with_C=False, with_zeropoints=True, with_scales=True)
]
impl_configs += [
ImplConfig(x[0], x[1], x[2], x[3])
for x in zip(AWQ_kernel_type_configs, itertools.repeat(schedules),
itertools.repeat(AWQ_kernel_specializations),
itertools.repeat(default_heuristic))
]
output_dir = os.path.join(SCRIPT_DIR, "generated")
# Delete the "generated" directory if it exists
if os.path.exists(output_dir):
shutil.rmtree(output_dir)
# Create the "generated" directory
os.makedirs(output_dir)
# Render each group of configurations into separate files
for impl_config in impl_configs:
for filename, code in create_sources(impl_config):
filepath = os.path.join(output_dir, f"{filename}.cu")
with open(filepath, "w") as output_file:
output_file.write(code)
print(f"Rendered template to {filepath}")
if __name__ == "__main__":
generate()
#pragma once
#include "cutlass_extensions/vllm_collective_builder.cuh"
#include "machete_mainloop.cuh"
namespace cutlass::gemm::collective {
using namespace cute;
struct MacheteKernelTag {};
template <class ElementPairA_, class GmemLayoutA_, int AlignmentA,
class ElementPairB_, class GmemLayoutB_, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK,
class StageCountType, class KernelScheduleType>
struct VLLMCollectiveBuilder<
MacheteKernelTag, arch::Sm90, arch::OpClassTensorOp, ElementPairA_,
GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_, AlignmentB,
ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType,
KernelScheduleType,
cute::enable_if_t<(
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>)>> {
using CollectiveOp = machete::MacheteCollectiveMma<
ElementPairA_, GmemLayoutA_, AlignmentA, ElementPairB_, GmemLayoutB_,
AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK,
StageCountType, KernelScheduleType>;
};
}; // namespace cutlass::gemm::collective
\ No newline at end of file
#pragma once
#include "cutlass/cutlass.h"
#include "cute/layout.hpp"
namespace machete {
using namespace cute;
// get an interleaved block layout where each element consecutive element has a
// stride of bit_stride and the block width is blk_bit_width,
// examples:
// size_bits<T> = 8, bit_stride = 8, blk_bit_width = 32 -> 4:1
// size_bits<T> = 8, bit_stride = 16, blk_bit_width = 32 -> (2, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 8, blk_bit_width = 32 -> (4, 2):(2, 1)
// size_bits<T> = 4, bit_stride = 16, blk_bit_width = 32 -> (2, 4):(4, 1)
template <typename T, int bit_stride, int blk_bit_width>
CUTE_HOST_DEVICE static constexpr auto get_interleaved_blk_layout() {
static_assert(blk_bit_width % bit_stride == 0);
static_assert(bit_stride % cute::sizeof_bits_v<T> == 0);
constexpr auto elems_per_blk = blk_bit_width / cute::sizeof_bits_v<T>;
if constexpr (cute::sizeof_bits_v<T> == bit_stride) {
// identity layout
return Layout<Shape<Int<elems_per_blk>>>{};
} else {
constexpr auto elems_per_stride = bit_stride / cute::sizeof_bits_v<T>;
constexpr auto num_strides = elems_per_blk / elems_per_stride;
return Layout<Shape<Int<num_strides>, Int<elems_per_stride>>,
Stride<Int<elems_per_stride>, Int<1>>>{};
}
}
}; // namespace machete
//
// Based off of:
// cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp
// Specifically:
// https://github.com/NVIDIA/cutlass/tree/06b21349bcf6ddf6a1686a47a137ad1446579db9/include/cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp
// Referred to as upstream from in the comments
//
// The main optimization machete implements compared to upstream is to prepack
// the weight matrix to more closely match the shape of the wgmma instructions
// allowing for wider (ideally 128bit) shared memory loads. For subbyte types
// this is done by packing values from multiple wgmma loads (for a single
// thread) into a single 128bit load. This is very similar to layout used in
// Marlin, although specific to the wgmma instructions.
//
// Since the wgmma instructions only support sourcing from registers for the A
// operand, and we want to upconvert/decompress the weight values/elements
// before feeding them into the tensor cores in registers, we need the weight
// matrix to be A. To achieve this we compute the transpose of Y = XW^t as
// Y^t = W^tX^t. This is mostly done outside of this file in
// csrc/quantization/machete/machete_mm_kernel.cuh, but this why A is the
// quantized/narrow type and has the prepacked layout despite the API being:
// B_prepacked = machete_prepack_B(B)
// Y = machete_mm(A, B_prepacked)
//
#pragma once
// clang-format off
#include "cutlass/cutlass.h"
#include "cutlass/numeric_conversion.h"
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/detail/dependent_false.hpp"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/detail/layout.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/atom/copy_traits_sm90_tma.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/tensor_predicate.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/transform/collective/sm90_wgmma_transpose.hpp"
#include "cutlass/trace.h"
#include "cutlass/detail/collective.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
namespace machete {
using namespace cute;
using namespace cutlass;
using namespace cutlass::gemm;
using namespace cutlass::gemm::collective;
using namespace cutlass::gemm::collective::detail;
template <class ElementATuple_, class GmemLayoutA, int AlignmentA,
class ElementB_, class GmemLayoutB, int AlignmentB,
class ElementAccumulator_, class TileShape_MNK,
class ClusterShape_MNK, class StageCountType,
class KernelScheduleType>
struct MacheteCollectiveMma {
using Schedule = KernelScheduleType;
static_assert(
cute::is_same_v<Schedule, KernelTmaWarpSpecialized> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedMixedInput> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedPingpong> ||
cute::is_same_v<Schedule,
KernelTmaWarpSpecializedPingpongMixedInput> ||
cute::is_same_v<Schedule, KernelTmaWarpSpecializedCooperative> ||
cute::is_same_v<Schedule,
KernelTmaWarpSpecializedCooperativeMixedInput>,
"KernelSchedule must be one of the warp specialized policies");
public:
static constexpr bool ALayoutIsPrepacked = true;
// Prepacked block shape (N is M in the transposed problem)
using PPBlockShape_MK = typename GmemLayoutA::PPBlockShape_NK;
// Prepacked blocks per dim for a single MMA tile
using PPBlocksPerTile_MK = decltype(make_shape(
size<0>(TileShape_MNK{}) / size<0>(PPBlockShape_MK{}),
size<2>(TileShape_MNK{}) / size<1>(PPBlockShape_MK{})));
using IlvdBlkLayout = typename GmemLayoutA::IlvdBlkLayout;
static_assert(size<0>(TileShape_MNK{}) % size<0>(PPBlockShape_MK{}) == 0,
"M in PPBlockShape_MK must evenly divide M TileShape_MNK");
static_assert(size<2>(TileShape_MNK{}) % size<1>(PPBlockShape_MK{}) == 0,
"K in PPBlockShape_MK must evenly divide K TileShape_MNK");
using ArchTag = arch::Sm90;
using TileShape = TileShape_MNK;
using ClusterShape = ClusterShape_MNK;
using ElementA = deduce_mixed_width_dtype_t<0, ElementATuple_>;
using StrideA = TagToStrideA_t<layout::RowMajor>;
using ElementB = ElementB_;
using StrideB = TagToStrideB_t<GmemLayoutB>;
using ElementAccumulator = ElementAccumulator_;
using ElementMma = ElementB;
using ElementATuple =
cute::conditional_t<!cute::is_tuple<ElementATuple_>::value,
cute::tuple<ElementA>, ElementATuple_>;
static constexpr cute::GMMA::Major GmmaMajorA =
gmma_rs_tag_to_major_A<layout::RowMajor>();
static constexpr cute::GMMA::Major GmmaMajorB =
gmma_rs_tag_to_major_B<GmemLayoutB>();
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelScheduleType,
KernelTmaWarpSpecializedCooperativeMixedInput>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
TileShape_MNK, GMMA::Major::K, GmmaMajorB>(),
AtomLayoutMNK{}));
private:
//
// the setup section (until "section setup end") contains a combination of
// modified code from (used as a starting point):
// `cutlass/gemm/collective/builders/sm90_gmma_builder.inl`
// `cutlass/gemm/collective/sm90_mma_tma_gmma_rs_warpspecialized_mixed_input.hpp`
// (upstream)
//
// however in-order to simplify the code we combine a lot of the logic from
// `CollectiveMma` and `CollectiveBuilder` into this class, this also makes
// sense given that we have flexibility on layouts here. We also simplify the
// code by only supporting scales and zeros for A (in the transposed problem,
// B from an API perspective), also since we force A to be the narrow type
// (i.e. the type to be upconverted) we can remove all the `SwapAB` logic in
// the upstream also simplifying the code. This section includes new logic
// (compared ustream) for handling the prepacked-A layouts (in the transposed
// problem, B from an API perspective)
//
using ElementScale = deduce_mixed_width_dtype_t<1, ElementATuple_>;
using ElementZero = deduce_mixed_width_dtype_t<2, ElementATuple_>;
static constexpr bool IsANarrow = cutlass::sizeof_bits<ElementA>::value <
cutlass::sizeof_bits<ElementB>::value;
static_assert(IsANarrow,
"A must be the narrow one since its the one that flows through "
"registers.");
public:
static constexpr int PipelineStages =
compute_stage_count_or_override_single_affine_transformed_input<
sm90_smem_capacity_bytes, ElementA, ElementB, ElementScale,
ElementZero, TileShape_MNK>(StageCountType{});
struct DispatchPolicy {
constexpr static int Stages = PipelineStages;
using ClusterShape = ClusterShape_MNK;
using Schedule = KernelScheduleType;
};
using GmemTiledCopyA =
decltype(sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB =
decltype(sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
// ((T, V), (BlocksM, BlocksK), pipe) -> offset
using SmemLayoutA = decltype(GmemLayoutA::TVbNbKL_to_offset(
make_shape(size<0>(TileShape_MNK{}), size<2>(TileShape_MNK{}),
Int<DispatchPolicy::Stages>{})));
using SmemLayoutAtomARowMajor =
decltype(rs_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomScale = Layout<
Shape<decltype(cute::shape<0>(SmemLayoutAtomARowMajor{})), cute::Int<1>>>;
using SmemLayoutAtomB =
decltype(rs_smem_selector<GmmaMajorB, ElementB,
decltype(cute::get<1>(TileShape_MNK{})),
decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemCopyAtomA = Copy_Atom<cute::DefaultCopy, ElementA>;
using SmemCopyAtomB = void;
//
// Validity checks
//
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(is_aligned<ElementA, AlignmentA, ElementB, AlignmentB,
tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>,
"Unsupported Toolkit for SM90 Collective Builder\n");
#endif
private:
enum class ConversionMode {
DirectConvert,
ConvertAndScale,
ConvertAndScaleWithZero
};
public:
//
// Type Aliases
//
using KernelSchedule = KernelScheduleType;
// For cases where we can't have a void type, we can use this to allow the
// code to compile when the scale / zero is void.
using NonVoidElementScale =
cute::conditional_t<cute::is_void_v<ElementScale>, float, ElementScale>;
using NonVoidElementZero =
cute::conditional_t<cute::is_void_v<ElementZero>, float, ElementZero>;
// These are always MN major
using StrideScale = cute::Stride<cute::Int<1>, int64_t, int64_t>;
// For cases where we can't have a void scale, we can use this to allow the
// code to compile when the scale is void.
using NonVoidStrideScale =
cute::conditional_t<cute::is_void_v<StrideScale>,
cute::Stride<_1, int64_t, int64_t>, StrideScale>;
static_assert((cutlass::gemm::detail::is_k_major<StrideA>()),
"The transformed matrix (A) must be K-major.");
static_assert((sizeof(ElementB) == 2) ||
(cutlass::gemm::detail::is_k_major<StrideA>() &&
cutlass::gemm::detail::is_k_major<StrideB>()),
"The unscaled element (matrix B) must be 2 bytes OR both "
"inputs must be K-major");
static_assert(cutlass::gemm::detail::is_mn_major<NonVoidStrideScale>(),
"Scale must be MN major [Col Major if A is scaled, Row Major "
"if B is scaled].");
static_assert(std::is_same_v<typename TiledMma::ValTypeC, ElementAccumulator>,
"TiledMma::ValTypeC must be the same as ElementAccumulator.");
using GmemTiledCopyScale = cute::SM90_TMA_LOAD;
using SmemCopyAtomScale = Copy_Atom<cute::DefaultCopy, NonVoidElementScale>;
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any
// rounding by TMA.
static constexpr bool ConvertF32toTF32A = cute::is_same_v<float, ElementA>;
static constexpr bool ConvertF32toTF32B = cute::is_same_v<float, ElementB>;
using InternalElementA =
cute::conditional_t<ConvertF32toTF32A, tfloat32_t,
uint_bit_t<sizeof_bits_v<ElementA>>>;
using InternalElementB =
cute::conditional_t<ConvertF32toTF32B, tfloat32_t,
uint_bit_t<sizeof_bits_v<ElementB>>>;
using TransformA = cute::identity;
using TransformB = cute::identity;
static constexpr int IsSubbyteA = cute::sizeof_bits_v<InternalElementA> < 8;
using TmaElementA =
cute::conditional_t<IsSubbyteA, uint8_t, InternalElementA>;
using MainloopPipeline = cutlass::PipelineTmaAsync<DispatchPolicy::Stages>;
using PipelineState = cutlass::PipelineState<DispatchPolicy::Stages>;
using PipelineParams = typename MainloopPipeline::Params;
using ScaleTileShape = decltype(make_shape(shape<0>(TileShape{}),
shape<1>(SmemLayoutAtomScale{})));
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
"SmemLayoutAtom must be rank 2 (M/N, K)");
static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0,
"SmemLayoutAtom must evenly divide tile shape.");
static_assert(rank(SmemLayoutAtomScale{}) == 2,
"SmemLayoutAtomScale must be rank 2");
static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomScale{})) == 0,
"SmemLayoutAtomScale must equal the tile shape.");
static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomScale{})) == 0,
"SmemLayoutAtomScale must evenly divide tile k shape.");
// Tile along modes in a way that maximizes the TMA box size.
using SmemLayoutACopy = decltype(tile_to_shape(
SmemLayoutAtomARowMajor{},
make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideA>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
using SmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
// It is assumed that the scales and zero-points share the same smem layout
using SmemLayoutScale = decltype(tile_to_shape(
SmemLayoutAtomScale{},
make_shape(shape<0>(ScaleTileShape{}), shape<1>(ScaleTileShape{}),
Int<PipelineStages>{})));
// If A mn-layout and B mn-layout, transposing B matrix since WGMMA is k-major
// only (e.g. tf32, fp32, fp8, int8).
static constexpr bool IsLayoutAmnBmn =
cute::is_same_v<gemm::detail::StrideToLayoutTagA_t<StrideA>,
layout::ColumnMajor> &&
cute::is_same_v<gemm::detail::StrideToLayoutTagB_t<StrideB>,
layout::RowMajor>;
static_assert(DispatchPolicy::Stages >= 2,
"Specialization requires Stages set to value 2 or more.");
static_assert(not cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeA>::value &&
cute::is_base_of<cute::GMMA::DescriptorIterator,
typename TiledMma::FrgTypeB>::value,
"MMA atom must source A from rmem and B operand from smem_desc "
"for this mainloop.");
static_assert(cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
static_assert(cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD> ||
cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>,
"GmemTiledCopy - invalid SM90 TMA copy atom specified.");
using GmmaSmemLayoutB = decltype(tile_to_shape(
SmemLayoutAtomB{},
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}),
Int<DispatchPolicy::Stages>{}),
conditional_t<::cutlass::gemm::detail::is_major<0, StrideB>(),
Step<_2, _1, _3>, Step<_1, _2, _3>>{}));
// These two restrictions are related, so we place the assertions together.
// To relax them, we need to handle loading more than 1 row of scales for
// every main loop iteration. We must also handle updating the pipeline
// transaction bytes on the fly. NOTE: Deleting this assertion without
// required changes will cause the code to hang.
static_assert(size<1>(SmemLayoutAtomScale{}) == 1,
"size<1>(SmemLayoutAtomScale) must be 1.");
private:
static constexpr ConversionMode get_conversion_mode() {
if constexpr (cute::is_void_v<ElementScale>) {
return ConversionMode::DirectConvert;
} else if constexpr (cute::is_void_v<ElementZero>) {
return ConversionMode::ConvertAndScale;
} else {
return ConversionMode::ConvertAndScaleWithZero;
}
}
static constexpr ConversionMode KernelConversionMode = get_conversion_mode();
static constexpr bool ModeHasScales =
KernelConversionMode == ConversionMode::ConvertAndScale ||
KernelConversionMode == ConversionMode::ConvertAndScaleWithZero;
// Same as upstream, should be kept the same when possible
static constexpr auto elements_per_smem_scale() {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return 0;
} else if constexpr (ModeHasScales) {
return cute::cosize_v<SmemLayoutScale>;
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Type not handled in scale smem allocation.");
}
}
// Same as upstream, should be kept the same when possible
static constexpr auto elements_per_smem_zero() {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
KernelConversionMode == ConversionMode::ConvertAndScale) {
return 0;
} else if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
return cute::cosize_v<SmemLayoutScale>;
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Type not handled in scale smem allocation.");
}
}
// Same as upstream, should be kept the same when possible, not formatte for
// easier comparison
// clang-format off
// These methods use some the public members of the class. For that reason, we define them after the public section.
static constexpr uint32_t
compute_tma_transaction_bytes_mk() {
constexpr uint32_t baseline_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast<uint32_t>(cute::sizeof_bits_v<InternalElementA>));
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return baseline_bytes;
}
else if constexpr (ModeHasScales) {
constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast<uint32_t>(cute::sizeof_bits_v<ElementScale>));
static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return baseline_bytes + scale_tx_bytes;
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
// Scale and zero share smem layout
constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast<uint32_t>(cute::sizeof_bits_v<ElementZero>));
static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA
return baseline_bytes + scale_tx_bytes + zero_tx_bytes;
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Type not handled in tma transaction bytes computation.");
}
}
static constexpr uint32_t
compute_tma_transaction_bytes_nk() {
return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast<uint32_t>(cute::sizeof_bits_v<InternalElementB>));
}
// clang-format on
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
using PrepackedStrideA = decltype(stride(GmemLayoutA::TVbNbKL_to_offset(
make_shape(int32_t(0), int32_t(0), int32_t(0)))));
using ATensor = decltype(make_tensor(
get_logical_ptr(static_cast<InternalElementA const*>(nullptr)),
shape(GmemLayoutA::TVbNbKL_to_offset(
make_shape(int32_t(0), int32_t(0), int32_t(0)))),
PrepackedStrideA{}));
using BTensor = decltype(make_tensor(
get_logical_ptr(static_cast<InternalElementB const*>(nullptr)),
repeat_like(StrideB{}, int32_t(0)), StrideB{}));
using ScaleTensor = decltype(make_tensor(
get_logical_ptr(static_cast<NonVoidElementScale const*>(nullptr)),
repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}));
using ZeroTensor = decltype(make_tensor(
get_logical_ptr(static_cast<NonVoidElementZero const*>(nullptr)),
repeat_like(NonVoidStrideScale{}, int32_t(0)), NonVoidStrideScale{}));
static constexpr auto make_tma_copy_A(ATensor tensor_a = ATensor{}) {
return make_tma_copy<TmaElementA>(
GmemTiledCopyA{}, tensor_a, SmemLayoutA{}(_, _, cute::Int<0>{}),
shape(SmemLayoutA{}(_, _, cute::Int<0>{})),
size<1>(ClusterShape{})); // mcast along N mode for this M load, if any
}
static constexpr auto make_tma_copy_scale(
ScaleTensor tensor_scale = ScaleTensor{}) {
return make_tma_copy(GmemTiledCopyScale{}, tensor_scale,
SmemLayoutScale{}(_, _, cute::Int<0>{}),
ScaleTileShape{},
_1{}); // mcast along N mode for this M load, if any
}
static constexpr auto make_tma_copy_zero(
ZeroTensor tensor_zero = ZeroTensor{}) {
return make_tma_copy(GmemTiledCopyScale{}, tensor_zero,
SmemLayoutScale{}(_, _, cute::Int<0>{}),
ScaleTileShape{},
_1{}); // mcast along N mode for this M load, if any
}
static constexpr auto make_tma_copy_B(BTensor tensor_b = BTensor{}) {
return make_tma_copy(
GmemTiledCopyB{}, tensor_b, SmemLayoutB{}(_, _, cute::Int<0>{}),
make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})),
size<0>(ClusterShape{})); // mcast along M mode for this N load, if any
}
public:
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// with `RealInternalElementA` -> `ElementA` since we support `SwapAB` logic
// clang-format off
static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{});
static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{});
// Just pick the max alignment of A and B since it is required to be at least 128B
static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB);
static_assert(SmemAlignmentA >= 128 and SmemAlignmentB >= 128, "Require at least 128B alignment");
struct SharedStorage
{
static constexpr int scale_elements = elements_per_smem_scale();
static constexpr int zero_elements = elements_per_smem_zero();
struct TensorStorage : cute::aligned_struct<cute::max(SmemAlignmentA, SmemAlignmentB)> {
cute::ArrayEngine<ElementA, cute::cosize_v<SmemLayoutA>> smem_A;
cute::ArrayEngine<typename TiledMma::ValTypeB, cute::cosize_v<SmemLayoutB>> smem_B;
cute::ArrayEngine<NonVoidElementScale, scale_elements> smem_scale;
cute::ArrayEngine<NonVoidElementZero, zero_elements> smem_zero;
} tensors;
using PipelineStorage = typename MainloopPipeline::SharedStorage;
PipelineStorage pipeline;
};
using TensorStorage = typename SharedStorage::TensorStorage;
using PipelineStorage = typename SharedStorage::PipelineStorage;
// Host side kernel arguments
struct Arguments {
ElementA const* ptr_A = nullptr;
StrideA dA{};
ElementB const* ptr_B = nullptr;
StrideB dB{};
ElementScale const* ptr_S = nullptr;
NonVoidStrideScale dS{};
int group_size = 0;
ElementZero const* ptr_Z = nullptr;
uint32_t mma_promotion_interval = 4;
};
// clang-format on
//
// section setup end
//
// Similar (but not idendtical) to upstream, should be kept the same when
// possible
// compared to upstream we use `make_tma_copy_A`, `make_tma_copy_B` etc. to
// define the TMA types
// Device side kernel params
struct Params {
public:
// Assumption: StrideA is congruent with Problem_MK
using TMA_A = decltype(make_tma_copy_A());
using TMA_Scale = decltype(make_tma_copy_scale());
using TMA_Zero = decltype(make_tma_copy_zero());
using TMA_B = decltype(make_tma_copy_B());
// required by outer loop: i.e.
// cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp
TMA_A tma_load_a;
TMA_B tma_load_b;
TMA_Scale tma_load_scale;
TMA_Zero tma_load_zero;
int64_t scale_k;
int group_size;
uint32_t tma_transaction_bytes = TmaTransactionBytes;
uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK;
uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK;
};
//
// Methods
//
// Similar (but not idendtical) to upstream, should be kept the same when
// possible
// compared to upstream we use `make_tma_copy_A` and `TVbNbKL_to_offset` here
// to handle the prepacked layout
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const& problem_shape, Arguments const& args,
void* workspace) {
(void)workspace;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is
// only rank-3 (MNK)
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M, N, K, L] = problem_shape_MNKL;
auto ptr_A = reinterpret_cast<InternalElementA const*>(args.ptr_A);
auto ptr_B = reinterpret_cast<InternalElementB const*>(args.ptr_B);
auto make_logical_tensor = [&](auto ptr, auto shape, auto stride) {
return make_tensor(get_logical_ptr(ptr), make_layout(shape, stride));
};
typename Params::TMA_A tma_load_a;
typename Params::TMA_B tma_load_b;
typename Params::TMA_Scale tma_load_scale;
typename Params::TMA_Zero tma_load_zero;
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
tma_load_a = make_tma_copy_A(
make_logical_tensor(ptr_A, shape(layout), stride(layout)));
tma_load_b = make_tma_copy_B(
make_logical_tensor(ptr_B, make_shape(N, K, L), args.dB));
int32_t scale_k =
(ModeHasScales) ? (K + args.group_size - 1) / args.group_size : 0;
int32_t group_size = (ModeHasScales) ? args.group_size : 0;
if constexpr (ModeHasScales) {
tma_load_scale = make_tma_copy_scale(
make_logical_tensor(args.ptr_S, make_shape(M, scale_k, L), args.dS));
}
if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
tma_load_zero = make_tma_copy_zero(
make_logical_tensor(args.ptr_Z, make_shape(M, scale_k, L), args.dS));
}
if constexpr (KernelConversionMode == ConversionMode::DirectConvert ||
KernelConversionMode == ConversionMode::ConvertAndScale ||
KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
return {tma_load_a, tma_load_b, tma_load_scale,
tma_load_zero, scale_k, group_size};
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in to_underlying_arguments.");
}
}
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// with `SwapAB ? N : M -> M` since we dont support SwapAB
// clang-format off
template<class ProblemShape>
static bool
can_implement(
ProblemShape const& problem_shape,
[[maybe_unused]] Arguments const& args) {
constexpr int tma_alignment_bits = 128;
auto problem_shape_MNKL = append<4>(problem_shape, 1);
auto [M,N,K,L] = problem_shape_MNKL;
bool implementable = true;
constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits<ElementA>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_A>(cute::make_shape(M,K,L), StrideA{});
constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits<ElementB>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_B>(cute::make_shape(N,K,L), StrideB{});
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
implementable = implementable && (args.ptr_S == nullptr);
implementable = implementable && (args.ptr_Z == nullptr);
}
else if constexpr (ModeHasScales) {
const int scale_mn = M;
const int scale_k = (K + args.group_size - 1) / args.group_size;
constexpr int min_tma_aligned_elements_scale = tma_alignment_bits / cutlass::sizeof_bits<ElementScale>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_scale>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.group_size == K || ((args.group_size % size<2>(TileShape{})) == 0));
implementable = implementable && args.group_size != 0;
implementable = implementable && (args.ptr_S != nullptr);
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
implementable = implementable && (args.ptr_Z == nullptr);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
constexpr int min_tma_aligned_elements_zero = tma_alignment_bits / cutlass::sizeof_bits<ElementZero>::value;
implementable = implementable && cutlass::detail::check_alignment<min_tma_aligned_elements_zero>(cute::make_shape(scale_mn,scale_k,L), StrideScale{});
implementable = implementable && (args.ptr_Z != nullptr);
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in can_implement.");
}
if (!implementable) {
CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n");
}
return implementable;
}
static constexpr int K_PIPE_MAX = DispatchPolicy::Stages;
static constexpr uint32_t TmaTransactionBytesMK = compute_tma_transaction_bytes_mk();
static constexpr uint32_t TmaTransactionBytesNK = compute_tma_transaction_bytes_nk();
static constexpr uint32_t TmaTransactionBytes = TmaTransactionBytesMK + TmaTransactionBytesNK;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static void prefetch_tma_descriptors(Params const& mainloop_params) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor());
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// Nothing extra to do
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor());
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
cute::prefetch_tma_descriptor(mainloop_params.tma_load_scale.get_tma_descriptor());
cute::prefetch_tma_descriptor(mainloop_params.tma_load_zero.get_tma_descriptor());
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA prefetch.");
}
}
// clang-format off
// Modified from upstream, should be kept close to that when possible
// the main difference is special handling for the prepacked A layout
//
// Set up the data needed by this collective for load and mma.
// Returns a tuple of tensors. The collective and the kernel layer have the
// contract Returned tuple must contain at least two elements, with the first
// two elements being: gA_mkl - The tma tensor, A after a local tile so it
// has shape (TILE_V,TILE_B,m,k,l) gB_nkl - The tma tensor, B after a local
// tile so it has shape (TILE_N,TILE_K,n,k,l) The rest of the tensors can be
// specified as needed by this collective.
// NOTE: TILE_B is the prepacked block index within a tile. TILE_V is the
// values within a prepacked block.
template <class ProblemShape_MNKL>
CUTLASS_DEVICE auto load_init(ProblemShape_MNKL const& problem_shape_MNKL,
Params const& mainloop_params) const {
using X = Underscore;
auto M = get<0>(problem_shape_MNKL), N = get<1>(problem_shape_MNKL),
K = get<2>(problem_shape_MNKL), L = get<3>(problem_shape_MNKL);
// (TILE_V,TILE_B,m,k,l)
auto make_gA_mkl = [&]() {
// ((athrid, val), (BlocksM, BlockK), L) -> (storage_idx)
auto layout = GmemLayoutA::TVbNbKL_to_offset(make_shape(M, K, L));
Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(shape(layout));
return local_tile(mA_mkl,
make_shape(size<0>(layout), PPBlocksPerTile_MK{}),
make_coord(0, make_coord(_, _)));
};
// (TILE_N,TILE_K,n,k,l)
auto make_gB_nkl = [&]() {
Tensor mB_nkl =
mainloop_params.tma_load_b.get_tma_tensor(make_shape(N, K, L));
return local_tile(mB_nkl, TileShape{}, make_coord(_, _, _),
Step<X, _1, _1>{});
};
// (TILE_M,TILE_Scale_K,m,scale_k,l)
auto make_gS_mkl = [&]() {
auto scale_k = mainloop_params.scale_k;
Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(
make_shape(M, scale_k, L));
return local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _));
};
// (TILE_M,TILE_Scale_K,m,scale_k,l)
auto make_gZ_mkl = [&]() {
auto scale_k = mainloop_params.scale_k;
Tensor mZ_mkl = mainloop_params.tma_load_zero.get_tma_tensor(
make_shape(M, scale_k, L));
return local_tile(mZ_mkl, ScaleTileShape{}, make_coord(_, _));
};
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return cute::make_tuple(make_gA_mkl(), make_gB_nkl());
} else if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScale) {
return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl());
} else if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
return cute::make_tuple(make_gA_mkl(), make_gB_nkl(), make_gS_mkl(),
make_gZ_mkl());
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in load_init.");
}
}
// Similar to upstream, should be kept close to that when possible
// the main difference is in the layout comments
// clang-format off
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
/// This overload gets triggered when we have scales.
template <
class... Ts,
class KTileIterator, class BlockCoord
>
CUTLASS_DEVICE void
load(
Params const& mainloop_params,
MainloopPipeline pipeline,
PipelineState smem_pipe_write,
cute::tuple<Ts...> const& load_inputs,
BlockCoord const& blk_coord,
KTileIterator k_tile_iter, int k_tile_count,
int thread_idx,
uint32_t block_rank_in_cluster,
TensorStorage& shared_tensors) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
static_assert(sizeof... (Ts) == 2, "Direct convert needs two inputs");
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
static_assert(sizeof... (Ts) == 3, "Scaled convert needs three inputs");
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
static_assert(sizeof... (Ts) == 4, "Scaled and zero convert needs four inputs");
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in TMA load.");
}
int lane_predicate = cute::elect_one_sync();
if (lane_predicate) {
Tensor sA_ = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE)
Tensor sB_ = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
Tensor sA = as_position_independent_swizzle_tensor(sA_); // (BLK_M,BLK_K,PIPE)
Tensor sB = as_position_independent_swizzle_tensor(sB_); // (BLK_N,BLK_K,PIPE)
//
// Prepare the TMA loads for A, B and Scales
//
constexpr uint32_t cluster_shape_x = get<0>(ClusterShape());
uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x};
Tensor gA_mkl = get<0>(load_inputs);
Tensor gB_nkl = get<1>(load_inputs);
auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y);
auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x);
// Partition the inputs based on the current block coordinates.
auto [m_coord, n_coord, k_coord, l_coord] = blk_coord;
Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (TILE_V,TILE_B,k)
Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (TILE_N,TILE_K,k)
// Applies the mapping from block_tma_a
Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k)
Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE)
Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k)
Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE)
uint16_t mcast_mask_a = 0;
uint16_t mcast_mask_b = 0;
uint16_t mcast_mask_s = 0;
// Issue TmaLoads
// Maps the tile -> block, value
if constexpr (cute::is_same_v<GmemTiledCopyA, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int n = 0; n < size<1>(block_layout); ++n) {
mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{}));
}
}
if constexpr (cute::is_same_v<GmemTiledCopyB, SM90_TMA_LOAD_MULTICAST>) {
auto block_layout = Layout<typename DispatchPolicy::ClusterShape>{}; // (m,n) -> block_id
for (int m = 0; m < size<0>(block_layout); ++m) {
mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{}));
}
}
auto extra_input_partitions = partition_extra_tma_inputs(mainloop_params, load_inputs, shared_tensors, cluster_local_block_id, m_coord, l_coord);
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for ( ; k_tile_count > 0; --k_tile_count) {
// LOCK smem_pipe_write for _writing_
pipeline.producer_acquire(smem_pipe_write);
//
// Copy gmem to smem for *k_tile_iter
//
using BarrierType = typename MainloopPipeline::ProducerBarrierType;
BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write);
int write_stage = smem_pipe_write.index();
copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage));
copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage));
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// Nothing extra to do.
}
else if constexpr (ModeHasScales) {
auto tSgS = get<0>(extra_input_partitions);
auto tSsS = get<1>(extra_input_partitions);
// Temporary factor which will determine which k tile to reload from gmem. Needed so we don't modify tma transaction bytes
// on the fly.
// We must do a ceiling divide here to correctly handle with group_size == K. In that case, we don't require that K
// is a multiple of the threadblock tile K
const int ReloadFactor = (mainloop_params.group_size + size<2>(TileShape{}) - 1) / size<2>(TileShape{});
const int scale_load_k = *k_tile_iter / ReloadFactor; // This will always be 0 when group_size == K.
copy(mainloop_params.tma_load_scale.with(*tma_barrier, mcast_mask_s), tSgS(_,_,_,scale_load_k), tSsS(_,_,_,write_stage));
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
// Nothing extra to do
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
auto tZgZ = get<2>(extra_input_partitions);
auto tZsZ = get<3>(extra_input_partitions);
copy(mainloop_params.tma_load_zero.with(*tma_barrier, mcast_mask_s), tZgZ(_,_,_,scale_load_k), tZsZ(_,_,_,write_stage));
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for TMA copy op.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for TMA copy op.");
}
++k_tile_iter;
// Advance smem_pipe_write
++smem_pipe_write;
}
}
}
// clang-format off
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// clang-format off
// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE void
load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) {
int lane_predicate = cute::elect_one_sync();
// Issue the epilogue waits
if (lane_predicate) {
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline.producer_tail(smem_pipe_write);
}
}
// clang-format on
// Modified from upstream, should be kept close to that when possible
// the main differences are handling the prepacked A layout, and separating
// the loading of A from upcoverting A
//
// Perform a collective-scoped matrix multiply-accumulate
// Consumer Perspective
template <class FrgTensorC>
CUTLASS_DEVICE void mma(MainloopPipeline pipeline,
PipelineState smem_pipe_read, FrgTensorC& accum,
int k_tile_count, int thread_idx,
TensorStorage& shared_tensors,
Params const& mainloop_params) {
static_assert(is_rmem<FrgTensorC>::value,
"C tensor must be rmem resident.");
static_assert(cute::rank(SmemLayoutB{}) == 3,
"Smem layout must be rank 3.");
static_assert(cute::rank(SmemLayoutAtomB{}) == 2,
"SmemLayoutAtomB must be rank 2.");
static_assert(!cute::is_void_v<SmemCopyAtomA>,
"SM90 GMMA mainloops must specify a non-void copy atom for "
"RF sourced instructions.");
static_assert(cute::is_void_v<SmemCopyAtomB>,
"SM90 GMMA mainloops cannot have a non-void copy atom for "
"smem sourced instructions.");
// Obtain warp index
int warp_idx = canonical_warp_idx_sync();
[[maybe_unused]] int warp_group_thread_idx = thread_idx % 128;
// ((T, (FrgV,(RestM, RestK)), (BlocksM, BlocksK), pipe) -> offset
auto constexpr smem_A = SmemLayoutA{};
// convert:
// ((T, (MMA,(MMA_M, MMA_K)), (BlocksM, BlocksK), pipe) -> offset
// to:
// (T, MMA, ((MMA_M, BlocksM), (MMA_K, BlocksK)), pipe) -> offset
// which can be thought of as:
// (T, MMA, (MMA_M, MMA_K), pipe) -> offset
auto constexpr smem_A_mma_ =
make_layout(get<0, 0>(smem_A), get<0, 1, 0>(smem_A),
zip(get<0, 1, 1>(smem_A), get<1>(smem_A)), get<2>(smem_A));
// flatten to:
// (T, MMA, MMA_M, MMA_K, pipe) -> offset
auto constexpr smem_A_mma = smem_A_mma_(_, _, make_coord(_, _), _);
Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.begin()),
smem_A_mma); // (T, MMA, MMA_M, MMA_K, pipe)
Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.begin()),
SmemLayoutB{}); // (BLK_N,BLK_K,PIPE)
//
// Define C accumulators and A/B partitioning
//
TiledMma tiled_mma;
auto thread_mma = tiled_mma.get_thread_slice(thread_idx);
Tensor tCsA = sA(thread_idx, _, _, _, _); // (MMA,MMA_M,MMA_K,PIPE)
Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE)
// Allocate fragments and descriptors
Tensor tCrA_load = make_tensor<ElementA>(
tCsA(_, _, _, Int<0>{}).shape()); // (MMA,MMA_N,MMA_K)
Tensor tCrA_mma = make_fragment_like<ElementMma>(tCrA_load);
Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE)
static constexpr int A_CPY_VEC =
decltype(max_common_vector(tCsA, tCrA_load)){};
static constexpr int COVERSION_WIDTH =
std::min(A_CPY_VEC, int(size<0>(tCrA_mma)));
auto load_A_to_registers = [&](int read_stage) {
copy(create_auto_vectorizing_copy<ElementA, decltype(A_CPY_VEC)>(),
tCsA(_, _, _, read_stage), tCrA_load(_, _, _));
};
// Partition of thread -> shared and thread -> RF
auto partitioned_extra_info =
partition_extra_mma_info(thread_mma, shared_tensors);
auto copy_partitions_extra_info = retile_extra_mma_info(
tiled_mma, partitioned_extra_info, warp_group_thread_idx);
CUTE_STATIC_ASSERT_V(size<1>(tCrA_mma) == size<1>(accum)); // MMA_M
CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N
CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K
CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE
CUTE_STATIC_ASSERT_V(Int<DispatchPolicy::Stages>{} == size<2>(sB)); // PIPE
//
// PIPELINED MAIN LOOP
//
auto convert_A = [&, a_vec = Int<COVERSION_WIDTH>{}](int k_block,
int read_stage) {
load_extra_info_to_registers(partitioned_extra_info,
copy_partitions_extra_info, k_block,
read_stage);
transform_A_kblock(tCrA_load, a_vec, tCrA_mma, partitioned_extra_info,
k_block);
};
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState smem_pipe_release = smem_pipe_read;
tiled_mma.accumulate_ = GMMA::ScaleOut::Zero;
warpgroup_fence_operand(accum);
constexpr int K_BLOCK_MAX = size<2>(tCrA_load);
ConsumerToken barrier_token = {BarrierStatus::WaitAgain};
// first k tile
{
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
pipeline.consumer_wait(smem_pipe_read, barrier_token);
int read_stage = smem_pipe_read.index();
++smem_pipe_read;
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
// copy smem->rmem for A operand
load_A_to_registers(read_stage);
convert_A(0, read_stage);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
if (k_block < K_BLOCK_MAX - 1) {
convert_A(k_block + 1, smem_pipe_read.index());
}
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
tCrB(_, _, k_block, read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
}
--k_tile_count;
if (k_tile_count > 0) {
// Wait for K_BLOCK_MAX - 1 to be in flight to ensure that it is safe to
// overwrite the A registers for the first mma.
warpgroup_wait<K_BLOCK_MAX - 1>();
pipeline.consumer_wait(smem_pipe_read, barrier_token);
load_A_to_registers(smem_pipe_read.index());
convert_A(0, smem_pipe_read.index());
}
}
if (k_tile_count == 0) {
return;
}
warpgroup_fence_operand(accum);
// Mainloop GMMAs
CUTLASS_PRAGMA_NO_UNROLL
for (; k_tile_count > 1; --k_tile_count) {
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
++smem_pipe_read;
warpgroup_fence_operand(accum);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
tCrB(_, _, k_block, read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
warpgroup_wait<K_BLOCK_MAX - 1>();
if (k_block == K_BLOCK_MAX - 1) {
// We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage,
// so we can release prior barrier
pipeline.consumer_release(
smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_
// on it
++smem_pipe_release;
}
if (k_block == 0) {
barrier_token = pipeline.consumer_try_wait(smem_pipe_read);
}
if (k_block == K_BLOCK_MAX - 1) {
pipeline.consumer_wait(smem_pipe_read, barrier_token);
load_A_to_registers(smem_pipe_read.index());
convert_A(0, smem_pipe_read.index());
} else {
convert_A(k_block + 1, read_stage);
}
}
warpgroup_fence_operand(accum);
}
warpgroup_fence_operand(accum);
{
//
// Compute on k_tile
//
int read_stage = smem_pipe_read.index();
warpgroup_fence_operand(accum);
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) {
warpgroup_arrive();
// (V,M) x (V,N) => (V,M,N)
cute::gemm(tiled_mma, tCrA_mma(_, _, k_block),
tCrB(_, _, k_block, read_stage), accum);
tiled_mma.accumulate_ = GMMA::ScaleOut::One;
warpgroup_commit_batch();
warpgroup_wait<K_BLOCK_MAX - 1>();
if (k_block == K_BLOCK_MAX - 1) {
// release prior barrier
pipeline.consumer_release(
smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_
// on it
++smem_pipe_release;
}
if (k_block < K_BLOCK_MAX - 1) {
convert_A(k_block + 1, read_stage);
}
}
}
warpgroup_fence_operand(accum);
}
// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE void mma_tail(MainloopPipeline pipeline,
PipelineState smem_pipe_release,
int k_tile_count) {
// Prologue GMMAs
int prologue_mma_count = 1;
k_tile_count -= prologue_mma_count;
smem_pipe_release.advance(k_tile_count);
// Wait on all GMMAs to complete
warpgroup_wait<0>();
for (int count = 0; count < prologue_mma_count; ++count) {
pipeline.consumer_release(
smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on
// it
++smem_pipe_release;
}
}
private:
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// clang-format off
/// Utilities for any additional inputs inside of the TMA load
template <class... Ts>
CUTLASS_DEVICE
auto partition_extra_tma_inputs(
Params const& mainloop_params,
cute::tuple<Ts...> const& load_inputs,
TensorStorage& shared_tensors,
uint2 const& cluster_local_block_id,
int const m_coord,
int const l_coord) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
return cute::make_tuple();
}
else if constexpr (ModeHasScales) {
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
Tensor gS_mkl = get<2>(load_inputs);
auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y);
Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k)
Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(tSgS, tSsS);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE)
Tensor gZ_mkl = get<3>(load_inputs);
auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y);
Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k)
Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k)
Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE)
return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ);
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled for input partitioning.");
}
}
// clang-format off
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// clang-format off
/// Utilities for partitioning extra inputs for loading from smem in the mainloop.
template <class ThreadMma>
CUTLASS_DEVICE
auto partition_extra_mma_info(
ThreadMma const& mma_thread_slice,
TensorStorage& shared_tensors) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
return cute::make_tuple();
}
else if constexpr (ModeHasScales) {
Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsS = mma_thread_slice.partition_A(sS);
Tensor tCrS = make_tensor<ElementScale>(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).shape());
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(tCsS, tCrS);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE)
Tensor tCsZ = mma_thread_slice.partition_A(sZ);
Tensor tCrZ = make_tensor<ElementZero>(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).shape());
return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ);
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
// clang-format on
// Same as upstream, should be kept the same when possible, not formatted for
// easier comparison
// clang-format off
/// Returns the tiled copy and copy views for the extra inputs.
template <class TiledMma, class... Ts>
CUTLASS_DEVICE
auto retile_extra_mma_info(
TiledMma const& tiled_mma,
cute::tuple<Ts...>& partitioned_extra_info,
int const warp_group_thread_idx) {
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
return cute::make_tuple();
}
else if constexpr (ModeHasScales) {
auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma);
auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx);
Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view);
}
else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) {
Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K)
return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view);
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>, "Conversion mode not handled in A -> RF path.");
}
}
// clang-format on
// Similar to `copy_A_and_extra_info` upstream, should be kept the same when
// possible
// the main differences this only loads the extra info into registers and
// not A (since we now preload more of A in the main pipeline)
// Load scales and zeros into registers if required
template <class... Ts, class... Us>
CUTLASS_DEVICE void load_extra_info_to_registers(
cute::tuple<Ts...> const& partitioned_mma_extra_info,
cute::tuple<Us...> const& tiled_copy_and_views, int k_block,
int read_stage) {
if (k_block == 0) {
// We are starting a new k-tile so copy the scale
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
// nothing to do
} else if constexpr (ModeHasScales) {
auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views);
auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views);
auto tCsS = cute::get<0>(partitioned_mma_extra_info);
copy(smem_tiled_copy_S, tCsS(_, _, k_block, read_stage),
tCrS_copy_view(_, _, k_block));
if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) {
// Nothing extra to do
} else if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
auto tCsZ = cute::get<2>(partitioned_mma_extra_info);
auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views);
copy(smem_tiled_copy_S, tCsZ(_, _, k_block, read_stage),
tCrZ_copy_view(_, _, k_block));
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in A -> RF path.");
}
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"Conversion mode not handled in A -> RF path.");
}
}
}
// Similar to upstream, should be kept the same when possible.
// the main differences are that `convert_tensor` supports interleaved
// layouts and bfloat16 has been optimized. `transform_internal_A` has also
// been inlined for code simplicity.
// Utilities to transform A.
template <class TCrA_load, int VectorWidthA, class TCrA_mma, class... Ts>
CUTLASS_DEVICE void transform_A_kblock(
TCrA_load const& tCrA_load, cute::Int<VectorWidthA> vec_A,
TCrA_mma& tCrA_mma, cute::tuple<Ts...> const& partitioned_extra_info,
int const k_block) {
auto in = tCrA_load(_, _, k_block);
auto out = tCrA_mma(_, _, k_block);
if constexpr (KernelConversionMode == ConversionMode::DirectConvert) {
convert_tensor<IlvdBlkLayout>(in, out, vec_A);
} else if constexpr (ModeHasScales) {
auto tCrS = cute::get<1>(partitioned_extra_info);
auto converted_inputs =
make_fragment_like<ElementScale>(tCrA_mma)(_, _, k_block);
auto scales = tCrS(_, _, 0);
// First, we upcast the inputs to the scale type
convert_tensor<IlvdBlkLayout>(in, converted_inputs, vec_A);
// Apply scales and broadcast across inputs, store in converted_inputs
// We need to cast to nv_bfloat16 for the multiply since
// `cutlass::bfloat16_t` has an overloaded operator* that upconverts to
// float, which nvcc will not optimize to using vectorized fma
// instructions (i.e. hfma.bf16_v2)
if constexpr (std::is_same_v<ElementScale, cutlass::bfloat16_t>) {
cute::transform(
recast<nv_bfloat16>(converted_inputs), recast<nv_bfloat16>(scales),
recast<nv_bfloat16>(converted_inputs), cute::multiplies{});
} else {
cute::transform(converted_inputs, scales, converted_inputs,
cute::multiplies{});
}
// Apply zeros if required
if constexpr (KernelConversionMode ==
ConversionMode::ConvertAndScaleWithZero) {
auto tCrZ = cute::get<3>(partitioned_extra_info);
auto converted_zeros = make_fragment_like<ElementScale>(tCrZ)(_, _, 0);
convert_tensor<void>(tCrZ(_, _, 0), converted_zeros);
if constexpr (std::is_same_v<ElementScale, cutlass::bfloat16_t>) {
cute::transform(recast<nv_bfloat16>(converted_inputs),
recast<nv_bfloat16>(converted_zeros),
recast<nv_bfloat16>(converted_inputs), cute::plus{});
} else {
cute::transform(converted_inputs, converted_zeros, converted_inputs,
cute::plus{});
}
}
// Finally, we convert the scaled inputs to the mma type.
convert_tensor<void>(converted_inputs, out);
} else {
static_assert(cutlass::detail::dependent_false<KernelSchedule>,
"No A data is loaded.");
}
}
// Modified from upstream, should be kept the same when possible
// the main differences is that this version supports interleaved converts
// Utilities for transforming the A operand prior to issuing tensorcore math.
template <typename IlvdBlkLayout, class EngineIn, class EngineOut,
class TensorLayout,
int ConversionVectorWidth = cosize_v<TensorLayout>>
CUTLASS_DEVICE void convert_tensor(
Tensor<EngineIn, TensorLayout> const& in,
Tensor<EngineOut, TensorLayout>& out,
cute::Int<ConversionVectorWidth> width = {}) {
// This is an element-wise conversion where we expect both tensors to have
// the same layout. As a result, we can cast as a cutlass array to use the
// fast numeric converters without worrying about indexing into the layout.
constexpr int N = cosize_v<TensorLayout>;
// The inputs must be backed by registers & be statically sized.
static_assert(is_rmem<EngineIn>::value,
"Input tensor for A conversion must come from registers");
static_assert(is_rmem<EngineOut>::value,
"Output tensor for A conversion must come from registers");
static_assert(is_static_v<TensorLayout>,
"Tensor layout for the conversion must be static");
static_assert(cosize_v<TensorLayout> == size(TensorLayout{}),
"Cosize and size of the layout must be equal.");
static_assert(
N % ConversionVectorWidth == 0,
"Conversion vector width must divide cosize of the tensor layout.");
using SrcType = typename EngineIn::value_type;
using DstType = typename EngineOut::value_type;
using SrcArray = cutlass::Array<SrcType, ConversionVectorWidth>;
using DstArray = cutlass::Array<DstType, ConversionVectorWidth>;
constexpr cutlass::FloatRoundStyle RoundStyle =
cutlass::FloatRoundStyle::round_to_nearest;
using Converter = cutlass::InterleavedNumericArrayConverter<
IlvdBlkLayout, DstType, SrcType, ConversionVectorWidth, RoundStyle>;
constexpr int NumIterations = N / ConversionVectorWidth;
for (int ii = 0; ii < NumIterations; ++ii) {
SrcArray const* src_array_ptr =
reinterpret_cast<SrcArray const*>(raw_pointer_cast(in.data())) + ii;
DstArray* dst_array_ptr =
reinterpret_cast<DstArray*>(raw_pointer_cast(out.data())) + ii;
*dst_array_ptr = Converter::convert(*src_array_ptr);
}
}
};
} // namespace machete
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/vllm_numeric_conversion.cuh"
#include "machete_collective_builder.cuh"
#include "machete_prepacked_layout.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
// NOTE This kernel computes D = alpha * A * B + beta * C by computing
// D^t = alpha * B^t * A^t + beta * C^t, this is because the wgmma
// instructions only support sourcing from registers for the left-hand
// operand, we want to upconvert/decompress the quantized operand in
// register. Since the primary use case we want to support is Y = XW^t where
// W is quantized, in this situation or right-hand operand is quantized so
// we compute the transpose to move it to the left-hand side.
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, typename ScaleT, typename ZeroT,
class KernelSchedule, typename ScheduleConfig, bool with_C,
bool with_scales, bool with_zeropoints>
struct MacheteKernelTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementC = cute::conditional_t<with_C, ElementD, void>;
using ElementZ = ZeroT;
using ElementS = ScaleT;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementCompute = AccumulatorT; // For Epilogue
using BTypeTuple = cute::conditional_t<
with_scales,
cute::conditional_t<with_zeropoints,
cute::tuple<ElementB, ElementS, ElementZ>,
cute::tuple<ElementB, ElementS>>,
ElementB>;
using LayoutA = cutlass::layout::RowMajor;
using LayoutC = cutlass::layout::RowMajor;
using LayoutD = LayoutC;
using LayoutScale = cutlass::layout::RowMajor;
// not actually used since B has the prepacked layout, but required by cutlass
using _LayoutB = cutlass::layout::ColumnMajor;
// Interface strides expected by create_arguments (will get transposed)
using StrideA = cutlass::detail::TagToStrideA_t<LayoutA>;
using StrideC = cutlass::detail::TagToStrideA_t<LayoutC>;
using StrideD = cutlass::detail::TagToStrideA_t<LayoutD>;
using StrideS = cutlass::detail::TagToStrideA_t<LayoutScale>;
using StrideZ = StrideS;
using LayoutA_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutA>::type;
using LayoutC_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutC>::type;
using LayoutD_Transpose =
typename cutlass::layout::LayoutTranspose<LayoutD>::type;
using ArchTag = cutlass::arch::Sm90;
using OperatorClass = cutlass::arch::OpClassTensorOp;
using PrepackedLayoutB =
PrepackedLayoutBTemplate<ElementA_, ElementB_, ElementD_, AccumulatorT,
LayoutA_Transpose, KernelSchedule>;
static int constexpr TileShapeK =
128 * 8 / cutlass::sizeof_bits<MmaType>::value;
static int constexpr AlignmentA = 128 / cutlass::sizeof_bits_v<ElementA>;
static int constexpr AlignmentB = 128 / cutlass::sizeof_bits_v<ElementB>;
static int constexpr AlignmentC =
(with_C) ? 128 / cutlass::sizeof_bits_v<ElementC> : 0;
static int constexpr AlignmentD = 128 / cutlass::sizeof_bits_v<ElementD>;
using TileShape = decltype(append(typename ScheduleConfig::TileShapeNM{},
cute::Int<TileShapeK>{}));
using ClusterShape = typename ScheduleConfig::ClusterShape;
using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule;
using EpilogueTileType = typename ScheduleConfig::EpilogueTileType;
using TileScheduler = typename ScheduleConfig::TileScheduler;
using CollectiveEpilogue =
typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass, TileShape, ClusterShape, EpilogueTileType,
ElementAccumulator, ElementAccumulator, ElementC, LayoutC_Transpose,
AlignmentC, ElementD, LayoutD_Transpose, AlignmentD,
EpilogueSchedule>::CollectiveOp;
using CollectiveMainloop =
typename cutlass::gemm::collective::VLLMCollectiveBuilder<
cutlass::gemm::collective::MacheteKernelTag, ArchTag, OperatorClass,
BTypeTuple, PrepackedLayoutB, AlignmentB, ElementA, LayoutA_Transpose,
AlignmentA, ElementAccumulator, TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<static_cast<int>(
sizeof(typename CollectiveEpilogue::SharedStorage))>,
KernelSchedule>::CollectiveOp;
using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int, int, int, int>, // Indicates ProblemShape
CollectiveMainloop, CollectiveEpilogue, TileScheduler>;
using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;
// stride_B is unused (since B is prepacked), but still required by cutlass
using _StrideB = cutlass::detail::TagToStrideB_t<_LayoutB>;
using Arguments = typename Gemm::Arguments;
using MainloopArguments = typename GemmKernel::MainloopArguments;
using EpilogueArguments = typename GemmKernel::EpilogueArguments;
template <typename ShapeA, typename ShapeC, typename ShapeD, typename ShapeS,
typename ShapeZ>
static Arguments create_arguments(
cudaStream_t stream,
ElementA const* A_ptr, // A is an MxK matrix
Layout<ShapeA, StrideA> const& layout_A,
ElementB const* B_ptr, // B is an KxN prepacked matrix
ElementD* D_ptr, // D is an MxN matrix
Layout<ShapeD, StrideD> const& layout_D,
ElementC const* C_ptr, // C is an MxN matrix
std::optional<Layout<ShapeC, StrideC>> const& layout_C,
ElementS const* S_ptr, // S is an scale_KxN matrix
std::optional<Layout<ShapeS, StrideS>> const& layout_S,
ElementZ const* Z_ptr, // Z is an scale_KxN matrix
std::optional<Layout<ShapeZ, StrideZ>> const& layout_Z,
ElementCompute alpha, ElementCompute beta,
std::optional<int> maybe_group_size) {
static_assert(!with_zeropoints || with_scales);
int M = size<0>(layout_A), N = size<1>(layout_D), K = size<1>(layout_A);
int const group_size =
maybe_group_size == -1 ? K : maybe_group_size.value_or(K);
int const scale_k = (K + group_size - 1) / group_size;
TORCH_CHECK(size<0>(layout_A) == M && size<1>(layout_A) == K);
TORCH_CHECK(size<0>(layout_D) == M && size<1>(layout_D) == N);
if constexpr (with_C) {
TORCH_CHECK(C_ptr && layout_C);
} else {
TORCH_CHECK(!C_ptr, "C not supported");
}
if constexpr (with_scales) {
TORCH_CHECK(S_ptr && layout_S);
TORCH_CHECK((size<0>(*layout_S) == scale_k && size<1>(*layout_S) == N));
} else {
TORCH_CHECK(!S_ptr, "Scales not supported");
}
if constexpr (with_zeropoints) {
TORCH_CHECK(Z_ptr && layout_Z);
TORCH_CHECK((size<0>(*layout_Z) == scale_k && size<1>(*layout_Z) == N));
TORCH_CHECK(layout_S && *layout_Z == *layout_S,
"Scales and zeros must have the same layout");
} else {
TORCH_CHECK(!Z_ptr, "Zeropoints not supported");
}
// Transpose A and D
// A doesn't need to be transposed since cutlass expects a NxK matrix
// for B (which is At)
auto stride_At = layout_A.stride();
auto stride_Dt = permute_layout<1, 0, 2>(layout_D).stride();
auto stride_Ct = stride_Dt;
if (layout_C) {
stride_Ct = permute_layout<1, 0, 2>(*layout_C).stride();
}
MainloopArguments mainloop_arguments{};
EpilogueArguments epilogue_arguments{
{alpha, beta}, C_ptr, stride_Ct, D_ptr, stride_Dt};
if constexpr (with_scales && with_zeropoints) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At,
S_ptr, stride_S, group_size, Z_ptr};
} else if constexpr (with_scales) {
auto stride_S = permute_layout<1, 0, 2>(*layout_S).stride();
mainloop_arguments = MainloopArguments{
B_ptr, _StrideB{}, A_ptr, stride_At, S_ptr, stride_S, group_size};
} else {
mainloop_arguments =
MainloopArguments{B_ptr, _StrideB{}, A_ptr, stride_At};
}
return Arguments{cutlass::gemm::GemmUniversalMode::kGemm,
{N, M, K, 1},
mainloop_arguments,
epilogue_arguments};
};
static size_t get_workspace_size(Arguments const& args) {
return Gemm::get_workspace_size(args);
}
static bool can_implement(Arguments const& args) {
return Gemm::can_implement(args) == cutlass::Status::kSuccess;
}
static void run(Arguments const& args, void* workspace, cudaStream_t stream) {
Gemm gemm_op;
cutlass::Status status = gemm_op.initialize(args, workspace, stream);
TORCH_CHECK(status == cutlass::Status::kSuccess,
"Machete kernel failed to initialize workspace");
status = gemm_op.run(stream);
TORCH_CHECK(status == cutlass::Status::kSuccess, "Machete kernel failed");
}
};
}; // namespace machete
#pragma once
#include <torch/all.h>
#include <Python.h>
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
struct PyTorchArguments {
torch::Tensor const& A;
torch::Tensor const& B;
c10::optional<torch::Tensor> const& scales;
c10::optional<torch::Tensor> const& zeros;
c10::optional<int64_t> group_size;
c10::optional<torch::Tensor> const& C;
c10::optional<double> alpha;
c10::optional<double> beta;
c10::optional<std::string> schedule;
};
template <typename MacheteKernel>
torch::Tensor run_impl(PyTorchArguments args) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(args.A));
auto device = args.A.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
using EleA = typename MacheteKernel::ElementA;
using EleB = typename MacheteKernel::ElementB;
using EleC = typename MacheteKernel::ElementC;
using EleD = typename MacheteKernel::ElementD;
using EleScale = typename MacheteKernel::ElementS;
using EleZero = typename MacheteKernel::ElementZ;
using StrideA = typename MacheteKernel::StrideA;
using StrideC = typename MacheteKernel::StrideC;
using StrideD = typename MacheteKernel::StrideD;
using StrideS = typename MacheteKernel::StrideS;
using StrideZ = typename MacheteKernel::StrideZ;
int M = args.A.size(0);
int N = args.B.size(1);
int K = args.A.size(1);
// Allocate output
torch::Tensor D =
torch::empty({M, N}, torch::TensorOptions()
.dtype(equivalent_scalar_type_v<EleD>)
.device(device));
auto const &A = args.A, &B = args.B;
auto const &C = args.C, &scales = args.scales, &zeros = args.zeros;
auto layout_A = make_cute_layout<StrideA>(A, "A");
auto layout_D = make_cute_layout<StrideD>(D, "D");
auto layout_C = maybe_make_cute_layout<StrideC>(C, "C");
auto layout_S = maybe_make_cute_layout<StrideS>(scales, "scales");
auto layout_Z = maybe_make_cute_layout<StrideZ>(zeros, "zeros");
auto A_ptr = static_cast<EleA const*>(A.const_data_ptr());
auto B_ptr = static_cast<EleB const*>(B.const_data_ptr());
auto D_ptr = static_cast<EleD*>(D.mutable_data_ptr());
auto C_ptr = static_cast<EleC const*>(C ? C->const_data_ptr() : nullptr);
auto S_ptr =
static_cast<EleScale const*>(scales ? scales->const_data_ptr() : nullptr);
auto Z_ptr =
static_cast<EleZero const*>(zeros ? zeros->const_data_ptr() : nullptr);
auto arguments = MacheteKernel::create_arguments(
stream, A_ptr, layout_A, B_ptr, D_ptr, layout_D, C_ptr, layout_C, S_ptr,
layout_S, Z_ptr, layout_Z, args.alpha.value_or(1), args.beta.value_or(0),
args.group_size);
TORCH_CHECK(MacheteKernel::can_implement(arguments),
"Machete kernel cannot be run with these arguments");
size_t workspace_size = MacheteKernel::get_workspace_size(arguments);
torch::Tensor workspace = torch::empty(
workspace_size, torch::TensorOptions().dtype(torch::kU8).device(device));
MacheteKernel::run(arguments, workspace.mutable_data_ptr(), stream);
return D;
};
template <typename ElementA, typename ElementB, typename ElementD = ElementA,
typename AccumulatorT = float, typename ScaleT = ElementA,
typename ZeroT = ElementA>
struct GemmDispatcher {
static torch::Tensor dispatch(PyTorchArguments args);
static std::vector<std::string> supported_schedules();
};
}; // namespace machete
\ No newline at end of file
#pragma once
#include "machete_mm_kernel.cuh"
#include "cutlass_extensions/cute_utils.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
template <typename TileShapeNKL, typename ElementB, typename BInTensor,
typename BTiledOutTensor>
static __global__ void prepack_B_kernel(BInTensor B_in,
BTiledOutTensor B_tiled_out) {
auto tB_in = local_tile(B_in, TileShapeNKL{},
make_coord(blockIdx.x, blockIdx.y, blockIdx.z));
auto tB_out = B_tiled_out(make_coord(_, _),
make_coord(blockIdx.x, blockIdx.y), blockIdx.z);
auto tiled_copy = make_tiled_copy(Copy_Atom<DefaultCopy, ElementB>{},
Layout<Shape<_4, _32>, Stride<_32, _1>>{},
Layout<Shape<_1, _2>>{});
auto thr_copy = tiled_copy.get_thread_slice(threadIdx.x);
Tensor thr_tile_S = thr_copy.partition_S(tB_in);
Tensor thr_tile_D = thr_copy.partition_D(tB_out);
// Construct a register-backed Tensor with the same shape as each thread's
// partition
auto fragment = make_tensor<ElementB>(shape(thr_tile_D));
// Copy from GMEM to RMEM and from RMEM to GMEM
copy(tiled_copy, thr_tile_S, fragment);
copy(Copy_Atom<DefaultCopy, uint8_t>{}, fragment, thr_tile_D);
}
template <typename PrepackedLayoutB, typename InLayout>
static void prepack_B_template(
cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr,
InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) {
using TileShapeNKL =
decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{}));
auto ilvd_NKbNbKL_to_offset =
PrepackedLayoutB::ilvd_NKbNbKL_to_offset(shape(B_layout));
TORCH_CHECK(size<0>(B_layout) % size<0>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<1>(B_layout) % size<1>(TileShapeNKL{}) == 0);
TORCH_CHECK(size<2>(B_layout) % size<2>(TileShapeNKL{}) == 0);
auto N_tiles = size<0>(B_layout) / size<0>(TileShapeNKL{});
auto K_tiles = size<1>(B_layout) / size<1>(TileShapeNKL{});
auto L_tiles = size<2>(B_layout) / size<2>(TileShapeNKL{});
auto B_in = make_tensor(get_logical_ptr(B_in_ptr), B_layout);
auto B_tiled_out =
make_tensor(get_logical_ptr(B_out_ptr), ilvd_NKbNbKL_to_offset);
prepack_B_kernel<TileShapeNKL, typename PrepackedLayoutB::ElementB>
<<<dim3(N_tiles, K_tiles, L_tiles), 128, 0, stream>>>(B_in, B_tiled_out);
}
}; // namespace machete
\ No newline at end of file
#pragma once
#include "machete_prepack_kernel.cuh"
#include "cutlass_extensions/torch_utils.hpp"
namespace machete {
template <typename PrepackedLayoutB>
torch::Tensor prepack_impl(torch::Tensor const B) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(B));
using ElementB = typename PrepackedLayoutB::ElementB;
using PPBlockShape_NK = typename PrepackedLayoutB::PPBlockShape_NK;
auto device = B.device();
auto stream = at::cuda::getCurrentCUDAStream(device.index());
auto B_ptr = static_cast<ElementB const*>(B.const_data_ptr());
// elements per storage item for B
auto eles_per_storage =
(B.dtype().itemsize() * 8) / cute::sizeof_bits_v<ElementB>;
// torch B passed in is/should be (packed_K,N), the kernel expects (N,K,L) (to
// match cutlass using (N,K,L) for B), so we transpose B to (N,packed_K,L)
auto Bt_packed = B.t();
TORCH_CHECK(
(B.size(0) * eles_per_storage) % size<1>(PPBlockShape_NK{}) == 0,
"B.shape[0] (in terms of unpacked elements) must be a multiple of ",
size<1>(PPBlockShape_NK{}));
TORCH_CHECK(B.size(1) % size<0>(PPBlockShape_NK{}) == 0,
"B.shape[1] must be a multiple of ", size<0>(PPBlockShape_NK{}));
using StrideB = cutlass::detail::TagToStrideB_t<cutlass::layout::ColumnMajor>;
auto const l_Bt_packed = make_cute_layout<StrideB>(Bt_packed, "B");
// convert (N,packed_K,L) layout to (N,K,L) layout
// in effect we want to do: blocked_product(layout_Bt_packed,
// make_ordered_layout(make_shape(_1{}, eles_per_storage, _1{}),
// Step<_1, _0, _2>{}));
// but blocked_product does not support dynamic strides so we implement the
// equivalent manually,
// new_shape = (N, packed_K, L) * (1, eles_per_storage, 1) -> (N, K, L)
// new_stride = (s0, s1, s2) * (eles_per_storage, 1, eles_per_storage)
// when s1 == 1
TORCH_CHECK(stride<1>(l_Bt_packed) == 1);
// clang-format off
auto const layout_Bt = make_layout(
transform_with_idx(l_Bt_packed.shape(), [&](auto ele, auto idx) {
return idx == 1 ? ele * eles_per_storage : ele;
}),
transform_with_idx(l_Bt_packed.stride(), [&](auto ele, auto idx) {
return idx != 1 ? ele * eles_per_storage : ele;
}));
// clang-format on
// Allocate output
torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous);
prepack_B_template<PrepackedLayoutB>(
stream, B_ptr, layout_Bt, static_cast<ElementB*>(D.mutable_data_ptr()));
return D;
};
template <typename ElementA, typename ElementB, typename ElementD,
typename AccumulatorT = float, typename ScaleT = cutlass::half_t,
typename ZeroT = cutlass::half_t>
struct PrepackBDispatcher {
static torch::Tensor dispatch(torch::Tensor B);
};
}; // namespace machete
\ No newline at end of file
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
// clang-format off
// The cutlass include order matters (annoyingly)
#include "cutlass/cutlass.h"
#include "cute/tensor.hpp"
#include "cutlass/tensor_ref.h"
#include "cutlass/epilogue/collective/default_epilogue.hpp"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
// clang-format on
#include "cutlass_extensions/cute_utils.cuh"
#include "machete_collective_builder.cuh"
#include "machete_interleaving_utils.cuh"
namespace machete {
using namespace cute;
struct IlvBlkLayoutAuto {};
// This defines a prepacked layout for the B matrix, where the matrix is broken
// up into PPBlockShape_NK blocks. The data within each block is then compactly
// stored in memory such that when performing a TiledMMA operation with the same
// shape as prepacked block, all the data for a given thread is contiguous in
// memory. This allows us to use wider shared memory loads when loading B from
// shared memory. The values within a thread are also potentially interlaeved
// inorder to allow for more efficient upconverting.
//
// The contract here is that the `TiledMma` determined below matches the one
// ultimately used in the kernel. (this is also why the other element types are
// required along with the kernel schedule)
template <typename ElementA_, typename ElementB_, typename ElementD_,
typename AccumulatorT, class LayoutB, class KernelSchedule,
typename IlvBlkLayout_ = IlvBlkLayoutAuto>
// clang-format on
struct PrepackedLayoutBTemplate {
using MmaType = ElementA_;
using ElementA = ElementA_;
using ElementB = ElementB_;
using ElementD = ElementD_;
using ElementAccumulator =
AccumulatorT; // Element type for internal accumulation
using ElementMma = MmaType;
// Only use interleaved layouts for subbyte weights, prmt instructions makes
// non-interleaved layouts for 8bit+ weights efficient enough we don't need
// iterleaved layouts
using IlvdBlkLayout = std::conditional_t<
std::is_same_v<IlvBlkLayout_, IlvBlkLayoutAuto>,
std::conditional_t<sizeof_bits_v<ElementB> <= 4,
decltype(get_interleaved_blk_layout<
ElementB, sizeof_bits_v<ElementA>, 32>()),
void>,
IlvBlkLayout_>;
// TODO (LucasWilkinson): compare the performance for other sizes
// Prepacked block shape, smallest layout atom for loading into registers
// (can contain multiple wgmma instructions worth of data in one block)
// We ideally want this to be configured such that a thread can perform 128bit
// loads, i.e. we amount of data associated with each thread within a
// prepacked block is a multiple of 128bits, when using a cooperative sechdule
// we have 256 threads working a single block at a time, this means each
// thread works on `sizeof_bits_v<ElementB> * (128*64) / 256` bits of data,
// for a 4bit type this would be 128bits
using PPBlockShape_NK = Shape<_128, _64>;
// Create the shape of the tile anticipated to be used by the GEMM kernel,
// when the kernel executes we will compute `Ct = Bt * At` since the
// quantized weights (B), must be the lhs operand so the flow through
// registers.
// The _128 here doesn't actually impact the shape of the stored tile directly
// but may impact the op selected by rs_op_selector
using GemmTileShape = decltype(make_shape(size<0>(PPBlockShape_NK{}), _128{},
size<1>(PPBlockShape_NK{})));
static constexpr cute::GMMA::Major GmmaMajorB =
gmma_rs_tag_to_major_B<LayoutB>();
// For coop schedules we have two warp groups cooperatively issuing wgmma
// instructions so we use 2 atoms along the M dim (one for each warpgroup)
using AtomLayoutMNK = cute::conditional_t<
cute::is_same_v<KernelSchedule,
KernelTmaWarpSpecializedCooperativeMixedInput>,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::rs_op_selector<ElementMma, ElementMma, ElementAccumulator,
GemmTileShape, GMMA::Major::K, GmmaMajorB>(),
AtomLayoutMNK{}));
// Prepacked block, (athrid, val) -> (N,K)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (N,K)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_NK() {
return TiledMma{}.thrfrg_A(make_layout(PPBlockShape_NK{}));
}
// Prepacked block, (N,K) -> (athrid, val)
// i.e. (N,K) -> ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...)))
CUTE_HOST_DEVICE static constexpr auto ppblock_NK_to_TV() {
return right_inverse(ppblock_TV_to_NK()).with_shape(PPBlockShape_NK{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_TV_to_offset() {
// Return iterleaved layout
return make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
}
// Prepacked block, (athrid, val) -> (storage_offset)
// i.e. ((ThrV,(ThrM,ThrK)),(IlvdFrgV,(RestM,RestK,...))) -> (storage_idx)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_TV_to_offset() {
auto layout_no_interleave =
make_ordered_layout(shape(ppblock_TV_to_NK()), Step<_1, _0>{});
if constexpr (std::is_same_v<IlvdBlkLayout, void>) {
return layout_no_interleave;
} else {
// interleave by transforming FrgV into interleaved blocks where each
// block has the layout IlvdBlkLayout, for example if IlvdBlkLayout is
// (2, 2) : (2, 1) then we get: ((2, 2), size(FrgV) / 4) : ((2, 1), 4)
// if FrgV is {A, B, C, D, E, F, G, H}
// then ((IlvBlk), FrgB) is {A, C, B, D, C, G, D, H}
auto frgV = get<1, 0>(layout_no_interleave);
auto ilvdBlk = IlvdBlkLayout{};
static_assert(size(frgV) % 4 == 0, "FrgV must be divisible by 4");
auto ilvd_FrgV = make_layout(
make_shape(shape(ilvdBlk), Int<size(frgV) / size(ilvdBlk)>{}),
make_stride(stride(ilvdBlk), size(ilvdBlk)));
// Return iterleaved layout
return make_layout(
get<0>(layout_no_interleave),
make_layout(ilvd_FrgV, get<1, 1>(layout_no_interleave)));
}
}
// Prepacked block, (M,K) -> (storage_offset)
CUTE_HOST_DEVICE static constexpr auto ppblock_ilvd_NK_to_offset() {
// do (M,K) -> (athrid, val) -> (storage_idx)
return ppblock_ilvd_TV_to_offset().compose(ppblock_NK_to_TV());
}
// ((athrid, val), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto TVbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_TV_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L))
// => ((athrid, val), (BlocksN, BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// ((BlockN, BlockK), (BlocksN, BlocksK), L) -> (storage_idx)
template <typename Shape_NKL>
CUTE_HOST_DEVICE static constexpr auto ilvd_NKbNbKL_to_offset(
Shape_NKL shape_mkl) {
constexpr auto block_layout = ppblock_ilvd_NK_to_offset();
// (BlocksN, BlocksK, L)
auto blocks_shape =
cute::transform(shape_mkl, append(PPBlockShape_NK{}, _1{}),
[](auto x, auto y) { return x / y; });
// ((athrid, val), (BlocksN, BlocksK, L)) -> (storage_idx)
auto result = make_layout(
block_layout,
make_layout(blocks_shape,
compact_col_major(blocks_shape, size(block_layout))));
// ((athrid, val), (BlocksN, BlocksK, L)) => ((athrid, val), (BlocksN,
// BlocksK), L)
return group<1, 3>(result(_, repeat<rank<1>(result)>(_)));
}
// ((athrid, val), (BlocksN, BlocksK, L)) -> (N, K, L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto TVbNbK_to_NKL(Shape_NKL shape_mkl) {
auto tile = make_tile(make_layout(size<0>(PPBlockShape_NK{})),
make_layout(size<1>(PPBlockShape_NK{})));
// ((BlockN, BlockK), (BlocksN, BlocksK, L)) -> (N, K, L)
auto tiled_A = zipped_divide(make_layout(shape_mkl), tile);
return tiled_A.compose(ppblock_TV_to_NK(), _);
}
// (N, K, L) -> ((athrid, val), (BlocksN, BlocksK), L)
template <class Shape_NKL>
CUTE_HOST_DEVICE static auto NKL_to_TVbNbK(Shape_NKL shape_mkl) {
auto TVbNbK_to_NKL_layout = TVbNbK_to_NKL(shape_mkl);
return blocked_product(ppblock_NK_to_TV(),
make_layout(shape<1>(TVbNbK_to_NKL_layout)));
}
};
}; // namespace machete
\ No newline at end of file
#include "machete_mm_launcher.cuh"
#include "machete_prepack_launcher.cuh"
#include "core/scalar_type.hpp"
#include "core/registration.h"
namespace machete {
using namespace vllm;
//
// Utils (type dispatching)
//
template <typename Fn>
static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
if (type == vllm::kU4) {
return fn(cutlass::uint4b_t{});
} else if (type == vllm::kU8) {
return fn(cutlass::uint8_t{});
} else if (type == vllm::kU4B8) {
return fn(cutlass::vllm_uint4b8_t{});
} else if (type == vllm::kU8B128) {
return fn(cutlass::vllm_uint8b128_t{});
} else {
TORCH_CHECK(false, "Unsupported type ", type.str());
}
}
#define AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(...) \
AT_DISPATCH_CASE_REDUCED_FLOATING_TYPES(__VA_ARGS__)
#define AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, \
AT_DISPATCH_CASE_SUPPORTED_COMPUTE_TYPES(__VA_ARGS__))
//
// Interface
//
std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
ScalarTypeTorchPtr const& btype,
c10::optional<torch::Tensor> const& scales,
c10::optional<torch::Tensor> const& zeros,
c10::optional<int64_t> group_size,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
auto args = PyTorchArguments{.A = A,
.B = B,
.scales = scales,
.zeros = zeros,
.group_size = group_size,
.C = C,
.alpha = alpha,
.beta = beta,
.schedule = schedule};
return scalar_type_dispatch(*btype, [&](auto BType) {
return AT_DISPATCH_SUPPORTED_COMPUTE_TYPES(
A.scalar_type(), "machete_gemm", [&] {
using ComputeType = equivalent_cutlass_type_t<scalar_t>;
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
});
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}
torch::Tensor prepack_B(torch::Tensor const& B,
vllm::ScalarTypeTorchPtr const& btype) {
return scalar_type_dispatch(*btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("machete_prepack_B", &prepack_B);
m.impl("machete_gemm", &gemm);
}
// use CatchAll since supported_schedules has no tensor arguments
TORCH_LIBRARY_IMPL(TORCH_EXTENSION_NAME, CatchAll, m) {
m.impl("machete_supported_schedules", &supported_schedules);
}
}; // namespace machete
......@@ -26,6 +26,7 @@
#include <iostream>
#include "common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "common/mem.h"
......@@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_gemm", &marlin_gemm);
}
......@@ -30,6 +30,7 @@
#include <iostream>
#include "../dense/common/base.h"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#include "../dense/common/mem.h"
......@@ -1241,3 +1242,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a,
return d;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("marlin_qqq_gemm", &marlin_qqq_gemm);
}
......@@ -28,6 +28,7 @@
#include "common/base.h"
#include "core/scalar_type.hpp"
#include "core/registration.h"
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
......@@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
return c;
}
TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) {
m.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm);
}
#include <torch/all.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
// half-tensor
#include <c10/cuda/CUDAStream.h>
#include <ATen/cuda/CUDATensorMethods.cuh>
#include <c10/cuda/CUDAGuard.h>
#define BLOCKWIDTH 128
#define BLOCKHEIGHT4 16
namespace vllm {
namespace squeezellm {
__device__ inline unsigned int as_unsigned(int i) {
return *reinterpret_cast<unsigned int*>(&i);
}
// 4-bit matvec kernel (LUT-based)
__global__ void NUQ4MatMulKernel(
#ifndef USE_ROCM
const half2* __restrict__ vec,
#else
const __half2* __restrict__ vec,
#endif
const int* __restrict__ mat,
#ifndef USE_ROCM
half2* __restrict__ mul,
#else
float2* __restrict__ mul,
#endif
const __half* __restrict__ lookup_table, int height, int width, int batch,
int vec_height) {
const int blockwidth2 = BLOCKWIDTH / 2;
int row = BLOCKHEIGHT4 * blockIdx.x;
int col = BLOCKWIDTH * blockIdx.y + threadIdx.x;
#ifndef USE_ROCM
__shared__ half2 blockvec[blockwidth2];
#else
__shared__ __half2 blockvec[blockwidth2];
#endif
__shared__ __half deq2[16][BLOCKWIDTH];
int off = threadIdx.x;
int column_offset = col * 16;
for (int val = 0; val < 16; val += 1) {
int lut_index = column_offset + val;
deq2[val][off] = lookup_table[lut_index];
}
__half res;
#ifndef USE_ROCM
half2 res2;
half2 tmp2;
#else
__half2 res2;
__half2 tmp2;
#endif
int i;
int k;
unsigned int tmp1;
unsigned int lut_index1, lut_index2;
for (int b = 0; b < batch; ++b) {
i = width * row + col;
res = __int2half_rd(0);
k = 0;
__syncthreads();
if (threadIdx.x < blockwidth2)
blockvec[threadIdx.x] =
vec[b * vec_height / 2 + (row / BLOCKHEIGHT4) * blockwidth2 +
threadIdx.x];
__syncthreads();
while (k < blockwidth2) {
tmp1 = as_unsigned(mat[i]);
#ifndef USE_ROCM
res2 = {};
tmp2 = {};
#else
res2.x = __half_as_ushort(__float2half(0));
res2.y = __half_as_ushort(__float2half(0));
tmp2.x = __half_as_ushort(__float2half(0));
tmp2.y = __half_as_ushort(__float2half(0));
#endif
lut_index1 = tmp1 & 0xF;
lut_index2 = (tmp1 >> 4) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 0], res2);
lut_index1 = (tmp1 >> 8) & 0xF;
lut_index2 = (tmp1 >> 12) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 1], res2);
lut_index1 = (tmp1 >> 16) & 0xF;
lut_index2 = (tmp1 >> 20) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 2], res2);
lut_index1 = (tmp1 >> 24) & 0xF;
lut_index2 = (tmp1 >> 28) & 0xF;
#ifndef USE_ROCM
tmp2.x = deq2[lut_index1][off];
tmp2.y = deq2[lut_index2][off];
#else
tmp2.x = __half_as_ushort(deq2[lut_index1][off]);
tmp2.y = __half_as_ushort(deq2[lut_index2][off]);
#endif
res2 = __hfma2(tmp2, blockvec[k + 3], res2);
#ifndef USE_ROCM
res = __hadd(__hadd(res2.x, res2.y), res);
#else
res = __hadd(__hadd(__ushort_as_half(res2.x), __ushort_as_half(res2.y)),
res);
#endif
i += width;
k += 4;
}
// col%2 -> only set one of the two values
#ifndef USE_ROCM
half2 res3 = {};
if (col % 2 == 0) {
res3.x = res;
} else {
res3.y = res;
}
#else
__half2 res3;
res3.x = __half_as_ushort(__float2half(0));
res3.y = __half_as_ushort(__float2half(0));
if (col % 2 == 0) {
res3.x = __half_as_ushort(res);
} else {
res3.y = __half_as_ushort(res);
}
#endif
#ifndef USE_ROCM
atomicAdd(&mul[b * width / 2 + col / 2], res3);
#else
int tmp_addr = b * width / 2 + col / 2;
atomicAdd(&(mul[tmp_addr].x), __half2float(__ushort_as_half(res3.x)));
atomicAdd(&(mul[tmp_addr].y), __half2float(__ushort_as_half(res3.y)));
#endif
}
}
} // namespace squeezellm
} // namespace vllm
// 4-bit matvec kernel (LUT-based)
void squeezellm_gemm(torch::Tensor vec, torch::Tensor mat, torch::Tensor mul,
torch::Tensor lookup_table) {
int height = mat.size(0);
int width = mat.size(1);
int batch = vec.size(0);
int vec_height = vec.size(1);
dim3 blocks((height + BLOCKHEIGHT4 - 1) / BLOCKHEIGHT4,
(width + BLOCKWIDTH - 1) / BLOCKWIDTH);
dim3 threads(BLOCKWIDTH);
const at::cuda::OptionalCUDAGuard device_guard(device_of(vec));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
vllm::squeezellm::NUQ4MatMulKernel<<<blocks, threads, 0, stream>>>(
#ifndef USE_ROCM
(half2*)vec.data_ptr<at::Half>(),
#else
(__half2*)vec.data_ptr<at::Half>(),
#endif
mat.data_ptr<int>(),
#ifndef USE_ROCM
(half2*)mul.data_ptr<at::Half>(),
(__half*)lookup_table.data_ptr<at::Half>(),
#else
(float2*)mul.data_ptr<float>(),
(__half*)lookup_table.data_ptr<at::Half>(),
#endif
height, width, batch, vec_height);
}
#undef BLOCKWIDTH
#undef BLOCKHEIGHT4
/*
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/reduce_kernel_utils.cuh
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_compat.h"
namespace vllm {
namespace detail {
template <typename T>
__inline__ __device__ T _max(T a, T b) {
return max(a, b);
}
template <typename T>
__inline__ __device__ T _sum(T a, T b) {
return a + b;
}
} // namespace detail
template <typename T>
using ReduceFnType = T (*)(T, T);
// Helper function to return the next largest power of 2
static constexpr int _nextPow2(unsigned int num) {
if (num <= 1) return num;
return 1 << (CHAR_BIT * sizeof(num) - __builtin_clz(num - 1));
}
template <typename T, int numLanes = WARP_SIZE>
__inline__ __device__ T warpReduce(T val, ReduceFnType<T> fn) {
static_assert(numLanes > 0 && (numLanes & (numLanes - 1)) == 0,
"numLanes is not a positive power of 2!");
static_assert(numLanes <= WARP_SIZE);
#pragma unroll
for (int mask = numLanes >> 1; mask > 0; mask >>= 1)
val = fn(val, VLLM_SHFL_XOR_SYNC(val, mask));
return val;
}
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduce(T val, ReduceFnType<T> fn) {
static_assert(maxBlockSize <= 1024);
if constexpr (maxBlockSize > WARP_SIZE) {
val = warpReduce<T>(val, fn);
// Calculates max number of lanes that need to participate in the last
// warpReduce
constexpr int maxActiveLanes = (maxBlockSize + WARP_SIZE - 1) / WARP_SIZE;
static __shared__ T shared[maxActiveLanes];
int lane = threadIdx.x % WARP_SIZE;
int wid = threadIdx.x / WARP_SIZE;
if (lane == 0) shared[wid] = val;
__syncthreads();
val = (threadIdx.x < blockDim.x / float(WARP_SIZE)) ? shared[lane]
: (T)(0.0f);
val = warpReduce<T, _nextPow2(maxActiveLanes)>(val, fn);
} else {
// A single warpReduce is equal to blockReduce
val = warpReduce<T, _nextPow2(maxBlockSize)>(val, fn);
}
return val;
}
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceMax(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_max<T>);
}
template <typename T, int maxBlockSize = 1024>
__inline__ __device__ T blockReduceSum(T val) {
return blockReduce<T, maxBlockSize>(val, detail::_sum<T>);
}
} // namespace vllm
/*
* Copyright (c) 2024, The vLLM team.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <torch/all.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <hip/hip_bf16.h>
#include "cuda_compat.h"
#include <algorithm>
#include "../attention/dtype_fp8.cuh"
#include "../quantization/fp8/amd/quant_utils.cuh"
#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#endif
#if defined(NDEBUG)
#undef NDEBUG
#include <assert.h>
#define UNREACHABLE_CODE assert(false);
#define NDEBUG
#else
#define UNREACHABLE_CODE assert(false);
#endif
#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#if defined(__HIP__MI300_MI250__) // TODO: Add NAVI support
#define GCN_MFMA_INSTR1 __builtin_amdgcn_mfma_f32_16x16x4f32
#define GCN_MFMA_INSTR __builtin_amdgcn_mfma_f32_4x4x4f16
using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float;
using float16x4 =
__attribute__((__vector_size__(4 * sizeof(_Float16)))) _Float16;
typedef float16x4 _Half4;
typedef struct _Half8 {
_Half4 xy[2];
} _Half8;
using bit16_t = uint16_t;
using bit16x4 = __attribute__((__vector_size__(4 * sizeof(uint16_t)))) uint16_t;
typedef bit16x4 _B16x4;
typedef struct _B16x8 {
_B16x4 xy[2];
} _B16x8;
using _B8x8 = uint2;
////// Non temporal load stores ///////
template <typename T>
__device__ __forceinline__ T load(T* addr) {
return addr[0];
}
template <typename T>
__device__ __forceinline__ void store(T value, T* addr) {
addr[0] = value;
}
template <typename T, int absz, int cbid, int blgp>
__device__ __forceinline__ floatx4 gcn_mfma_instr(const _B16x4& inpA,
const _B16x4& inpB,
const floatx4& inpC) {
if constexpr (std::is_same<T, _Float16>::value) {
return __builtin_amdgcn_mfma_f32_4x4x4f16(inpA, inpB, inpC, absz, cbid,
blgp);
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(inpA, inpB, inpC, absz, cbid,
blgp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ float to_float(const T& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (float)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __bfloat162float(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ T from_float(const float& inp) {
if constexpr (std::is_same<T, _Float16>::value) {
return (_Float16)inp;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
return __float2bfloat16(inp);
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ _B16x4 from_floatx4(const floatx4& inp) {
union tmpcvt {
uint16_t u;
_Float16 f;
__hip_bfloat16 b;
} t16;
_B16x4 ret;
if constexpr (std::is_same<T, _Float16>::value) {
#pragma unroll
for (int i = 0; i < 4; i++) {
t16.f = (_Float16)inp[i];
ret[i] = t16.u;
}
return ret;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < 4; i++) {
t16.b = __float2bfloat16(inp[i]);
ret[i] = t16.u;
}
return ret;
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T>
__device__ __forceinline__ _B16x4 addx4(const _B16x4& inp1,
const _B16x4& inp2) {
union tmpcvt {
uint16_t u;
_Float16 f;
__hip_bfloat16 b;
} t1, t2, res;
_B16x4 ret;
if constexpr (std::is_same<T, _Float16>::value) {
#pragma unroll
for (int i = 0; i < 4; i++) {
t1.u = inp1[i];
t2.u = inp2[i];
res.f = t1.f + t2.f;
ret[i] = res.u;
}
return ret;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
#pragma unroll
for (int i = 0; i < 4; i++) {
t1.u = inp1[i];
t2.u = inp2[i];
res.b = t1.b + t2.b;
ret[i] = res.u;
}
return ret;
} else {
static_assert(false, "unsupported 16b dtype");
}
}
template <typename T, vllm::Fp8KVCacheDataType KV_DTYPE>
__device__ __forceinline__ _B16x8 scaled_convert_b8x8(const _B8x8 input,
const float scale) {
union alignas(16) {
uint4 u4;
_B16x8 u16x8;
vllm::bf16_8_t b16x8;
} tmp;
if constexpr (std::is_same<T, _Float16>::value) {
tmp.u4 = vllm::fp8::scaled_convert<uint4, _B8x8, KV_DTYPE>(input, scale);
return tmp.u16x8;
} else if constexpr (std::is_same<T, __hip_bfloat16>::value) {
tmp.b16x8 = vllm::fp8::scaled_convert<vllm::bf16_8_t, _B8x8, KV_DTYPE>(
input, scale);
return tmp.u16x8;
} else {
static_assert(false, "unsupported 16b dtype");
}
}
///////////////////////////////////////
// grid (num_seqs, num_partitions,num_heads/gqa_ratio)
// block (partition size)
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
int NUM_THREADS,
int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale) {
constexpr int NWARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
const int lane4id = laneid % 4;
const int seq_idx = blockIdx.x;
const int partition_idx = blockIdx.y;
const int partition_size = blockDim.x;
const int max_num_partitions = gridDim.y;
const int context_len = context_lens[seq_idx];
const int partition_start_token_idx = partition_idx * partition_size;
// exit if partition is out of context for seq
if (partition_start_token_idx >= context_len) {
return;
}
constexpr int QHLOOP =
DIVIDE_ROUND_UP(GQA_RATIO, 4); // each 4 lanes fetch 4 different qheads,
// total qheads =8, so qhloop is 2
constexpr int GQA_RATIO4 = 4 * QHLOOP;
__shared__ float shared_qk_max[NWARPS][GQA_RATIO4 + 1];
__shared__ float shared_exp_sum[NWARPS][GQA_RATIO4 + 1];
_B16x8 Qlocal[QHLOOP];
constexpr int x = 16 / sizeof(scalar_t);
constexpr int KHELOOP = HEAD_SIZE / x;
_B16x8 Klocal[KHELOOP];
_B8x8 Klocalb8[KHELOOP];
constexpr int VHELOOP =
HEAD_SIZE /
WARP_SIZE; // v head_size dimension is distributed across lanes
constexpr int VTLOOP = 8; // 16 separate 4xtokens across warp -> 16/2
// 8xtokens
_B16x8 Vlocal[VHELOOP][VTLOOP];
_B8x8 Vlocalb8[VHELOOP][VTLOOP];
floatx4 dout[QHLOOP];
float qk_max[QHLOOP];
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
dout[h] = {0};
qk_max[h] = -FLT_MAX;
}
const int wg_start_head_idx = blockIdx.z * GQA_RATIO;
const int wg_start_kv_head_idx = blockIdx.z;
const int warp_start_token_idx =
partition_start_token_idx + warpid * WARP_SIZE;
if (warp_start_token_idx >= context_len) { // warp out of context
#pragma unroll
for (int h = 0; h < GQA_RATIO4; h++) {
shared_qk_max[warpid][h] = -FLT_MAX;
shared_exp_sum[warpid][h] = 0.0f;
}
} else { // warp within context
const int num_context_blocks = DIVIDE_ROUND_UP(context_len, BLOCK_SIZE);
const int last_ctx_block = num_context_blocks - 1;
const int* block_table = block_tables + seq_idx * max_num_blocks_per_seq;
const int local_token_idx = threadIdx.x;
const int global_token_idx = partition_start_token_idx + local_token_idx;
const int block_idx = (global_token_idx < context_len)
? global_token_idx / BLOCK_SIZE
: last_ctx_block;
// fetch block number for q and k
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const int64_t physical_block_number =
static_cast<int64_t>(block_table[block_idx]);
// fetch vphysical block numbers up front
constexpr int VBLOCKS = 8 * VTLOOP / BLOCK_SIZE;
int vphysical_blocks[VBLOCKS];
const int warp_start_block_idx = warp_start_token_idx / BLOCK_SIZE;
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
const int vblock_idx = warp_start_block_idx + b;
const int vblock_idx_ctx =
(vblock_idx <= last_ctx_block) ? vblock_idx : last_ctx_block;
vphysical_blocks[b] = block_table[vblock_idx_ctx];
}
// each 4 lanes fetch 8 helems, so warp fetches 8*16 = 128 helems
const scalar_t* q_ptr =
q + seq_idx * q_stride + wg_start_head_idx * HEAD_SIZE;
const _B16x8* q_ptrh8 = reinterpret_cast<const _B16x8*>(q_ptr);
const int qhead_elemh8 = laneid / 4;
#pragma unroll
for (int h = 0; h < QHLOOP - 1; h++) {
const int qhead_idx = h * 4 + lane4id;
Qlocal[h] = q_ptrh8[qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
}
const int final_qhead_idx = 4 * (QHLOOP - 1) + lane4id;
if (final_qhead_idx < GQA_RATIO) {
Qlocal[QHLOOP - 1] =
q_ptrh8[final_qhead_idx * HEAD_SIZE / 8 + qhead_elemh8];
} else {
Qlocal[QHLOOP - 1].xy[0] = {0};
Qlocal[QHLOOP - 1].xy[1] = {0};
}
const cache_t* k_ptr = k_cache + physical_block_number * kv_block_stride +
wg_start_kv_head_idx * kv_head_stride;
const int physical_block_offset =
local_token_idx % BLOCK_SIZE; // since x=half8, physical_block_offset
// is already cast as _H8
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
const _B16x8* k_ptrh8 = reinterpret_cast<const _B16x8*>(k_ptr);
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
Klocal[d] = k_ptrh8[d * BLOCK_SIZE + physical_block_offset];
}
} else {
constexpr int X = 16 / sizeof(cache_t);
const cache_t* k_ptr2 = k_ptr + physical_block_offset * X;
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
const int head_elem = d * 8;
const int offset1 = head_elem / X;
const int offset2 = head_elem % X;
const cache_t* k_ptr3 = k_ptr2 + offset1 * BLOCK_SIZE * X + offset2;
Klocalb8[d] = *reinterpret_cast<const _B8x8*>(k_ptr3);
}
}
float alibi_slope[QHLOOP];
if (alibi_slopes != nullptr) {
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
const int qhead_idx = h * 4 + lane4id;
alibi_slope[h] = (qhead_idx < GQA_RATIO)
? alibi_slopes[wg_start_head_idx + qhead_idx]
: 0.f;
}
}
const cache_t* v_ptr = v_cache + wg_start_kv_head_idx * kv_head_stride;
if constexpr (KV_DTYPE == vllm::Fp8KVCacheDataType::kAuto) {
const _B16x8* v_ptrh8 = reinterpret_cast<const _B16x8*>(v_ptr);
// iterate over each v block
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const int64_t vphysical_block_number =
static_cast<int64_t>(vphysical_blocks[b]);
const _B16x8* v_ptrh8b =
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size)
#pragma unroll
for (int h = 0; h < VHELOOP; h++) {
const int head_size_elem = h * WARP_SIZE + laneid;
const _B16x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block
#pragma unroll
for (int d = 0; d < BLOCK_SIZE / 8; d++) {
Vlocal[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
}
}
}
} else {
const _B8x8* v_ptrh8 = reinterpret_cast<const _B8x8*>(v_ptr);
// iterate over each v block
#pragma unroll
for (int b = 0; b < VBLOCKS; b++) {
// int32 physical_block_number leads to overflow when multiplied with
// kv_block_stride
const int64_t vphysical_block_number =
static_cast<int64_t>(vphysical_blocks[b]);
const _B8x8* v_ptrh8b =
v_ptrh8 + (vphysical_block_number * kv_block_stride) / 8;
// iterate over each head elem (within head_size)
#pragma unroll
for (int h = 0; h < VHELOOP; h++) {
const int head_size_elem = h * WARP_SIZE + laneid;
const _B8x8* v_ptrh8be = v_ptrh8b + head_size_elem * BLOCK_SIZE / 8;
// iterate over all velems within block
#pragma unroll
for (int d = 0; d < BLOCK_SIZE / 8; d++) {
// Vlocalb8[h][b * BLOCK_SIZE / 8 + d] = v_ptrh8be[d];
const _B8x8 Vlocalb8 = v_ptrh8be[d];
Vlocal[h][b * BLOCK_SIZE / 8 + d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Vlocalb8, v_scale);
}
}
}
}
if constexpr (KV_DTYPE != vllm::Fp8KVCacheDataType::kAuto) {
#pragma unroll
for (int d = 0; d < KHELOOP; d++) {
Klocal[d] =
scaled_convert_b8x8<scalar_t, KV_DTYPE>(Klocalb8[d], k_scale);
}
}
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[0],
Klocal[0].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 0, 0>(Qlocal[h].xy[1],
Klocal[0].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[0],
Klocal[1].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 1, 0>(Qlocal[h].xy[1],
Klocal[1].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[0],
Klocal[2].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 2, 0>(Qlocal[h].xy[1],
Klocal[2].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[0],
Klocal[3].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 3, 0>(Qlocal[h].xy[1],
Klocal[3].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[0],
Klocal[4].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 4, 0>(Qlocal[h].xy[1],
Klocal[4].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[0],
Klocal[5].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 5, 0>(Qlocal[h].xy[1],
Klocal[5].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[0],
Klocal[6].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 6, 0>(Qlocal[h].xy[1],
Klocal[6].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[0],
Klocal[7].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 7, 0>(Qlocal[h].xy[1],
Klocal[7].xy[1], dout[h]);
if constexpr (KHELOOP > 8) {
dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[0],
Klocal[8].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 8, 0>(Qlocal[h].xy[1],
Klocal[8].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[0],
Klocal[9].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 9, 0>(Qlocal[h].xy[1],
Klocal[9].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[0],
Klocal[10].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 10, 0>(Qlocal[h].xy[1],
Klocal[10].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[0],
Klocal[11].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 11, 0>(Qlocal[h].xy[1],
Klocal[11].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[0],
Klocal[12].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 12, 0>(Qlocal[h].xy[1],
Klocal[12].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[0],
Klocal[13].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 13, 0>(Qlocal[h].xy[1],
Klocal[13].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[0],
Klocal[14].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 14, 0>(Qlocal[h].xy[1],
Klocal[14].xy[1], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[0],
Klocal[15].xy[0], dout[h]);
dout[h] = gcn_mfma_instr<scalar_t, 4, 15, 0>(Qlocal[h].xy[1],
Klocal[15].xy[1], dout[h]);
} // KHELOOP>8
dout[h] *= scale;
}
// transpose dout so that 4 token ids are in each lane, and 4 heads are across
// 4 lanes
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
floatx4 tmp = {0};
#pragma unroll
for (int i = 0; i < 4; i++) {
const float B = (lane4id == i) ? 1.0f : 0.0f;
// const float A = (global_token_idx < context_len) ? dout[h][i] : 0.0f;
tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(dout[h][i], B, tmp, 0, 0, 0);
// tmp = __builtin_amdgcn_mfma_f32_4x4x1f32(A, B, tmp, 0, 0, 0);
}
dout[h] = tmp;
}
const int lane4_token_idx = 4 * (global_token_idx >> 2);
const int alibi_offset = lane4_token_idx - context_len + 1;
if (alibi_slopes != nullptr) {
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
#pragma unroll
for (int i = 0; i < 4; i++) {
dout[h][i] += alibi_slope[h] * (alibi_offset + i);
}
}
}
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
qk_max[h] = -FLT_MAX;
#pragma unroll
for (int i = 0; i < 4; i++) {
qk_max[h] = (lane4_token_idx + i < context_len)
? fmaxf(qk_max[h], dout[h][i])
: qk_max[h];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
qk_max[h] = fmaxf(qk_max[h], __shfl_xor(qk_max[h], mask));
}
}
float exp_sum[QHLOOP];
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
exp_sum[h] = 0.0f;
#pragma unroll
for (int i = 0; i < 4; i++) {
dout[h][i] = (lane4_token_idx + i < context_len)
? __expf(dout[h][i] - qk_max[h])
: 0.0f;
exp_sum[h] += dout[h][i];
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 4; mask /= 2) {
exp_sum[h] += __shfl_xor(exp_sum[h], mask);
}
}
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
const int head_idx = 4 * h + lane4id;
shared_qk_max[warpid][head_idx] = qk_max[h];
shared_exp_sum[warpid][head_idx] = exp_sum[h];
}
} // warp within context
__syncthreads();
const int num_heads = gridDim.z * GQA_RATIO;
float* max_logits_ptr =
max_logits + seq_idx * num_heads * max_num_partitions + partition_idx;
float* exp_sums_ptr =
exp_sums + seq_idx * num_heads * max_num_partitions + partition_idx;
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
float global_qk_max = -FLT_MAX;
float warp_qk_max[NWARPS];
const int head_idx = 4 * h + lane4id;
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
warp_qk_max[w] = shared_qk_max[w][head_idx];
global_qk_max = fmaxf(global_qk_max, warp_qk_max[w]);
}
float global_exp_sum = 0.0f;
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
global_exp_sum +=
shared_exp_sum[w][head_idx] * __expf(warp_qk_max[w] - global_qk_max);
}
if (head_idx < GQA_RATIO) {
max_logits_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
global_qk_max;
exp_sums_ptr[(wg_start_head_idx + head_idx) * max_num_partitions] =
global_exp_sum;
}
const float global_inv_sum_scale = __fdividef(1.f, global_exp_sum + 1e-6f) *
__expf(qk_max[h] - global_qk_max);
dout[h] *= global_inv_sum_scale;
}
// logits[h] -> every 4 lanes hold 4 heads, each lane holds 4 tokens, there
// are 4x16 tokens across warp
_B16x4 logits[QHLOOP];
#pragma unroll
for (int h = 0; h < QHLOOP; h++) {
logits[h] = from_floatx4<scalar_t>(dout[h]);
}
__shared__ _B16x4 vout_shared[QHLOOP][VHELOOP][WARP_SIZE][NWARPS + 1];
if (warp_start_token_idx >= context_len) { // warp out of context
#pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) {
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) {
vout_shared[qh][vh][laneid][warpid] = {0};
}
}
} else { // warp in context
// iterate across heads
#pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) {
// iterate over each v head elem (within head_size)
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) {
floatx4 acc = {0};
// iterate over tokens
acc = gcn_mfma_instr<scalar_t, 4, 0, 0>(logits[qh], Vlocal[vh][0].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 1, 0>(logits[qh], Vlocal[vh][0].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 2, 0>(logits[qh], Vlocal[vh][1].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 3, 0>(logits[qh], Vlocal[vh][1].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 4, 0>(logits[qh], Vlocal[vh][2].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 5, 0>(logits[qh], Vlocal[vh][2].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 6, 0>(logits[qh], Vlocal[vh][3].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 7, 0>(logits[qh], Vlocal[vh][3].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 8, 0>(logits[qh], Vlocal[vh][4].xy[0],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 9, 0>(logits[qh], Vlocal[vh][4].xy[1],
acc);
acc = gcn_mfma_instr<scalar_t, 4, 10, 0>(logits[qh],
Vlocal[vh][5].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 11, 0>(logits[qh],
Vlocal[vh][5].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 12, 0>(logits[qh],
Vlocal[vh][6].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 13, 0>(logits[qh],
Vlocal[vh][6].xy[1], acc);
acc = gcn_mfma_instr<scalar_t, 4, 14, 0>(logits[qh],
Vlocal[vh][7].xy[0], acc);
acc = gcn_mfma_instr<scalar_t, 4, 15, 0>(logits[qh],
Vlocal[vh][7].xy[1], acc);
vout_shared[qh][vh][laneid][warpid] = from_floatx4<scalar_t>(acc);
}
}
} // warp in context
__syncthreads();
if (warpid == 0) {
_B16x4 vout[QHLOOP][VHELOOP];
// iterate across heads
scalar_t* out_ptr;
int out_num_partitions;
if (context_len > partition_size) {
out_num_partitions = max_num_partitions;
out_ptr = out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
partition_idx * HEAD_SIZE;
} else {
out_num_partitions = 1;
out_ptr = final_out + seq_idx * num_heads * HEAD_SIZE;
}
#pragma unroll
for (int qh = 0; qh < QHLOOP; qh++) {
// iterate over each v head elem (within head_size)
#pragma unroll
for (int vh = 0; vh < VHELOOP; vh++) {
vout[qh][vh] = {0};
#pragma unroll
for (int w = 0; w < NWARPS; w++) {
vout[qh][vh] =
addx4<scalar_t>(vout[qh][vh], vout_shared[qh][vh][laneid][w]);
}
const int head_size_elem = vh * WARP_SIZE + laneid;
bit16_t* out_ptr_b16 = reinterpret_cast<bit16_t*>(out_ptr);
#pragma unroll
for (int i = 0; i < 4; i++) {
const int head_idx = 4 * qh + i;
if (head_idx < GQA_RATIO) {
out_ptr_b16[(wg_start_head_idx + head_idx) * out_num_partitions *
HEAD_SIZE +
head_size_elem] = vout[qh][vh][i];
}
}
}
}
}
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions) {
const int num_heads = gridDim.x;
const int head_idx = blockIdx.x;
const int seq_idx = blockIdx.y;
const int context_len = context_lens[seq_idx];
const int num_partitions = DIVIDE_ROUND_UP(context_len, PARTITION_SIZE);
if (num_partitions == 1) {
// if num_partitions==1, main kernel will write to out directly, no work in
// reduction kernel
return;
}
constexpr int NUM_WARPS = NUM_THREADS / WARP_SIZE;
const int warpid = threadIdx.x / WARP_SIZE;
const int laneid = threadIdx.x % WARP_SIZE;
__shared__ float shared_global_exp_sum;
__shared__ float shared_exp_sums[2 * WARP_SIZE];
if (warpid == 0) {
const float* max_logits_ptr = max_logits +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
// valid partition is the last valid partition in case threadid > num
// partitions
const int valid_partition =
(threadIdx.x < num_partitions) ? threadIdx.x : num_partitions - 1;
const int valid_partition2 = (WARP_SIZE + threadIdx.x < num_partitions)
? WARP_SIZE + threadIdx.x
: num_partitions - 1;
float reg_max_logit = max_logits_ptr[valid_partition];
float reg_max_logit2 = max_logits_ptr[valid_partition2];
float max_logit = fmaxf(reg_max_logit, reg_max_logit2);
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
max_logit = fmaxf(max_logit, __shfl_xor(max_logit, mask));
}
const float* exp_sums_ptr = exp_sums +
seq_idx * num_heads * max_num_partitions +
head_idx * max_num_partitions;
float global_exp_sum = 0.0f;
float rescaled_exp_sum = exp_sums_ptr[valid_partition];
float rescaled_exp_sum2 = exp_sums_ptr[valid_partition2];
rescaled_exp_sum *=
(threadIdx.x < num_partitions) ? expf(reg_max_logit - max_logit) : 0.0f;
rescaled_exp_sum2 *= (threadIdx.x + WARP_SIZE < num_partitions)
? expf(reg_max_logit2 - max_logit)
: 0.0f;
global_exp_sum += rescaled_exp_sum + rescaled_exp_sum2;
shared_exp_sums[threadIdx.x] = rescaled_exp_sum;
shared_exp_sums[threadIdx.x + WARP_SIZE] = rescaled_exp_sum2;
#pragma unroll
for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) {
global_exp_sum += __shfl_xor(global_exp_sum, mask);
}
if (threadIdx.x == 0) {
shared_global_exp_sum = global_exp_sum;
}
} // warpid == 0
const scalar_t* tmp_out_ptr =
tmp_out + seq_idx * num_heads * max_num_partitions * HEAD_SIZE +
head_idx * max_num_partitions * HEAD_SIZE + threadIdx.x;
constexpr int MAX_NPAR = 64;
scalar_t tmps[MAX_NPAR];
const float dzero = 0.0f;
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
tmps[j] = from_float<scalar_t>(dzero);
}
const int last_partition_offset = (num_partitions - 1) * HEAD_SIZE;
const int num_partition_offset = (num_partitions)*HEAD_SIZE;
int idx = 0;
constexpr int JCHUNK = 16;
#pragma unroll
for (int j = 0; j < JCHUNK * HEAD_SIZE; j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
__syncthreads();
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK * HEAD_SIZE; j < 2 * JCHUNK * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK * HEAD_SIZE; j < MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
}
} // num_partitions > JCHUNK
// Aggregate tmp_out to out.
float acc = 0.0f;
#pragma unroll
for (int j = 0; j < JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > JCHUNK) {
#pragma unroll
for (int j = JCHUNK; j < 2 * JCHUNK; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
if (num_partitions > 2 * JCHUNK) {
#pragma unroll
for (int j = 2 * JCHUNK; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j];
}
}
}
if (num_partitions > MAX_NPAR) {
idx = 0;
#pragma unroll
for (int j = MAX_NPAR * HEAD_SIZE; j < 2 * MAX_NPAR * HEAD_SIZE;
j += HEAD_SIZE) {
// lastj is last valid partition
const int lastj_offset =
(j < num_partition_offset) ? j : last_partition_offset;
tmps[idx] = tmp_out_ptr[lastj_offset];
idx++;
}
#pragma unroll
for (int j = 0; j < MAX_NPAR; j++) {
acc += to_float<scalar_t>(tmps[j]) * shared_exp_sums[j + MAX_NPAR];
}
}
const float inv_global_exp_sum =
__fdividef(1.0f, shared_global_exp_sum + 1e-6f);
acc *= inv_global_exp_sum;
scalar_t* out_ptr =
out + seq_idx * num_heads * HEAD_SIZE + head_idx * HEAD_SIZE;
out_ptr[threadIdx.x] = from_float<scalar_t>(acc);
}
#else // !defined(__HIP__MI300_MI250__) TODO: Add NAVI support
template <typename scalar_t, typename cache_t,
vllm::Fp8KVCacheDataType KV_DTYPE, int BLOCK_SIZE, int HEAD_SIZE,
int NUM_THREADS,
int GQA_RATIO>
__global__ __launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_QKV_kernel(
const scalar_t* __restrict__ q, // [num_seqs, num_heads, head_size]
const cache_t* __restrict__ k_cache, // [num_blocks, num_kv_heads,
// head_size/x, block_size, x]
const cache_t* __restrict__ v_cache, // [num_blocks, num_kv_heads,
// head_size, block_size]
const int num_kv_heads, const float scale,
const int* __restrict__ block_tables, // [num_seqs, max_num_blocks_per_seq]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const int q_stride, const int kv_block_stride, const int kv_head_stride,
float* __restrict__ exp_sums, // [num_seqs, num_heads, max_num_partitions]
float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
scalar_t* __restrict__ out, // [num_seqs, num_heads, max_num_partitions,
// head_size]
scalar_t* __restrict__ final_out, // [num_seqs, num_heads, head_size]
int max_ctx_blocks, float k_scale, float v_scale) {
UNREACHABLE_CODE
}
// Grid: (num_heads, num_seqs).
template <typename scalar_t, int HEAD_SIZE, int NUM_THREADS,
int PARTITION_SIZE>
__global__
__launch_bounds__(NUM_THREADS) void paged_attention_ll4mi_reduce_kernel(
scalar_t* __restrict__ out, // [num_seqs, num_heads, head_size]
const float* __restrict__ exp_sums, // [num_seqs, num_heads,
// max_num_partitions]
const float* __restrict__ max_logits, // [num_seqs, num_heads,
// max_num_partitions]
const scalar_t* __restrict__ tmp_out, // [num_seqs, num_heads,
// max_num_partitions, head_size]
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_partitions){UNREACHABLE_CODE}
#endif // defined(__HIP__MI300_MI250__) TODO: Add NAVI support
#define LAUNCH_CUSTOM_ATTENTION(GQA_RATIO) \
paged_attention_ll4mi_QKV_kernel<T, KVT, KV_DTYPE, BLOCK_SIZE, HEAD_SIZE, \
NTHR, GQA_RATIO> \
<<<grid, block, 0, stream>>>( \
query_ptr, key_cache_ptr, value_cache_ptr, num_kv_heads, scale, \
block_tables_ptr, context_lens_ptr, max_num_blocks_per_seq, \
alibi_slopes_ptr, q_stride, kv_block_stride, kv_head_stride, \
exp_sums_ptr, max_logits_ptr, tmp_out_ptr, out_ptr, max_ctx_blocks, \
k_scale, v_scale);
template <typename T, typename KVT, vllm::Fp8KVCacheDataType KV_DTYPE,
int BLOCK_SIZE, int HEAD_SIZE, int PARTITION_SIZE = 512>
void paged_attention_custom_launcher(
torch::Tensor& out, torch::Tensor& exp_sums, torch::Tensor& max_logits,
torch::Tensor& tmp_out, torch::Tensor& query, torch::Tensor& key_cache,
torch::Tensor& value_cache, const int num_kv_heads, float scale,
torch::Tensor& block_tables, torch::Tensor& context_lens,
int max_context_len, const c10::optional<torch::Tensor>& alibi_slopes,
float k_scale, float v_scale) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
int max_num_blocks_per_seq = block_tables.size(1);
int q_stride = query.stride(0);
int kv_block_stride = key_cache.stride(0);
int kv_head_stride = key_cache.stride(1);
// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr =
alibi_slopes
? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
float* max_logits_ptr = reinterpret_cast<float*>(max_logits.data_ptr());
T* tmp_out_ptr = reinterpret_cast<T*>(tmp_out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
KVT* key_cache_ptr = reinterpret_cast<KVT*>(key_cache.data_ptr());
KVT* value_cache_ptr = reinterpret_cast<KVT*>(value_cache.data_ptr());
int* block_tables_ptr = block_tables.data_ptr<int>();
int* context_lens_ptr = context_lens.data_ptr<int>();
const int max_ctx_blocks = DIVIDE_ROUND_UP(max_context_len, BLOCK_SIZE);
const int max_num_partitions =
DIVIDE_ROUND_UP(max_context_len, PARTITION_SIZE);
const int gqa_ratio = num_heads / num_kv_heads;
assert(num_heads % num_kv_heads == 0);
assert(head_size == HEAD_SIZE);
assert(max_num_partitions <= 128);
constexpr int NTHR = PARTITION_SIZE;
dim3 grid(num_seqs, max_num_partitions, num_kv_heads);
dim3 block(NTHR);
const at::cuda::OptionalCUDAGuard device_guard(device_of(query));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
switch (gqa_ratio) {
case 1:
LAUNCH_CUSTOM_ATTENTION(1);
break;
case 2:
LAUNCH_CUSTOM_ATTENTION(2);
break;
case 3:
LAUNCH_CUSTOM_ATTENTION(3);
break;
case 4:
LAUNCH_CUSTOM_ATTENTION(4);
break;
case 5:
LAUNCH_CUSTOM_ATTENTION(5);
break;
case 6:
LAUNCH_CUSTOM_ATTENTION(6);
break;
case 7:
LAUNCH_CUSTOM_ATTENTION(7);
break;
case 8:
LAUNCH_CUSTOM_ATTENTION(8);
break;
case 9:
LAUNCH_CUSTOM_ATTENTION(9);
break;
case 10:
LAUNCH_CUSTOM_ATTENTION(10);
break;
case 11:
LAUNCH_CUSTOM_ATTENTION(11);
break;
case 12:
LAUNCH_CUSTOM_ATTENTION(12);
break;
case 13:
LAUNCH_CUSTOM_ATTENTION(13);
break;
case 14:
LAUNCH_CUSTOM_ATTENTION(14);
break;
case 15:
LAUNCH_CUSTOM_ATTENTION(15);
break;
case 16:
LAUNCH_CUSTOM_ATTENTION(16);
break;
default:
TORCH_CHECK(false, "Unsupported gqa ratio: ", gqa_ratio);
break;
}
// dim3 grid2(num_heads,num_seqs,head_size/HEAD_ELEMS_PER_WG);
// dim3 block2(1024);
// LAUNCH_CUSTOM_ATTENTION2;
// reduction kernel is only required if max_context_len > partition size,
// otherwise main kernel writes directly to final output
// note there are cases with graphing where max_context_len is the max
// supported by graphing, not the actual max among all the sequences: in that
// case reduction kernel will still run but return immediately
if (max_context_len > PARTITION_SIZE) {
dim3 reduce_grid(num_heads, num_seqs);
dim3 reduce_block(head_size);
paged_attention_ll4mi_reduce_kernel<T, HEAD_SIZE, HEAD_SIZE, PARTITION_SIZE>
<<<reduce_grid, reduce_block, 0, stream>>>(
out_ptr, exp_sums_ptr, max_logits_ptr, tmp_out_ptr,
context_lens_ptr, max_num_partitions);
}
}
#define CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE) \
paged_attention_custom_launcher<T, KVT, KV_DTYPE, BLK_SIZE, HEAD_SIZE>( \
out, exp_sums, max_logits, tmp_out, query, key_cache, value_cache, \
num_kv_heads, scale, block_tables, context_lens, max_context_len, \
alibi_slopes, k_scale, v_scale);
#define CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, HEAD_SIZE) \
switch (block_size) { \
case 16: \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 16, HEAD_SIZE); \
break; \
case 32: \
CALL_CUSTOM_LAUNCHER(T, KVT, KV_DTYPE, 32, HEAD_SIZE); \
break; \
default: \
TORCH_CHECK(false, "Unsupported block size: ", block_size); \
break; \
}
#define CALL_CUSTOM_LAUNCHER_BLK_HEAD(T, KVT, KV_DTYPE) \
switch (head_size) { \
case 64: \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 64); \
break; \
case 128: \
CALL_CUSTOM_LAUNCHER_BLK(T, KVT, KV_DTYPE, 128); \
break; \
default: \
TORCH_CHECK(false, "Unsupported head size: ", head_size); \
break; \
}
void paged_attention(
torch::Tensor& out, // [num_seqs, num_heads, head_size]
torch::Tensor& exp_sums, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor& max_logits, // [num_seqs, num_heads, max_num_partitions]
torch::Tensor&
tmp_out, // [num_seqs, num_heads, max_num_partitions, head_size]
torch::Tensor& query, // [num_seqs, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, head_size/x, block_size, x]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
int64_t num_kv_heads, double scale,
torch::Tensor& block_tables, // [num_seqs, max_num_blocks_per_seq]
torch::Tensor& context_lens, // [num_seqs]
int64_t block_size, int64_t max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const std::string& kv_cache_dtype, double k_scale, double v_scale) {
const int head_size = query.size(2);
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Half) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, _Float16,
vllm::Fp8KVCacheDataType::kAuto);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, __hip_bfloat16,
vllm::Fp8KVCacheDataType::kAuto);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else if (kv_cache_dtype == "fp8" || kv_cache_dtype == "fp8_e4m3") {
if (query.dtype() == at::ScalarType::Half) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(_Float16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else if (query.dtype() == at::ScalarType::BFloat16) {
CALL_CUSTOM_LAUNCHER_BLK_HEAD(__hip_bfloat16, uint8_t,
vllm::Fp8KVCacheDataType::kFp8E4M3);
} else {
TORCH_CHECK(false, "Unsupported data type: ", query.dtype());
}
} else {
TORCH_CHECK(false, "Unsupported KV cache dtype: ", kv_cache_dtype);
}
}
#undef WARP_SIZE
#undef MAX
#undef MIN
#undef DIVIDE_ROUND_UP
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment